From 449cc79531078324135111d1d128f6e4af8b9e70 Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Fri, 22 Nov 2024 18:07:06 +0800 Subject: [PATCH] feat(serial): Standardise de/serialisation Initial draft to standardise how we de/serialise objects across LLMeter: From endpoint configurations, to tokenizers, and test results. BREAKING CHANGES to various save & load methods to drive consistency. Not yet updated tests or worked through full scope of breaking change to communicate in release note. --- llmeter/endpoints/base.py | 105 ++++------------ llmeter/results.py | 142 ++++++++++++++++----- llmeter/runner.py | 100 ++++++++++----- llmeter/serde.py | 256 ++++++++++++++++++++++++++++++++++++++ llmeter/tokenizers.py | 160 +++++++++--------------- 5 files changed, 511 insertions(+), 252 deletions(-) create mode 100644 llmeter/serde.py diff --git a/llmeter/endpoints/base.py b/llmeter/endpoints/base.py index 4c7019d..850b289 100644 --- a/llmeter/endpoints/base.py +++ b/llmeter/endpoints/base.py @@ -2,14 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import importlib -import json -import os from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Dict, TypeVar from uuid import uuid4 -from upath import UPath as Path +from llmeter.serde import JSONableBase Self = TypeVar( "Self", bound="Endpoint" @@ -17,7 +15,7 @@ @dataclass -class InvocationResponse: +class InvocationResponse(JSONableBase): """ A class representing a invocation result. @@ -43,9 +41,6 @@ class InvocationResponse: time_per_output_token: float | None = None error: str | None = None - def to_json(self, **kwargs) -> str: - return json.dumps(self.__dict__, **kwargs) - @staticmethod def error_output( input_prompt: str | None = None, error=None, id: str | None = None @@ -64,11 +59,8 @@ def __repr__(self): def __str__(self): return self.to_json(indent=4, default=str) - def to_dict(self): - return asdict(self) - -class Endpoint(ABC): +class Endpoint(JSONableBase, ABC): """ An abstract base class for endpoint implementations. @@ -154,79 +146,26 @@ def __subclasshook__(cls, C): return True return NotImplemented - def save(self, output_path: os.PathLike) -> os.PathLike: - """ - Save the endpoint configuration to a JSON file. - - This method serializes the endpoint's configuration (excluding private attributes) - to a JSON file at the specified path. - - Args: - output_path (str | UPath): The path where the configuration file will be saved. - - Returns: - None - """ - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - with output_path.open("w") as f: - endpoint_conf = self.to_dict() - json.dump(endpoint_conf, f, indent=4, default=str) - return output_path - - def to_dict(self) -> Dict: - """ - Convert the endpoint configuration to a dictionary. - - Returns: - Dict: A dictionary representation of the endpoint configuration. - """ - endpoint_conf = {k: v for k, v in vars(self).items() if not k.startswith("_")} - endpoint_conf["endpoint_type"] = self.__class__.__name__ - return endpoint_conf - - @classmethod - def load_from_file(cls, input_path: os.PathLike) -> Self: - """ - Load an endpoint configuration from a JSON file. - - This class method reads a JSON file containing an endpoint configuration, - determines the appropriate endpoint class, and instantiates it with the - loaded configuration. - - Args: - input_path (str|UPath): The path to the JSON configuration file. - - Returns: - Endpoint: An instance of the appropriate endpoint class, initialized - with the configuration from the file. - """ - - input_path = Path(input_path) - with input_path.open("r") as f: - data = json.load(f) - endpoint_type = data.pop("endpoint_type") - endpoint_module = importlib.import_module("llmeter.endpoints") - endpoint_class = getattr(endpoint_module, endpoint_type) - return endpoint_class(**data) - @classmethod - def load(cls, endpoint_config: Dict) -> Self: # type: ignore - """ - Load an endpoint configuration from a dictionary. - - This class method reads a dictionary containing an endpoint configuration, - determines the appropriate endpoint class, and instantiates it with the - loaded configuration. + def from_dict( + cls: Self, raw: Dict, alt_classes: Dict[str, Self] = {}, **kwargs + ) -> Self: + """Load any built-in Endpoint type (or custom ones) from a plain JSON dictionary Args: - data (Dict): A dictionary containing the endpoint configuration. + raw: A plain Endpoint config dictionary, as created with `to_dict()`, `to_json`, etc. + alt_classes (Dict[str, type[Endpoint]]): A dictionary mapping additional custom type + names (beyond those in `llmeter.endpoints`, which are included automatically), to + corresponding classes for loading custom endpoint types. + **kwargs: Optional extra keyword arguments to pass to the constructor Returns: - Endpoint: An instance of the appropriate endpoint class, initialized - with the configuration from the dictionary. - """ - endpoint_type = endpoint_config.pop("endpoint_type") - endpoint_module = importlib.import_module("llmeter.endpoints") - endpoint_class = getattr(endpoint_module, endpoint_type) - return endpoint_class(**endpoint_config) + endpoint: An instance of the appropriate endpoint class, initialized with the + configuration from the file. + """ + builtin_endpoint_types = importlib.import_module("llmeter.endpoints") + class_map = { + **builtin_endpoint_types, + **alt_classes, + } + return super().from_dict(raw, alt_classes=class_map, **kwargs) diff --git a/llmeter/results.py b/llmeter/results.py index ff802ee..c460efb 100644 --- a/llmeter/results.py +++ b/llmeter/results.py @@ -9,18 +9,22 @@ from math import isnan import os from statistics import StatisticsError, mean, median, quantiles -from typing import Dict, Sequence +from typing import Any, Callable, Dict, Sequence, TypeVar import jmespath from upath import UPath as Path from .endpoints import InvocationResponse +from .serde import from_dict_with_class_map, JSONableBase logger = logging.getLogger(__name__) +TResult = TypeVar("TResult", bound="Result") + + @dataclass -class Result: +class Result(JSONableBase): """Results of an experiment run.""" responses: list[InvocationResponse] @@ -65,40 +69,108 @@ def save(self, output_path: os.PathLike | str | None = None): The method uses the Universal Path (UPath) library for file operations, which provides a unified interface for working with different file systems. """ - try: output_path = Path(self.output_path or output_path) except TypeError: raise ValueError("No output path provided") - output_path.mkdir(parents=True, exist_ok=True) - - summary_path = output_path / "summary.json" + self.to_file( + output_path / "summary.json", + include_responses=False, + ) # Already creates output_path folders if needed stats_path = output_path / "stats.json" - with summary_path.open("w") as f, stats_path.open("w") as s: - f.write(self.to_json(indent=4)) + with stats_path.open("w") as s: s.write(json.dumps(self.stats, indent=4, default=str)) responses_path = output_path / "responses.jsonl" - if not responses_path.exists(): - with responses_path.open("w") as f: - for response in self.responses: - f.write(json.dumps(asdict(response)) + "\n") - - def to_json(self, **kwargs): - """Return the results as a JSON string.""" - summary = { - k: o for k, o in asdict(self).items() if k not in ["responses", "stats"] - } - return json.dumps(summary, default=str, **kwargs) + with responses_path.open("w") as f: + for response in self.responses: + f.write(json.dumps(asdict(response)) + "\n") - def to_dict(self, include_responses: bool = False): - """Return the results as a dictionary.""" - if include_responses: - return asdict(self) - return { - k: o for k, o in asdict(self).items() if k not in ["responses", "stats"] - } + @classmethod + def from_dict( + cls, raw: dict, alt_classes: dict[str, TResult] = {}, **kwargs + ) -> TResult: + """Load a run Result from a plain dict (with optional extra kwargs) + + Args: + raw: A plain Python dict, for example loaded from a JSON file + alt_classes: By default, this method will only use the class of the current object + (i.e. `cls`). If you want to support loading of subclasses, provide a mapping + from your raw dict's `_type` field to class, for example `{cls.__name__: cls}`. + **kwargs: Optional extra keyword arguments to pass to the constructor + """ + data = {**raw} + if "responses" in data: + data["responses"] = [ + # Just in case users wanted to override InvocationResponse itself... + from_dict_with_class_map( + resp, + alt_classes={ + InvocationResponse.__name__: InvocationResponse, + **alt_classes, + }, + ) + for resp in data["responses"] + ] + else: + data["responses"] = [] + data.pop("stats", None) # Calculated property should be omitted + return super().from_dict(data, alt_classes, **kwargs) + + def to_dict(self, include_responses: bool = False, **kwargs) -> dict: + """Save the results to a JSON-dumpable dictionary (with optional extra kwargs) + + Args: + include_responses: Set `True` to include the `responses` and `stats` in the output. + By default, these will be omitted. + **kwargs: Additional fields to save in the output dictionary. + """ + result = super().to_dict(**kwargs) + if not include_responses: + result.pop("responses", None) + result.pop("stats", None) + return result + + def to_file( + self, + output_path: os.PathLike, + include_responses: bool = False, + indent: int | str | None = 4, + default: Callable[[Any], Any] | None = {}, + **kwargs, + ) -> Path: + """Save the Run Result to a (local or Cloud) JSON file + + Args: + output_path: The path where the file will be saved. + include_responses: Set `True` to include the `responses` and `stats` in the output. + By default, these will be omitted. + indent: Optional indentation passed through to `to_json()` and therefore `json.dumps()` + default: Optional function to convert non-JSON-serializable objects to strings, passed + through to `to_json()` and therefore to `json.dumps()` + **kwargs: Optional extra keyword arguments to pass to `to_json()` + + Returns: + output_path: Universal Path representation of the target file. + """ + return super().to_file( + output_path, + include_responses=include_responses, + indent=indent, + default=default, + **kwargs, + ) + + def to_json(self, include_responses: bool = False, **kwargs) -> str: + """Serialize the results to JSON, with optional kwargs passed through to `json.dumps()` + + Args: + include_responses: Set `True` to include the `responses` and `stats` in the output. + By default, these will be omitted. + **kwargs: Optional arguments to pass to `json.dumps()`. + """ + return json.dumps(self.to_dict(include_responses=include_responses), **kwargs) @classmethod def load(cls, result_path: os.PathLike | str): @@ -126,12 +198,20 @@ def load(cls, result_path: os.PathLike | str): """ result_path = Path(result_path) - responses_path = result_path / "responses.jsonl" summary_path = result_path / "summary.json" - with open(responses_path, "r") as f, summary_path.open("r") as g: - responses = [InvocationResponse(**json.loads(line)) for line in f if line] - summary = json.load(g) - return cls(responses=responses, **summary) + with summary_path.open("r") as f: + raw = json.load(f) + + responses_path = result_path / "responses.jsonl" + try: + with responses_path.open("r") as f: + raw["responses"] = [ + InvocationResponse.from_json(line) for line in f if line + ] + except FileNotFoundError: + logger.info("Result.load: No responses data found at %s", responses_path) + + return cls.from_dict(raw) @cached_property def stats(self) -> Dict: diff --git a/llmeter/runner.py b/llmeter/runner.py index a0ffbc1..5e315a2 100644 --- a/llmeter/runner.py +++ b/llmeter/runner.py @@ -8,10 +8,10 @@ import random import time from concurrent.futures import ThreadPoolExecutor -from dataclasses import InitVar, asdict, dataclass, replace +from dataclasses import InitVar, dataclass, replace from datetime import datetime from itertools import cycle -from typing import Any +from typing import Any, Dict, TypeVar from uuid import uuid4 from tqdm.auto import tqdm, trange @@ -20,6 +20,7 @@ from .endpoints.base import Endpoint, InvocationResponse from .prompt_utils import load_payloads, save_payloads from .results import Result +from .serde import JSONableBase from .tokenizers import DummyTokenizer, Tokenizer logger = logging.getLogger(__name__) @@ -31,8 +32,11 @@ _disable_tqdm = True +TRunConfig = TypeVar("TRunConfig", bound="_RunConfig") + + @dataclass -class _RunConfig: +class _RunConfig(JSONableBase): """A class to store the configuration for a test run.""" endpoint: Endpoint | dict @@ -72,42 +76,70 @@ def __post_init__(self, disable_client_progress_bar, disable_clients_progress_ba else: self._tokenizer = self.tokenizer - def save( - self, - output_path: os.PathLike | str | None = None, - file_name: str = "run_config.json", - ): - """Save the configuration to a disk or could storage.""" - output_path = Path(output_path or self.output_path) - output_path.mkdir(parents=True, exist_ok=True) - if self.run_name: - output_path = output_path / self.run_name - run_config_path = output_path / file_name + @classmethod + def from_dict( + cls, raw: dict, alt_classes: Dict[str, TRunConfig] = {}, **kwargs + ) -> TRunConfig: + """Load a Run Configuration from a plain JSON-compatible dictionary. - config_copy = replace(self) + Args: + raw: A plain Python dict, for example loaded from a JSON file + alt_classes: Optional mapping from raw dictionary `_type` field to additional custom + classes to support during loading (for example if using custom Endpoint or + Tokenizer classes). Format like: `{cls.__name__: cls}`. + **kwargs: Optional extra keyword arguments to pass to the RunConfig constructor + """ + data = {**raw} + data["endpoint"] = Endpoint.from_dict(raw["endpoint"], alt_classes=alt_classes) + data["tokenizer"] = Tokenizer.from_dict( + raw["tokenizer"], alt_classes=alt_classes + ) + return super().from_dict(data, alt_classes, **kwargs) - if self.payload and (not isinstance(self.payload, (os.PathLike, str))): - payload_path = save_payloads(self.payload, output_path) - config_copy.payload = payload_path + @classmethod + def from_file( + cls: TRunConfig, + input_path: os.PathLike, + file_name: str = "run_config.json", + **kwargs, + ) -> TRunConfig: + """Load a Run Configuration from a (local or Cloud) JSON file - if not isinstance(self.endpoint, dict): - config_copy.endpoint = self.endpoint.to_dict() + Args: + input_path: The base path where the file is stored. + file_name: The base name of the target file within `input_path` + **kwargs: Optional extra keyword arguments to pass to `from_dict()` + """ + # Superclass expects input_path to be the actual final JSON file path: + input_path = Path(input_path) + return super().from_file(input_path / file_name, **kwargs) - if not isinstance(self.tokenizer, dict): - config_copy.tokenizer = Tokenizer.to_dict(self.tokenizer) + def to_file( + self, + output_path: os.PathLike | None = None, + file_name: str = "run_config.json", + **kwargs, + ) -> Path: + """Save the Run Configuration to a (local or Cloud) JSON file - with run_config_path.open("w") as f: - f.write(json.dumps(asdict(config_copy), default=str, indent=4)) + Args: + output_path: The base folder where files should be stored. If `self.run_name` is set, + this will be appended as a subfolder. + file_name: The base name of the file to save the configuration to. + **kwargs: Optional extra keyword arguments to pass to `to_json()` - @classmethod - def load(cls, load_path: Path | str, file_name: str = "run_config.json"): - """Load a configuration from a JSON file.""" - load_path = Path(load_path) - with open(load_path / file_name) as f: - config = json.load(f) - config["endpoint"] = Endpoint.load(config["endpoint"]) - config["tokenizer"] = Tokenizer.load(config["tokenizer"]) - return cls(**config) + Returns: + output_path: Universal Path representation of the target file. + """ + # Superclass expects output_path to be the actual final JSON file path: + if not output_path and not self.output_path: + raise ValueError( + "Can't save RunConfig: No output_path provided in to_file() or in config" + ) + output_path = Path(self.output_path) + if self.run_name: + output_path = output_path / self.run_name + return super().to_file(output_path / file_name, **kwargs) @dataclass @@ -397,7 +429,7 @@ async def run( assert isinstance(run_config.payload, list) assert isinstance(run_config.run_name, str) if run_config.output_path: - run_config.save() + run_config.to_file() result = self._initialize_result(run_config) diff --git a/llmeter/serde.py b/llmeter/serde.py new file mode 100644 index 0000000..a6663bf --- /dev/null +++ b/llmeter/serde.py @@ -0,0 +1,256 @@ +"""Common (De/re)serialization interfaces for saving objects to file and loading them back + +Our patterns for this are more verbose and superclass-based than ideal (improvements welcome!), +but there are some challenges that make this more complicated than you might expect: + +1. To help builders, we're trying to keep type hints as accurate as possible; and avoid + introducing extra heavy dependencies without good reason. +2. `dataclasses.asdict()` doesn't recursively convert dataclass fields to dictionaries, so gives + non-JSONable results for nested dataclasses. +3. (At least our targeted min version of) Python doesn't support type intersections [See + https://github.com/python/typing/issues/213]. One consequence of this is that there's no nice + way of type hinting a class decorator that adds (e.g. to_json) methods [See + https://discuss.python.org/t/how-to-type-hint-a-class-decorator/63010/7]. Another is that + although we *can* define an abstract interface as a "Protocol", it's not very useful except in + settings where you *only* need to assert that one interface at a time. +""" + +# Python Built-Ins: +from dataclasses import asdict, is_dataclass +from datetime import date, datetime, time +import json +import logging +import os +from typing import Any, Callable, Dict, Protocol, Type, TypeVar + +# External Dependencies: +from upath import UPath as Path + + +_TJSONDictable = TypeVar("_TJSONDictable", bound="IJSONDictable") + + +logger = logging.getLogger(__name__) + + +class IJSONDictable(Protocol): + """Typing protocol for supporting copying to, and initializing from, JSON-able dictionaries""" + + @classmethod + def from_dict(cls: Type[_TJSONDictable], raw: dict, **kwargs) -> _TJSONDictable: + """Initialize an instance of this class from a plain dict (with optional extra kwargs) + + Args: + raw: A plain Python dict, for example loaded from a JSON file + **kwargs: Optional extra keyword arguments to pass to the constructor + """ + ... + + def to_dict(self, **kwargs) -> dict: + """Save the state of the object to a JSON-dumpable dictionary (with optional extra kwargs) + + Implementers of this method should ensure that the returned dict is fully JSON-compatible: + Mapping any child fields from Python classes to dicts if necessary, avoiding any circular + references, etc. + """ + ... + + +_TJSONStringable = TypeVar("_TJSONStringable", bound="IJSONStringable") + + +class IJSONStringable(Protocol): + """Typing for an object that supports serializing to JSON and loading from JSON (strings)""" + + @classmethod + def from_json( + cls: Type[_TJSONStringable], json_string: str, **kwargs + ) -> _TJSONStringable: + """Initialize an instance of this class from a JSON string (with optional extra kwargs) + + Args: + json_string: A string containing valid JSON data + **kwargs: Optional extra keyword arguments to pass to the class constructor + """ + ... + + def to_json(self, **kwargs) -> str: + """Serialize this object to JSON, with optional kwargs passed through to `json.dumps()`""" + ... + + +class IJSONable(IJSONDictable, IJSONStringable): + pass + + +def is_dataclass_instance(obj): + """Check whether `obj` is an instance of any dataclass + + See: https://docs.python.org/3/library/dataclasses.html#dataclasses.is_dataclass + """ + return is_dataclass(obj) and not isinstance(obj, type) + + +def to_dict_recursive_generic(obj: object, **kwargs) -> dict: + """Convert a vaguely dataclass-like object (with maybe IJSONable fields) to a JSON-ready dict + + The output dict is augmented with `_type` storing the `__class__.__name__` of the provided + `obj`. + + Args: + obj: The object to convert + **kwargs: Optional extra parameters to insert in the output dictionary + """ + result = { + "_type": obj.__class__.__name__, + **(asdict(obj) if is_dataclass_instance(obj) else obj.__dict__), + **kwargs, + } + for k, v in result.items(): + if hasattr(v, "to_dict"): + result[k] = v.to_dict() + elif isinstance(v, (list, tuple)): + result[k] = [to_dict_recursive_generic(item) for item in v] + elif isinstance(v, (date, datetime, time)): + result[k] = v.isoformat() + return result + + +TFromDict = TypeVar("TFromDict") + + +def from_dict_with_class(raw: dict, cls: Type[TFromDict], **kwargs) -> TFromDict: + """Initialize an instance of a class from a plain dict (with optional extra kwargs) + + If the input dictionary contains a `_type` key, and this doesn't match the provided + `cls.__name__`, a warning will be logged. + + Args: + raw: A plain Python dict, for example loaded from a JSON file + cls: The class to create an instance of + **kwargs: Optional extra keyword arguments to pass to the constructor + """ + raw_args = {k: v for k, v in raw.items()} + raw_type = raw_args.pop("_type", None) + if raw_type is not None and raw_type != cls.__name__: + logger.warning( + "from_dict: _type '%s' doesn't match class '%s' being loaded. %s" + % (raw_type, cls.__name__, raw) + ) + return cls(**raw, **kwargs) + + +def from_dict_with_class_map( + raw: dict, class_map: Dict[str, Type[TFromDict]], **kwargs +) -> TFromDict: + """Initialize an instance of a class from a plain dict (with optional extra kwargs) + + Args: + raw: A plain Python dict which must contain a `_type` key + classes: A mapping from `_type` string to class to create an instance of + **kwargs: Optional extra keyword arguments to pass to the constructor + """ + if "_type" not in raw: + raise ValueError("from_dict_with_class_map: No _type in raw dict: %s" % raw) + if raw["_type"] not in class_map: + raise ValueError( + "Object _type '%s' not found in provided class_map %s" + % (raw["_type"], class_map) + ) + return from_dict_with_class(raw, class_map[raw["_type"]], **kwargs) + + +TJSONable = TypeVar("TJSONable", bound="JSONableBase") + + +class JSONableBase: + """An *optional* base class for speeding up implementation of JSONable objects + + Don't check `isinstance` of this class, because not all JSONable objects are guaranteed to + inherit from it: Use `IJSONable` instead. + """ + + @classmethod + def from_dict( + cls: Type[TJSONable], + raw: dict, + alt_classes: Dict[str, TJSONable] = {}, + **kwargs, + ) -> TJSONable: + """Initialize an instance of this class from a plain dict (with optional extra kwargs) + + Args: + raw: A plain Python dict, for example loaded from a JSON file + alt_classes: By default, this method will only use the class of the current object + (i.e. `cls`). If you want to support loading of subclasses, provide a mapping + from your raw dict's `_type` field to class, for example `{cls.__name__: cls}`. + **kwargs: Optional extra keyword arguments to pass to the constructor + """ + if alt_classes: + return from_dict_with_class_map( + raw=raw, + class_map={cls.__name__: cls, **alt_classes}, + **kwargs, + ) + else: + return from_dict_with_class(raw=raw, cls=cls, **kwargs) + + @classmethod + def from_file(cls: Type[TJSONable], input_path: os.PathLike, **kwargs) -> TJSONable: + """Initialize an instance of this class from a (local or Cloud) JSON file + + Args: + input_path: The path to the JSON data file. + **kwargs: Optional extra keyword arguments to pass to `from_dict()` + """ + input_path = Path(input_path) + with input_path.open("r") as f: + return cls.from_json(f.read(), **kwargs) + + @classmethod + def from_json(cls: Type[TJSONable], json_string: str, **kwargs) -> TJSONable: + """Initialize an instance of this class from a JSON string (with optional extra kwargs) + + Args: + json_string: A string containing valid JSON data + **kwargs: Optional extra keyword arguments to pass to `from_dict()`` + """ + return cls.from_dict(json.loads(json_string), **kwargs) + + def to_dict(self, **kwargs) -> dict: + """Save the state of the object to a JSON-dumpable dictionary (with optional extra kwargs) + + Implementers of this method should ensure that the returned dict is fully JSON-compatible: + Mapping any child fields from Python classes to dicts if necessary, avoiding any circular + references, etc. + """ + return to_dict_recursive_generic(self, **kwargs) + + def to_file( + self, + output_path: os.PathLike, + indent: int | str | None = 4, + default: Callable[[Any], Any] | None = str, + **kwargs, + ) -> Path: + """Save the state of the object to a (local or Cloud) JSON file + + Args: + output_path: The path where the configuration file will be saved. + indent: Optional indentation passed through to `to_json()` and therefore `json.dumps()` + default: Optional function to convert non-JSON-serializable objects to strings, passed + through to `to_json()` and therefore to `json.dumps()` + **kwargs: Optional extra keyword arguments to pass to `to_json()` + + Returns: + output_path: Universal Path representation of the target file. + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + f.write(self.to_json(indent=indent, default=default, **kwargs)) + return output_path + + def to_json(self, **kwargs) -> str: + """Serialize this object to JSON, with optional kwargs passed through to `json.dumps()`""" + return json.dumps(self.to_dict(), **kwargs) diff --git a/llmeter/tokenizers.py b/llmeter/tokenizers.py index 3d78b5a..2dc1ac1 100644 --- a/llmeter/tokenizers.py +++ b/llmeter/tokenizers.py @@ -2,16 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any +from importlib import import_module +from typing import Dict, TypeVar -from upath import UPath -import json +from .serde import JSONableBase +TTokenizer = TypeVar("TTokenizer", bound="Tokenizer") -class Tokenizer(ABC): - def __init__(self, *args, **kwargs): - pass +class Tokenizer(JSONableBase, ABC): @abstractmethod def encode(self, text: str): raise NotImplementedError @@ -40,136 +39,89 @@ def __subclasscheck__(cls, subclass): return False return True - @staticmethod - def load_from_file(tokenizer_path: UPath | None): - """ - Loads a tokenizer from a file. - - Args: - tokenizer_path (UPath): The path to the serialized tokenizer file. - - Returns: - Tokenizer: The loaded tokenizer. - """ - if tokenizer_path is None: - return DummyTokenizer() - with open(tokenizer_path, "r") as f: - tokenizer_info = json.load(f) - - return _load_tokenizer_from_info(tokenizer_info) - - @staticmethod - def load(tokenizer_info: dict): - """ - Loads a tokenizer from a dictionary. - - Args: - tokenizer_info (Dict): The tokenizer information to load. - - Returns: - Tokenizer: The loaded tokenizer. - """ - return _load_tokenizer_from_info(tokenizer_info) + # Load any built-in Endpoint type (or custom ones) from a plain JSON dictionary - @staticmethod - def to_dict(tokenizer: Any) -> dict: - """ - Serializes a tokenizer to a dictionary. + @classmethod + def from_dict( + cls, raw: dict, alt_classes: Dict[str, TTokenizer] = {}, **kwargs + ) -> TTokenizer: + """Load any built-in Tokenizer type (or custom ones) from a plain JSON dictionary Args: - tokenizer (Tokenizer): The tokenizer to serialize. + raw: A plain Tokenizer config dictionary, as created with `to_dict()`, `to_json`, etc. + alt_classes (Dict[str, type[Endpoint]]): A dictionary mapping additional custom type + names (beyond those in `llmeter.tokenizers`, which are included automatically), to + corresponding classes for loading custom endpoint types. + **kwargs: Optional extra keyword arguments to pass to the constructor Returns: - Dict: The serialized tokenizer. + endpoint: An instance of the appropriate endpoint class, initialized with the + configuration from the file. """ - return _to_dict(tokenizer) - - -def _to_dict(tokenizer: Any) -> dict: - """ - Serializes a tokenizer to a dictionary. + builtin_endpoint_types = import_module("llmeter.tokenizers") + class_map = { + **builtin_endpoint_types, + **alt_classes, + } + return super().from_dict(raw, alt_classes=class_map, **kwargs) - Args: - tokenizer (Tokenizer): The tokenizer to serialize. - Returns: - Dict: The serialized tokenizer. - """ - if tokenizer.__module__.split(".")[0] == "transformers": - return {"tokenizer_module": "transformers", "name": tokenizer.name_or_path} - - if tokenizer.__module__.split(".")[0] == "tiktoken": - return {"tokenizer_module": "tiktoken", "name": tokenizer.name} - - if tokenizer.__module__.split(".")[0] == "llmeter": - return {"tokenizer_module": "llmeter"} - - raise ValueError(f"Unknown tokenizer module: {tokenizer.__module__}") - - -def save_tokenizer(tokenizer: Any, output_path: UPath | str): +class DummyTokenizer(Tokenizer): """ - Save a tokenizer information to a file. + A dummy tokenizer that splits the input text on whitespace and returns the tokens as is. - Args: - tokenizer (Tokenizer): The tokenizer to serialize. - output_path (UPath): The path to save the serialized tokenizer to. + This tokenizer will generally under-estimate token counts in English and latin languages (where + words comprise more than one token on average), and will give very poor results for languages + where the whitespace/"word" heuristic doesn't work well (e.g. Chinese, Japanese, Korean, Thai). - Returns: - UPath: The path to the serialized tokenizer file. + However, it requires no dependencies beyond the Python standard library, using `str.split()` """ - tokenizer_info = _to_dict(tokenizer) - output_path = UPath(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - with open(output_path, "w") as f: - json.dump(tokenizer_info, f) + def __init__(self, *args, **kwargs): + pass - return output_path + def encode(self, text: str): + return [k for k in text.split()] + def decode(self, tokens: list[str]): + return " ".join(k for k in tokens) -def _load_tokenizer_from_info(tokenizer_info: dict) -> Tokenizer: - """ - Loads a tokenizer from a file. - Args: - tokenizer_info (Dict): The tokenizer information to load. +class TikTokenTokenizer(Tokenizer): + """A tokenizer based on TikToken get_encoding - Returns: - Tokenizer: The loaded tokenizer. + (Note: You must have the `tiktoken` library installed to use this in LLMeter) """ - if tokenizer_info["tokenizer_module"] == "transformers": - from transformers import AutoTokenizer - return AutoTokenizer.from_pretrained(tokenizer_info["name"]) # type: ignore + name: str - if tokenizer_info["tokenizer_module"] == "tiktoken": + def __init__(self, name: str): from tiktoken import get_encoding - return get_encoding(tokenizer_info["name"]) # type: ignore + self._tokenizer = get_encoding(name) + + def encode(self, text: str): + return self._tokenizer.encode(text) - if tokenizer_info["tokenizer_module"] == "llmeter": - return DummyTokenizer() + def decode(self, tokens: list[str]): + return self._tokenizer.decode(tokens) - raise ValueError(f"Unknown tokenizer module: {tokenizer_info['tokenizer_module']}") +class TransformersAutoTokenizer(Tokenizer): + """A tokenizer based on Hugging Face Transformers' AutoTokenizer -class DummyTokenizer(Tokenizer): + (Note: You must have the `transformers` library installed to use this in LLMeter) """ - A dummy tokenizer that splits the input text on whitespace and returns the tokens as is. - This tokenizer will generally under-estimate token counts in English and latin languages (where - words comprise more than one token on average), and will give very poor results for languages - where the whitespace/"word" heuristic doesn't work well (e.g. Chinese, Japanese, Korean, Thai). + name_or_path: str - However, it requires no dependencies beyond the Python standard library, using `str.split()` - """ + def __init__(self, name_or_path: str): + from transformers import AutoTokenizer - def __init__(self, *args, **kwargs): - pass + self._tokenizer = AutoTokenizer.from_pretrained(name_or_path) def encode(self, text: str): - return [k for k in text.split()] + return self._tokenizer.encode(text) def decode(self, tokens: list[str]): - return " ".join(k for k in tokens) + return self._tokenizer.decode(tokens)