Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
d5b8288
Checkpoint Dripper Common Crawl integration
VibhuJawa Jun 9, 2026
3435ced
Add large precomputed layout group splitting and baseline comparison …
VibhuJawa Jun 10, 2026
1790810
Move layout diagnostic scripts into tutorial directory
VibhuJawa Jun 10, 2026
2a8d7de
Fix audit quick wins: F1 accounting, determinism, safety, metrics
VibhuJawa Jun 10, 2026
cd5c906
Add LAYOUT_PRECOMPUTED_MANIFEST support to bypass per-shard DBSCAN
VibhuJawa Jun 10, 2026
38c77d5
Add deferred layout propagation to move CPU work off H100 critical path
VibhuJawa Jun 10, 2026
107a618
Wire defer_propagation, fix singleton shards, add dynamic max tokens
VibhuJawa Jun 10, 2026
8863426
Fix deferred propagation: store mapping_json on representative row
VibhuJawa Jun 10, 2026
14ad7a0
Add Dripper layout clustering tutorial notebook
VibhuJawa Jun 10, 2026
5e490b4
Use dc (data-copier) nodes for all rsync transfers
VibhuJawa Jun 10, 2026
b3e4168
Fix notebook: read_parquet_safe() bypasses ParquetDataset buffer issu…
VibhuJawa Jun 10, 2026
6d2d129
Fix notebook: use correct MinerU-HTML bindings API
VibhuJawa Jun 10, 2026
0074607
Add pipeline timing analysis doc
VibhuJawa Jun 11, 2026
a12cf85
Add comparison notebook: clustering pipeline vs standalone Dripper
VibhuJawa Jun 11, 2026
47adab5
Add MinerU-HTML standalone baseline + comparison notebook
VibhuJawa Jun 11, 2026
eb69946
Add GPU-accelerated DBSCAN clustering via cuML
VibhuJawa Jun 11, 2026
8d81b84
Add deduplication_cuda12 extra to uv sync for cuML DBSCAN
VibhuJawa Jun 11, 2026
f0dbfa4
Use cached venv when available to skip 15-20min install per job
VibhuJawa Jun 11, 2026
3af3ea4
Add CC-scale MinerU-HTML layout-clustering + propagation pipeline
VibhuJawa Jun 13, 2026
e0d6010
Simplify pipeline code: reuse upstream helpers, dedup, tighten
VibhuJawa Jun 13, 2026
be5af73
Remove cluster-specific scripts and hardcoded paths from tutorial
VibhuJawa Jun 13, 2026
2a9b509
Update tutorial README: drop removed cluster submit script references
VibhuJawa Jun 13, 2026
0326a98
Fix stage1b GPU OOM: chunk oversized hosts (>3k pages) via STAGE1B_MA…
VibhuJawa Jun 13, 2026
512d913
Apply pre-commit checks: ruff format, lint fixes, pyproject ignores f…
VibhuJawa Jun 13, 2026
a7cf17f
Fix ruff errors and secrets-detector issues introduced by our PR
VibhuJawa Jun 13, 2026
c399236
Fix all ruff CI failures: scope test ignores, fix real errors, add tu…
VibhuJawa Jun 13, 2026
390662c
Scope ruff tutorial ignores to dripper-cc dir; add streaming pipeline…
VibhuJawa Jun 13, 2026
21aa89e
Remove non-essential tutorial files from PR; keep only pipeline scripts
VibhuJawa Jun 13, 2026
4b4e704
Fix secrets-detector: mark World Bank URL test strings as allowlist
VibhuJawa Jun 13, 2026
e984eaf
Enable per-shard streaming: aftercorr dependencies + Stage 3 exact-sh…
VibhuJawa Jun 13, 2026
61eaaae
Fix Stage 2b serial bottleneck + partial LOC cuts + dashboard v3 path
VibhuJawa Jun 13, 2026
d35d055
Fix: restore _parse_xpath_rules, remove test file for deleted scripts
VibhuJawa Jun 13, 2026
b61a463
Parallelize Stage 1c + fix Stage 3 time limit
VibhuJawa Jun 13, 2026
542855c
Apply NeMo Curator dedup/SemDedup/SDG patterns: RayActorPool for stag…
VibhuJawa Jun 13, 2026
f82e293
Fix stage1a arg: --workers -> --cpus-per-actor (RayActorPool rewrite)
VibhuJawa Jun 13, 2026
ede98e5
Fix cluster env + LOC reductions + Ray tmp dir + library sync
VibhuJawa Jun 13, 2026
5e41953
Rewrite stage1b: ProcessingStage + RayActorPoolExecutor (no multiproc…
VibhuJawa Jun 13, 2026
352bf02
Tune stage1a: cpus-per-actor 4->1 for max parallelism (64 actors vs 16)
VibhuJawa Jun 13, 2026
508a93f
Fix Pipeline API: executor goes in run() not __init__()
VibhuJawa Jun 13, 2026
3b729d0
Fix compare_f1.py: handle directory baseline with glob pattern
VibhuJawa Jun 13, 2026
9379d4f
Fix stage1b Ray throughput: exclude HTML from actor results, join on …
VibhuJawa Jun 13, 2026
5d65832
Fix abstract method: add process() to Stage1c and Stage2b ProcessingS…
VibhuJawa Jun 13, 2026
6b46510
Fix GPU utilization in HostDBSCANStage: lower threshold + batch 16 hosts
VibhuJawa Jun 13, 2026
b6b25ae
Fix Stage 1c/2b: RayDataExecutor -> RayActorPoolExecutor for true par…
VibhuJawa Jun 13, 2026
4058a36
Fix GPU not used: set LD_LIBRARY_PATH for cuML in actor setup()
VibhuJawa Jun 13, 2026
ae6c042
Clean GPU fix: use ProcessingStage.runtime_env for LD_LIBRARY_PATH (C…
VibhuJawa Jun 13, 2026
a1a4771
Use dripper_cached_venv for Stage 1b — unified GPU env with cuML + vl…
VibhuJawa Jun 13, 2026
2c27fdf
ruff fix runtime_env in stage1b
VibhuJawa Jun 13, 2026
7cce928
Remove runtime_env LD_LIBRARY_PATH — dripper_cached_venv works natively
VibhuJawa Jun 13, 2026
3455f9f
Fix batch_size=1 for Stage1c+Stage2b: max actor parallelism
VibhuJawa Jun 13, 2026
ebfe5bf
Simplify: reduce LOC, remove dead code and unused paths in tutorial s…
VibhuJawa Jun 14, 2026
a42a77c
feat: remove dead ProcessPool path, collapse argparse, drop dashboard…
VibhuJawa Jun 14, 2026
8dd6c85
Remove non-tutorial files, cut test_stage.py from 2435 to 773 lines
VibhuJawa Jun 14, 2026
90704cd
Deep simplify: -1,433 lines via Curator patterns + dead code removal
VibhuJawa Jun 14, 2026
323a1bf
Add single-command YAML-driven pipeline runner with validation
VibhuJawa Jun 14, 2026
6e17b5c
Apply simplify review: remove dead code, dedup helpers, fix output_ba…
VibhuJawa Jun 14, 2026
3eac0dd
Add DripperHTMLWorkflow — SemanticDedup-style user entry point
VibhuJawa Jun 14, 2026
1071962
Restructure to match SemanticDedup pattern: workflow, simplified tuto…
VibhuJawa Jun 14, 2026
dab9753
Remove defensive binding guards; assume mineru-html and llm-web-kit i…
VibhuJawa Jun 14, 2026
093e688
Remove local-only scripts accidentally added by tutorial fix agent
VibhuJawa Jun 14, 2026
ba951d6
Add quickstart.py and test_workflow.py matching SemanticDedup style
VibhuJawa Jun 14, 2026
2ba4012
Replace print() with loguru.logger in tutorial scripts
VibhuJawa Jun 14, 2026
5ecf514
Complete type annotations; add DripperConfig typed config dataclass
VibhuJawa Jun 14, 2026
f08e490
Fix 3 bugs found during retest; retest confirms F1=0.8443 > 0.84 ✅
VibhuJawa Jun 14, 2026
a2f6b3a
Remove local-only dev files accidentally added to PR
VibhuJawa Jun 14, 2026
1a1fc94
Cut quickstart.py from 344 to 140 lines matching SemanticDedup tutori…
VibhuJawa Jun 14, 2026
bc2f514
Update STYLE_GAPS.md: mark completed items, add new gaps
VibhuJawa Jun 14, 2026
4da518a
Split stage.py monolith into focused per-stage files (SemanticDedup p…
VibhuJawa Jun 14, 2026
2e3c771
WorkflowRunResult return type; cut test_workflow.py to ~120 lines
VibhuJawa Jun 14, 2026
fc1e2d8
Fix workflow.py import after stage.py split
VibhuJawa Jun 14, 2026
f5e4342
Reduce stage.py and stage_gpu_pipeline.py LOC
VibhuJawa Jun 14, 2026
229c141
Simplify stage3 argparse: optional DripperConfig loading
VibhuJawa Jun 14, 2026
5329814
Extract DripperHTMLLayoutTemplateStage to layout_template.py
VibhuJawa Jun 14, 2026
0a2fbf4
Cut test_stage.py 775->under 500; update STYLE_GAPS.md iter 4 status
VibhuJawa Jun 14, 2026
1582e02
Compress layout_template.py: remove verbose private docstrings (-52 l…
VibhuJawa Jun 14, 2026
39f7548
Reduce stage3_cpu_propagation.py: merge config dataclasses, remove cruft
VibhuJawa Jun 14, 2026
f1d5ed0
Cut layout_template.py by removing verbose private docstrings and com…
VibhuJawa Jun 14, 2026
ef4c978
Extract URL helpers to _url_helpers.py; collapse signature dispatchers
VibhuJawa Jun 14, 2026
cdf862c
Restore _append_warning to stage.py (accidentally removed during redu…
VibhuJawa Jun 14, 2026
fabad0b
Fix _token_f1 import: moved to _url_helpers.py by extraction
VibhuJawa Jun 14, 2026
89c1cbc
Fix P1 bugs: broken imports, missing @dataclass, assert in production…
VibhuJawa Jun 14, 2026
7a47c60
Reduce DripperHTMLLayoutTemplateStage from 61 to ~20 fields
VibhuJawa Jun 14, 2026
20148ba
Migrate LBP logic to library; thin tutorial scripts
VibhuJawa Jun 14, 2026
9147b2c
Auto-fix ruff lint and format issues
VibhuJawa Jun 14, 2026
ecd7520
Update STYLE_GAPS.md: swarm results + next iteration gaps
VibhuJawa Jun 14, 2026
b9eca4c
Fix Gap 7.3/7.5: replace anonymous _make_stage_cls with DripperHTMLPr…
VibhuJawa Jun 14, 2026
64b5c3e
Reduce layout_template.py: extract planning fns, tighten exceptions
VibhuJawa Jun 14, 2026
4019149
Remove non-essential files: reduce PR to core library + minimal tutorial
VibhuJawa Jun 14, 2026
33c5db2
Merge extraction/inference/preprocessing into _base_stages.py (-71 li…
VibhuJawa Jun 14, 2026
8e4ddc2
Update __init__.py: import from _base_stages instead of 3 separate files
VibhuJawa Jun 14, 2026
74efd57
Remove extraction/inference/preprocessing (merged into _base_stages.py)
VibhuJawa Jun 14, 2026
2834024
Cut layout_template.py from 1400 to ~1255 lines (-418 lines removed)
VibhuJawa Jun 14, 2026
b238003
Thin tutorial scripts to minimal wrappers around library stages
VibhuJawa Jun 14, 2026
510bd51
Cut _layout_planning/url_helpers (-110 lines), rewrite test_stage.py …
VibhuJawa Jun 14, 2026
aec613f
Cut tutorial script docstrings/helpers: stage1a/1c/2b/compare_f1 (-79…
VibhuJawa Jun 14, 2026
e0b3d66
Fix workflow.py: import from _base_stages (extraction/inference/prepr…
VibhuJawa Jun 14, 2026
58e32e5
Cut gpu_layout_clustering.py: remove verbose docstrings/comments (-60…
VibhuJawa Jun 14, 2026
be5802a
Trim quickstart.py module docstring (-16 lines)
VibhuJawa Jun 14, 2026
49de613
Trim stage1b module docstring (-6 lines)
VibhuJawa Jun 14, 2026
d5d5972
Fix layout_template.py: import from _layout_planning (not deleted _ur…
VibhuJawa Jun 14, 2026
e4fef09
Agent cuts: merge _url_helpers into _layout_planning; cut stage.py/pr…
VibhuJawa Jun 14, 2026
0b7a431
Minor additional cuts to _base_stages.py, propagation_stage.py, stage…
VibhuJawa Jun 14, 2026
043014b
Merge _url_helpers into _layout_planning; cut stage.py and propagatio…
VibhuJawa Jun 14, 2026
f66e457
Trim __init__.py module docstring (-12 lines)
VibhuJawa Jun 14, 2026
71b89a8
Cut layout_template.py: remove class/method docstrings (-12 lines)
VibhuJawa Jun 14, 2026
94902c9
Cut layout_template.py: remove module docstring + section header comm…
VibhuJawa Jun 14, 2026
70fa357
Move DripperLayoutAdvancedConfig to _layout_planning.py; cut layout_t…
VibhuJawa Jun 14, 2026
56291e0
Fix duplicate DripperLayoutAdvancedConfig: remove from layout_templat…
VibhuJawa Jun 14, 2026
ff1198b
Trim tutorial script docstrings: stage_gpu_pipeline (-8 lines), stage…
VibhuJawa Jun 14, 2026
7d0318c
Trim gpu_layout_clustering.py: remove section separator comments (-6 …
VibhuJawa Jun 14, 2026
12ad184
Cut layout_template.py: inline DripperLayoutAdvancedConfig, remove me…
VibhuJawa Jun 14, 2026
f4de8ff
Cut layout_template.py: simplify _adv property, flatten validation (-…
VibhuJawa Jun 14, 2026
1fbdaf2
Cut layout_template.py: inline helpers, remove CC config boilerplate
VibhuJawa Jun 14, 2026
cd07244
Add module docstrings to _base_stages.py and layout_template.py (styl…
VibhuJawa Jun 15, 2026
badd5dd
Add DripperHTMLWorkflow.__post_init__ validation (SemanticDedup pattern)
VibhuJawa Jun 15, 2026
7c825c5
Add workflow validation tests (none client, empty model, bad threshold)
VibhuJawa Jun 15, 2026
b496489
Restore stage3b: GPU LLM fallback for siblings where propagation failed
VibhuJawa Jun 15, 2026
5786aa1
Add single-command run_pipeline.py; fix DripperHTMLWorkflow._build_st…
VibhuJawa Jun 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions nemo_curator/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class RayClient:
Args:
ray_port: The port number of the Ray GCS.
ray_dashboard_port: The port number of the Ray dashboard.
ray_min_worker_port: The first worker port Ray may bind.
ray_max_worker_port: The last worker port Ray may bind.
ray_temp_dir: The temporary directory to use for Ray.
include_dashboard: Whether to include dashboard integration. If true, adds Ray metrics service discovery.
ray_metrics_port: The port number of the Ray metrics.
Expand All @@ -79,6 +81,8 @@ class RayClient:
ray_port: int = DEFAULT_RAY_PORT
ray_dashboard_port: int = DEFAULT_RAY_DASHBOARD_PORT
ray_client_server_port: int = DEFAULT_RAY_CLIENT_SERVER_PORT
ray_min_worker_port: int | None = None
ray_max_worker_port: int | None = None
ray_temp_dir: str = DEFAULT_RAY_TEMP_DIR
include_dashboard: bool = True
ray_metrics_port: int = DEFAULT_RAY_METRICS_PORT
Expand Down Expand Up @@ -155,6 +159,8 @@ def start(self) -> None:
ray_metrics_port=self.ray_metrics_port,
ray_client_server_port=self.ray_client_server_port,
ray_dashboard_host=self.ray_dashboard_host,
ray_min_worker_port=self.ray_min_worker_port,
ray_max_worker_port=self.ray_max_worker_port,
num_gpus=self.num_gpus,
num_cpus=self.num_cpus,
object_store_memory=self.object_store_memory,
Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/core/serve/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def _resolve_effective_router(

- ``mode``: honor ``router.mode`` if set; otherwise auto-pick ``"kv"``
when any model uses ``mode="disagg"``, else leave unset so the
Dynamo frontend falls back to its own ``round_robin`` default.
Dynamo frontend falls back to its own ``round-robin`` default.
- ``kv_events``: when we auto-pick ``mode="kv"`` we also auto-enable
``kv_events`` so the router consumes what prefill workers publish
unconditionally in disagg. If the user set ``router.mode`` explicitly
Expand Down
19 changes: 17 additions & 2 deletions nemo_curator/core/serve/dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,41 @@ def __post_init__(self) -> None:
raise ValueError(msg)


DynamoRouterMode = Literal[
"round-robin",
"round_robin",
"random",
"power-of-two",
"kv",
"direct",
"least-loaded",
"device-aware-weighted",
]


@dataclass
class DynamoRouterConfig:
"""Frontend router config for Dynamo.

``mode=None`` means "auto": Curator picks ``"kv"`` if any model uses
``mode="disagg"``, else leaves ``--router-mode`` unset so the Dynamo
frontend falls back to its own ``round_robin`` default. ``kv_events``
frontend falls back to its own ``round-robin`` default. ``kv_events``
only applies when ``mode == "kv"``: pass ``kv_events=True`` to opt into
exact ZMQ KV-cache event publishing; the default uses the router's
approximate tree-based tracking. Anything else is forwarded to the
Dynamo frontend as CLI args via ``router_kwargs``.
"""

mode: Literal["round_robin", "random", "kv", "direct"] | None = None
mode: DynamoRouterMode | None = None
kv_events: bool = False
router_kwargs: dict[str, Any] = field(default_factory=dict)

_RESERVED_ROUTER_KWARGS: ClassVar[frozenset[str]] = frozenset({"router_mode", "router_kv_events"})
_MODE_ALIASES: ClassVar[dict[str, str]] = {"round_robin": "round-robin"}

def __post_init__(self) -> None:
if self.mode is not None:
self.mode = self._MODE_ALIASES.get(self.mode, self.mode) # type: ignore[assignment]
if self.mode is not None and self.mode != "kv" and self.kv_events:
msg = f"kv_events=True is only meaningful when mode='kv'; got mode={self.mode!r}."
raise ValueError(msg)
Expand Down
12 changes: 12 additions & 0 deletions nemo_curator/core/serve/dynamo/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import json
import os
import tempfile
from functools import reduce
from pathlib import Path
Expand Down Expand Up @@ -67,12 +68,19 @@
"config": {"setup_timeout_seconds": 600},
}

_USE_DRIVER_ENV_VAR = "NEMO_CURATOR_DYNAMO_USE_DRIVER_ENV"


@ray.remote
def _write_actor_overrides_file(path: str, body: str) -> None:
Path(path).write_text(body)


def _use_driver_env_for_dynamo() -> bool:
"""Return true when Dynamo actors should use the driver's Python env."""
return os.environ.get(_USE_DRIVER_ENV_VAR, "0").lower() in {"1", "true", "yes", "on"}


def ensure_actor_overrides_on_all_nodes(*, ignore_head_node: bool = False) -> None:
"""Write the actor-venv ``--override`` file at a fixed path on every alive node.

Expand Down Expand Up @@ -109,13 +117,17 @@ def ensure_actor_overrides_on_all_nodes(*, ignore_head_node: bool = False) -> No

def dynamo_runtime_env(model_config: DynamoVLLMModelConfig) -> dict[str, Any]:
"""Merge the user's ``runtime_env`` with the Dynamo-vLLM package pin."""
if _use_driver_env_for_dynamo():
return model_config.runtime_env or {}
return BaseModelConfig.merge_runtime_envs(DYNAMO_VLLM_RUNTIME_ENV, model_config.runtime_env or None)


def merge_model_runtime_envs(models: list[DynamoVLLMModelConfig]) -> dict[str, Any]:
"""Merge every model's ``runtime_env`` onto the Dynamo-vLLM pin for the shared frontend actor."""
envs = [m.runtime_env for m in models if m.runtime_env]
user_merged = reduce(BaseModelConfig.merge_runtime_envs, envs) if envs else None
if _use_driver_env_for_dynamo():
return user_merged or {}
return BaseModelConfig.merge_runtime_envs(DYNAMO_VLLM_RUNTIME_ENV, user_merged)


Expand Down
12 changes: 9 additions & 3 deletions nemo_curator/core/serve/ray_serve/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,17 @@ def _deploy(self) -> None:
llm_configs = [self._to_llm_config(model, quiet_runtime_env=quiet_env) for model in server.models]

build_args: dict[str, Any] = {"llm_configs": llm_configs}
ingress_deployment_config = dict(server.backend.ingress_deployment_config)
if quiet_env:
# Suppress access logs on the OpenAI ingress deployment too.
build_args["ingress_deployment_config"] = {
"ray_actor_options": {"runtime_env": quiet_env},
}
ray_actor_options = dict(ingress_deployment_config.get("ray_actor_options", {}))
ray_actor_options["runtime_env"] = BaseModelConfig.merge_runtime_envs(
ray_actor_options.get("runtime_env", {}),
quiet_env,
)
ingress_deployment_config["ray_actor_options"] = ray_actor_options
if ingress_deployment_config:
build_args["ingress_deployment_config"] = ingress_deployment_config

from ray import serve
from ray.serve.llm import build_openai_app
Expand Down
1 change: 1 addition & 0 deletions nemo_curator/core/serve/ray_serve/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ class RayServeServerConfig(BaseServerConfig):
"""Server-level Ray Serve config."""

model_configs: ClassVar[tuple[type[BaseModelConfig], ...]] = (RayServeModelConfig,)
ingress_deployment_config: dict[str, Any] = field(default_factory=dict)
6 changes: 6 additions & 0 deletions nemo_curator/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def init_cluster( # noqa: PLR0913
ray_metrics_port: int,
ray_client_server_port: int,
ray_dashboard_host: str,
ray_min_worker_port: int | None = None,
ray_max_worker_port: int | None = None,
num_gpus: int | None = None,
num_cpus: int | None = None,
object_store_memory: int | None = None,
Expand All @@ -164,6 +166,10 @@ def init_cluster( # noqa: PLR0913
ray_command.extend(["--dashboard-port", str(ray_dashboard_port)])
ray_command.extend(["--ray-client-server-port", str(ray_client_server_port)])
ray_command.extend(["--temp-dir", ray_temp_dir])
if ray_min_worker_port is not None:
ray_command.extend(["--min-worker-port", str(ray_min_worker_port)])
if ray_max_worker_port is not None:
ray_command.extend(["--max-worker-port", str(ray_max_worker_port)])
if object_store_memory is not None:
ray_command.extend(["--object-store-memory", str(object_store_memory)])
ray_command.extend(["--disable-usage-stats"])
Expand Down
60 changes: 35 additions & 25 deletions nemo_curator/models/client/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
import asyncio
import secrets
from abc import ABC, abstractmethod
from collections.abc import Iterable
from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass
from typing import TypeVar

from loguru import logger

T = TypeVar("T")


class ConversationFormatter(ABC):
"""
Expand Down Expand Up @@ -116,23 +119,15 @@ async def _query_model_impl(
msg = "Subclass of AsyncLLMClient must implement '_query_model_impl'"
raise NotImplementedError(msg)

async def query_model( # noqa: C901, PLR0912
self,
*,
messages: Iterable,
model: str,
conversation_formatter: ConversationFormatter | None = None,
generation_config: GenerationConfig | dict | None = None,
) -> list[str]:
"""
Query the model with automatic retry and concurrency control.
"""
# Use default config if none provided
@staticmethod
def _coerce_generation_config(generation_config: GenerationConfig | dict | None) -> GenerationConfig:
if generation_config is None:
generation_config = GenerationConfig()
elif isinstance(generation_config, dict):
generation_config = GenerationConfig(**generation_config)
return GenerationConfig()
if isinstance(generation_config, dict):
return GenerationConfig(**generation_config)
return generation_config

async def _run_with_retry_and_concurrency(self, operation: Callable[[], Awaitable[T]]) -> T: # noqa: C901, PLR0912
# Initialize semaphore if not already done or if we're in a different event loop
current_loop = asyncio.get_running_loop()
if self._semaphore is None or self._semaphore_loop != current_loop:
Expand Down Expand Up @@ -179,12 +174,7 @@ async def query_model( # noqa: C901, PLR0912

# Attempt the query
try:
return await self._query_model_impl(
messages=messages,
model=model,
conversation_formatter=conversation_formatter,
generation_config=generation_config,
)
return await operation()
except Exception as e:
last_exception = e
# If this is the last attempt, provide helpful error message
Expand All @@ -208,7 +198,27 @@ async def query_model( # noqa: C901, PLR0912
raise last_exception

# This should never be reached, but add explicit return for linter
logger.warning(
"Unexpected code path: AsyncLLMClient.query_model completed without returning a result or raising an exception"
msg = "Unexpected code path: AsyncLLMClient operation completed without returning a result or raising"
raise RuntimeError(msg)

async def query_model(
self,
*,
messages: Iterable,
model: str,
conversation_formatter: ConversationFormatter | None = None,
generation_config: GenerationConfig | dict | None = None,
) -> list[str]:
"""
Query the model with automatic retry and concurrency control.
"""
# Use default config if none provided
generation_config = self._coerce_generation_config(generation_config)
return await self._run_with_retry_and_concurrency(
lambda: self._query_model_impl(
messages=messages,
model=model,
conversation_formatter=conversation_formatter,
generation_config=generation_config,
)
return []
)
96 changes: 94 additions & 2 deletions nemo_curator/models/client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,25 @@

import warnings
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any

from loguru import logger
from openai import AsyncOpenAI, OpenAI

from nemo_curator.models.client.llm_client import AsyncLLMClient, ConversationFormatter, GenerationConfig, LLMClient


@dataclass(frozen=True)
class OpenAIChatCompletionResult:
"""OpenAI-compatible chat completion content and aggregate usage."""

contents: list[str]
prompt_tokens: int | None = None
completion_tokens: int | None = None
total_tokens: int | None = None


class OpenAIClient(LLMClient):
"""
A wrapper around OpenAI's Python client for querying models
Expand All @@ -45,6 +57,21 @@ def query_model(
conversation_formatter: ConversationFormatter | None = None,
generation_config: GenerationConfig | dict | None = None,
) -> list[str]:
return self.query_model_with_usage(
messages=messages,
model=model,
conversation_formatter=conversation_formatter,
generation_config=generation_config,
).contents

def query_model_with_usage(
self,
*,
messages: Iterable,
model: str,
conversation_formatter: ConversationFormatter | None = None,
generation_config: GenerationConfig | dict | None = None,
) -> OpenAIChatCompletionResult:
if conversation_formatter is not None:
warnings.warn("conversation_formatter is not used in an OpenAIClient", stacklevel=2)

Expand Down Expand Up @@ -80,7 +107,7 @@ def query_model(

response = self.client.chat.completions.create(**create_kwargs)

return [choice.message.content for choice in response.choices]
return _completion_result_from_response(response)


class AsyncOpenAIClient(AsyncLLMClient):
Expand Down Expand Up @@ -122,6 +149,25 @@ async def _query_model_impl(
"""
Internal implementation of query_model without retry/concurrency logic.
"""
result = await self._query_model_with_usage_impl(
messages=messages,
model=model,
conversation_formatter=conversation_formatter,
generation_config=generation_config,
)
return result.contents

async def _query_model_with_usage_impl(
self,
*,
messages: Iterable,
model: str,
conversation_formatter: ConversationFormatter | None = None,
generation_config: GenerationConfig | dict | None = None,
) -> OpenAIChatCompletionResult:
"""
Internal implementation of query_model_with_usage without retry/concurrency logic.
"""
if conversation_formatter is not None:
warnings.warn("conversation_formatter is not used in an AsyncOpenAIClient", stacklevel=2)

Expand Down Expand Up @@ -157,4 +203,50 @@ async def _query_model_impl(

response = await self.client.chat.completions.create(**create_kwargs)

return [choice.message.content for choice in response.choices]
return _completion_result_from_response(response)

async def query_model_with_usage(
self,
*,
messages: Iterable,
model: str,
conversation_formatter: ConversationFormatter | None = None,
generation_config: GenerationConfig | dict | None = None,
) -> OpenAIChatCompletionResult:
"""
Query the model and keep OpenAI-compatible usage counters when the server returns them.
"""
generation_config = self._coerce_generation_config(generation_config)
return await self._run_with_retry_and_concurrency(
lambda: self._query_model_with_usage_impl(
messages=messages,
model=model,
conversation_formatter=conversation_formatter,
generation_config=generation_config,
)
)


def _completion_result_from_response(response: Any) -> OpenAIChatCompletionResult: # noqa: ANN401
usage = getattr(response, "usage", None)
return OpenAIChatCompletionResult(
contents=[choice.message.content for choice in response.choices],
prompt_tokens=_usage_int(usage, "prompt_tokens"),
completion_tokens=_usage_int(usage, "completion_tokens"),
total_tokens=_usage_int(usage, "total_tokens"),
)


def _usage_int(usage: Any, field: str) -> int | None: # noqa: ANN401
if usage is None:
return None
value = usage.get(field) if isinstance(usage, dict) else getattr(usage, field, None)
if isinstance(value, bool):
return None
if isinstance(value, int):
return value
if isinstance(value, float) and value.is_integer():
return int(value)
if isinstance(value, str) and value.isdigit():
return int(value)
return None
Loading
Loading