Skip to content

Fix: model_dump(mode="json") falls back to __repr__#20

Open
slevental wants to merge 1 commit into
thinking-machines-lab:mainfrom
slevental:main
Open

Fix: model_dump(mode="json") falls back to __repr__#20
slevental wants to merge 1 commit into
thinking-machines-lab:mainfrom
slevental:main

Conversation

@slevental
Copy link
Copy Markdown

Fix: model_dump(mode="json") falls back to __repr__ for ModelInput chunks, burning 97% CPU on token formatting

Problem

Every sample_async() call serializes the SampleRequest via:

# tinker/resources/sampling.py:59
body=model_dump(request, exclude_unset=True, mode="json")

This triggers pydantic v2's model_dump(mode="json") on the full request,
including prompt: ModelInputchunks: List[ModelInputChunk].

ModelInputChunk is defined as:

ModelInputChunk: TypeAlias = Annotated[
    Union[EncodedTextChunk, ImageAssetPointerChunk, ImageChunk],
    PropertyInfo(discriminator="type")
]

PropertyInfo(discriminator=...) is a tinker-internal annotation — pydantic v2
does not recognize it as a discriminator for serialization. When pydantic's JSON-mode
serializer encounters the union variants, it cannot resolve their serialization schema
and falls back to __repr__() on each chunk object.

EncodedTextChunk.__repr__ (inherited from pydantic's default) recursively formats
every field, including tokens: Sequence[int] — which typically contains
thousands of token IDs. This turns every single LLM sampling call into an
O(n_tokens) string-formatting operation under the GIL.

Impact

Profiling an RL training loop (8 tasks × 4 rollouts, multi-turn agent with ~8K
context tokens per turn) showed:

GIL: 94.00%, Active: 101.00%

  %Own   %Total  OwnTime  TotalTime  Function (filename)
 97.00%  98.00%   54.56s    55.23s   <genexpr> (pydantic/_internal/_repr.py)
  0.00% 100.00%    1.03s    57.21s   serialize_sequence_via_list (pydantic/_internal/_serializers.py)
  1.00%   1.00%   0.720s    0.720s   __repr_args__ (pydantic/main.py)
  2.00% 100.00%   0.450s    55.73s   __repr_str__ (pydantic/_internal/_repr.py)

97% of all CPU time was spent formatting token lists as strings that are
immediately discarded. The GIL was held at 94%, serializing all concurrent async
episodes on a single core.

Fix

Add cheap __repr__ overrides on EncodedTextChunk and ModelInput so the
fallback path is O(1) instead of O(n_tokens):

# EncodedTextChunk
def __repr__(self) -> str:
    return f"EncodedTextChunk(tokens=[{len(self.tokens)} tokens])"

# ModelInput
def __repr__(self) -> str:
    total = sum(c.length for c in self.chunks)
    return f"ModelInput(chunks={len(self.chunks)}, total_tokens={total})"

This is a targeted symptom fix. The underlying issue is that ModelInputChunk
uses PropertyInfo(discriminator=...) instead of pydantic v2's native
pydantic.Discriminator, but changing the type alias has broader compatibility
implications.

How to reproduce

import time
from tinker.types.encoded_text_chunk import EncodedTextChunk
from tinker.types.model_input import ModelInput
from tinker.types.sample_request import SampleRequest
from tinker.types.sampling_params import SamplingParams
from tinker._compat import model_dump

# Simulate a realistic prompt (~8K tokens)
tokens = list(range(8192))
prompt = ModelInput.from_ints(tokens)

request = SampleRequest(
    prompt=prompt,
    sampling_params=SamplingParams(max_tokens=1024),
)

start = time.perf_counter()
model_dump(request, exclude_unset=True, mode="json")
elapsed = time.perf_counter() - start
print(f"model_dump took {elapsed:.3f}s")
# Before fix: ~2-5s (scales with token count)
# After fix:  <0.001s

@slevental slevental changed the title Fix: model_dump(mode="json") falls back to __repr__ for ModelInput chunks, burning 97% CPU on token formatting Fix: model_dump(mode="json") falls back to __repr__ Mar 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant