diff --git a/.gitignore b/.gitignore index e7f8718b7..e4ac9d9bf 100644 --- a/.gitignore +++ b/.gitignore @@ -85,6 +85,14 @@ venv.bak/ # Local scratch space .scratch/ +# Generated benchmark/report output +/artifacts/ +/reports/ +/scripts/benchmarks/benchmark_async_scheduling.py +/scripts/benchmarks/export_async_scheduling_perfetto.py +/scripts/benchmarks/generate_async_scheduling_idle_report.py +/scripts/benchmarks/run_async_scheduling_idle_regression.py + docs/notebooks/ docs/notebook_source/*.ipynb docs/notebook_source/*.csv diff --git a/architecture/dataset-builders.md b/architecture/dataset-builders.md index 825a2a392..a1ffbfe2a 100644 --- a/architecture/dataset-builders.md +++ b/architecture/dataset-builders.md @@ -35,7 +35,7 @@ Preparation (`_prepare_async_run`): 4. Constructs `CompletionTracker`, `RowGroupBufferManager`, `AsyncTaskScheduler` 5. Hooks `ProcessorRunner` for pre-batch and post-batch stages -`AsyncTaskScheduler` runs on a dedicated async loop with frontier-driven dispatch, semaphore-based capacity limits, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. Ready frontier tasks are admitted through a virtual-time fair queue so one hot column or model-backed generator cannot consume the whole submission window before peer work gets a turn. +`AsyncTaskScheduler` runs on a dedicated async loop with frontier-driven dispatch, task-admission leases, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. Ready frontier tasks enter `FairTaskQueue`, are selected through virtual-time ordering, and are committed only after `TaskAdmissionController` acquires the required scheduler resources. ### Execution Graph @@ -121,7 +121,7 @@ DatasetBuilder.build() → _prepare_async_run() → ExecutionGraph.create() → CompletionTracker.with_graph() - → AsyncTaskScheduler(semaphores, salvage_rounds) + → AsyncTaskScheduler(task admission, fair queue, salvage_rounds) → scheduler.run() → for each row group, fairly admit ready tasks from frontier → tasks execute generators, update CompletionTracker @@ -133,7 +133,7 @@ DatasetBuilder.build() - **Dual execution engines behind one API.** The sequential engine is simpler and easier to debug; the async engine adds row-group parallelism for throughput. Users switch via an environment variable without changing their code. - **DAG-driven ordering** ensures columns with dependencies (e.g., a judge column that depends on a text column) are generated in the correct order, regardless of the order they appear in the config. -- **Fair async admission** keeps the scheduler flowing across ready columns and model groups. Global semaphores still bound memory/coroutine growth, while per-group virtual-time queues prevent a large ready frontier from degenerating into a column-by-column wave. LLM admission caps are peer-sensitive: a solo model group can fill available global capacity, but once another scheduling group has queued work the saturated group yields until peers get admission slots or admitted tasks complete. +- **Fair async admission** keeps the scheduler flowing across ready columns and model groups. `FairTaskQueue.select_next(...)` chooses eligible ready work, `TaskAdmissionController` leases scheduler resources before spawn, and `FairTaskQueue.commit(...)` removes the selected task only after admission succeeds. Per-group virtual-time ordering prevents a large ready frontier from degenerating into a column-by-column wave, and scheduler-resource accounting remains separate from provider/model request admission. - **Salvage rounds in async mode** retry failed tasks after all other tasks in a round complete, improving resilience against transient LLM failures without blocking the entire generation. - **Unified DAG construction.** `topologically_sort_column_configs` (in `execution_graph.py`) determines column ordering using Kahn's algorithm; the runtime `ExecutionGraph` adds strategy-aware dependency tracking for the async scheduler. diff --git a/architecture/models.md b/architecture/models.md index d7af0cdac..fc90bcb8a 100644 --- a/architecture/models.md +++ b/architecture/models.md @@ -1,6 +1,6 @@ # Models -The model subsystem provides a unified interface for LLM access: chat completions, embeddings, and image generation. It handles client creation, retry, rate-limit throttling, usage tracking, and MCP tool integration. +The model subsystem provides a unified interface for LLM access: chat completions, embeddings, and image generation. It handles client creation, retry, request admission, usage tracking, and MCP tool integration. Source: `packages/data-designer-engine/src/data_designer/engine/models/` @@ -11,12 +11,12 @@ The model subsystem is layered: ``` ModelRegistry (lazy facade-per-alias) └── ModelFacade (completion, embeddings, image gen, MCP tool loops) - └── ThrottledModelClient (AIMD rate limiting) + └── ModelRequestExecutor (request admission + provider execution) └── ModelClient (OpenAI-compatible or Anthropic adapter) └── RetryTransport (httpx-level retries) ``` -Generators never interact with HTTP clients directly. They request a `ModelFacade` by alias from the `ModelRegistry`, which handles lazy construction and shared throttle state. +Generators never interact with HTTP clients directly. They request a `ModelFacade` by alias from the `ModelRegistry`, which handles lazy construction, request-resource canonicalization, and shared adaptive request admission state. ## Key Components @@ -31,13 +31,13 @@ Defines the contract: sync/async chat, embeddings, image generation, `supports_* `create_model_client` routes by provider type to the appropriate adapter. Optionally wraps with: - **`RetryTransport`** — httpx-level retries via `httpx_retries.RetryTransport`. `HttpModelClient` sets `strip_rate_limit_codes=True` for the async client and `False` for the sync client (`http_model_client.py`), which controls whether 429 responses are eligible for transport-layer retries. -- **`ThrottledModelClient`** — AIMD (Additive Increase, Multiplicative Decrease) concurrency control per throttle domain. +- **`ModelRequestExecutor`** — maps model-call attempts to request-admission items, acquires request leases, invokes the provider client, and releases the exact lease on every terminal path. -### ThrottleManager +### Request Admission -Manages concurrency limits per `ThrottleDomain` (CHAT, EMBEDDING, IMAGE, HEALTHCHECK), keyed by `(provider_name, model_id)`. Thread-safe with a shared lock for sync/async access. +`RequestAdmissionController` manages provider/model/domain request resources. `AdaptiveRequestAdmissionController` adds AIMD (Additive Increase, Multiplicative Decrease) adaptation per `RequestDomain` (`chat`, `embedding`, `image`, `healthcheck`) under the provider/model static cap. -`ThrottledModelClient` wraps each API call in a context manager that acquires/releases throttle capacity and adjusts limits on success (additive increase) or rate-limit errors (multiplicative decrease). +`ModelRequestExecutor` wraps each provider call with a request-admission lease and feeds success or rate-limit outcomes back to the controller. `RequestResourceResolver` owns canonical provider/model/domain identity so aliases that target the same endpoint share request capacity. ### ModelFacade @@ -50,7 +50,7 @@ The primary interface for generators. Holds a `ModelConfig`, `ModelClient`, opti ### ModelRegistry -Lazy `ModelFacade` construction per alias. Registers a shared `ThrottleManager` across all facades for coordinated rate limiting. Provides `get_model_usage_stats` and `log_model_usage` for post-build reporting. +Lazy `ModelFacade` construction per alias. Registers shared request-admission state across all facades for coordinated provider/model/domain capacity. Provides `get_model_usage_stats` and `log_model_usage` for post-build reporting. ### Usage Tracking @@ -59,18 +59,18 @@ Lazy `ModelFacade` construction per alias. Registers a shared `ThrottleManager` ## Data Flow 1. Generator requests a model by alias from `ModelRegistry` -2. Registry lazily creates `ModelFacade` with the appropriate client and throttle config +2. Registry lazily creates `ModelFacade` with the appropriate client and request-admission executor 3. Generator calls `completion()` with prompt/messages -4. `ModelFacade` builds kwargs, calls `ThrottledModelClient` -5. Throttle layer acquires capacity, delegates to `ModelClient` +4. `ModelFacade` builds kwargs, calls `ModelRequestExecutor` +5. Request admission acquires a provider/model/domain lease, delegates to `ModelClient` 6. `ModelClient` makes the HTTP request through `RetryTransport` 7. Response flows back; usage is tracked; if MCP tools are configured, tool calls are executed and results fed back for another completion round ## Design Decisions -- **Facade pattern** hides HTTP, retry, throttle, and MCP complexity from generators. Generators see `completion()` and get back parsed results. -- **AIMD throttling at the application layer** rather than relying solely on HTTP retries. This provides smoother throughput under rate limits — the transport layer still handles many transient failures, while the throttle manager adjusts concurrency to avoid sustained 429 storms. -- **429 handling depends on sync vs async `HttpModelClient`** — The async client uses `strip_rate_limit_codes=True`, so 429s are not retried at the transport layer and rate-limit signals reach `ThrottledModelClient` / AIMD quickly. The sync client uses `strip_rate_limit_codes=False`, so 429s may still be retried transparently at the transport layer before surfacing to callers. +- **Facade pattern** hides HTTP, retry, request admission, and MCP complexity from generators. Generators see `completion()` and get back parsed results. +- **AIMD request admission at the application layer** rather than relying solely on HTTP retries. This provides smoother throughput under rate limits: the transport layer still handles many transient failures, while adaptive request admission adjusts concurrency to avoid sustained 429 storms. +- **429 handling depends on sync vs async `HttpModelClient`** — The async client uses `strip_rate_limit_codes=True`, so 429s are not retried at the transport layer and rate-limit signals reach `ModelRequestExecutor` / request admission quickly. The sync client uses `strip_rate_limit_codes=False`, so 429s may still be retried transparently at the transport layer before surfacing to callers. - **Distribution-valued inference parameters** (`temperature`, `top_p` as `UniformDistribution` or `ManualDistribution`) enable controlled randomness across a dataset without per-row config changes. - **Lazy facade construction** avoids health-checking or connecting to models that are configured but never used in a particular generation run. diff --git a/architecture/overview.md b/architecture/overview.md index 30c91bdfb..10bde6c90 100644 --- a/architecture/overview.md +++ b/architecture/overview.md @@ -30,7 +30,7 @@ Users declare what their data should look like through config objects (columns, | `DataDesigner` | `data-designer` | Public API — `create()`, `preview()`, `validate()` | | `DataDesignerConfigBuilder` | `data-designer-config` | Fluent builder for dataset configs | | `DatasetBuilder` | `data-designer-engine` | Orchestrates generation (sync or async) | -| `ModelFacade` / `ModelRegistry` | `data-designer-engine` | LLM client abstraction with retry, throttle, usage tracking | +| `ModelFacade` / `ModelRegistry` | `data-designer-engine` | LLM client abstraction with retry, request admission, usage tracking | | `MCPFacade` / `MCPRegistry` | `data-designer-engine` | Tool execution via Model Context Protocol | | `ColumnGeneratorRegistry` | `data-designer-engine` | Maps column types to generator implementations | | `PluginRegistry` | `data-designer-config` | Discovers and registers entry-point plugins | @@ -44,7 +44,7 @@ Users declare what their data should look like through config objects (columns, 3. **Generation** — `DatasetBuilder` instantiates column generators from the registry, then executes one of two paths: - **Sequential** (default): batch loop over columns in topological order. Each generator produces its column via `CELL_BY_CELL` (threaded fan-out) or `FULL_COLUMN` strategy. - - **Async** (`DATA_DESIGNER_ASYNC_ENGINE=1`): builds an `ExecutionGraph`, partitions rows into groups, and dispatches tasks via `AsyncTaskScheduler` with semaphore-based concurrency, salvage rounds, and per-row-group checkpointing. + - **Async** (`DATA_DESIGNER_ASYNC_ENGINE=1`): builds an `ExecutionGraph`, partitions rows into groups, and dispatches tasks via `AsyncTaskScheduler` with `FairTaskQueue` selection, `TaskAdmissionController` scheduler-resource leases, salvage rounds, and per-row-group checkpointing. 4. **Post-processing** — `ProcessorRunner` applies transformations (pre-batch, post-batch, after-generation). Profilers analyze the generated dataset. @@ -61,7 +61,7 @@ Users declare what their data should look like through config objects (columns, - [Config Layer](config.md) — builder API, column types, model configs, plugin system - [Engine Layer](engine.md) — compilation, generators, registries -- [Models](models.md) — model facade, adapters, retry/throttle +- [Models](models.md) — model facade, adapters, retry, request admission - [Dataset Builders](dataset-builders.md) — sync/async orchestration, DAG, batching - [MCP](mcp.md) — tool execution, session pooling - [Sampling](sampling.md) — statistical generators, person/entity data diff --git a/docs/concepts/architecture-and-performance.md b/docs/concepts/architecture-and-performance.md index 5d1545ad7..537fb1aca 100644 --- a/docs/concepts/architecture-and-performance.md +++ b/docs/concepts/architecture-and-performance.md @@ -48,7 +48,7 @@ This guide explains the architecture, execution model, and how to tune performan ## Execution Model !!! note "Two execution engines" - The default execution path is the **async engine**, which dispatches work at the cell level and overlaps independent columns — see [Async Engine](#async-engine) below for its semantics. The legacy **sync engine** is still available for one transitional release via `DATA_DESIGNER_ASYNC_ENGINE=0` and is what this section describes. The configuration knobs documented below (`buffer_size`, `max_parallel_requests`, AIMD throttle config, error handling) apply to both engines; the differences are flagged inline. + The default execution path is the **async engine**, which dispatches work at the cell level and overlaps independent columns — see [Async Engine](#async-engine) below for its semantics. The legacy **sync engine** is still available for one transitional release via `DATA_DESIGNER_ASYNC_ENGINE=0` and is what this section describes. The public configuration knobs documented below (`buffer_size`, `max_parallel_requests`, error handling) apply to both engines; the differences are flagged inline. The sync engine processes datasets in **batches**, with **parallel** operations within each batch. @@ -104,12 +104,12 @@ At any moment, the number of concurrent LLM requests is: ```python concurrent_requests = min( buffer_size, # Records in current batch - current_throttle_limit, # AIMD-managed limit (≤ max_parallel_requests) + current_request_limit, # AIMD-managed limit (≤ max_parallel_requests) remaining_cells_in_column # Cells left to generate ) ``` -`max_parallel_requests` sets the **ceiling**. The actual limit (`current_throttle_limit`) is managed at runtime by an AIMD (Additive Increase / Multiplicative Decrease) controller that reacts to rate-limit signals from the inference server: +`max_parallel_requests` sets the **ceiling**. The actual limit (`current_request_limit`) is managed at runtime by adaptive request admission that reacts to rate-limit signals from the inference server: - **On the first 429 in a burst**: the limit is reduced by a configurable factor (default: 25% reduction) and a cooldown is applied. Further 429s from already in-flight requests in the same burst do not reduce the limit again — they release their permits and hold the limit steady. - **After consecutive successes**: the limit increases by 1 (by default) until it reaches the ceiling or a stabilized rate-limit threshold. @@ -117,7 +117,7 @@ concurrent_requests = min( This means Data Designer automatically finds the right concurrency level for your server without manual tuning. !!! note "Engine paths" - AIMD adaptive concurrency is fully active on the default **async engine**. The legacy **sync engine** is available for one transitional release via `DATA_DESIGNER_ASYNC_ENGINE=0`; on that path 429s are first retried at the HTTP transport layer and AIMD only engages as a fallback. See [Async engine](#async-engine) below. + Adaptive request admission is fully active on the default **async engine**. The legacy **sync engine** is available for one transitional release via `DATA_DESIGNER_ASYNC_ENGINE=0`; on that path 429s are first retried at the HTTP transport layer and AIMD only engages as a fallback. See [Async engine](#async-engine) below. **Example**: With `buffer_size=100` and `max_parallel_requests=32`, Data Designer starts sending up to 32 requests in parallel. If the server returns 429s, concurrency drops automatically (e.g., to 24, then 18) and recovers once the server catches up. @@ -153,7 +153,7 @@ designer.set_run_config(run_config) ### `max_parallel_requests` (InferenceParams) -Sets the **maximum** concurrent LLM API calls **per model**. This is the ceiling that the AIMD throttle controller can ramp up to — the actual concurrency at runtime may be lower if the server signals rate limits. +Sets the **maximum** concurrent LLM API calls **per model**. This is the ceiling that adaptive request admission can ramp up to — the actual concurrency at runtime may be lower if the server signals rate limits. ```python import data_designer.config as dd @@ -170,14 +170,14 @@ model = dd.ModelConfig( **Default**: 4 -**When to increase**: Your inference backend has high throughput capacity, you're using a cloud API with generous rate limits, or you're running vLLM/TensorRT-LLM with multiple GPUs. With AIMD, setting an aggressively high value is safer than before — the system will self-correct downward if the server can't keep up. The salvage queue on the async engine (default) reclaims failed rows; on the sync engine the initial burst of 429s before AIMD stabilizes can drop rows, so start with a more conservative ceiling if you've opted into sync. +**When to increase**: Your inference backend has high throughput capacity, you're using a cloud API with generous rate limits, or you're running vLLM/TensorRT-LLM with multiple GPUs. With adaptive request admission, setting an aggressively high value is safer than before — the system will self-correct downward if the server can't keep up. The salvage queue on the async engine (default) reclaims failed rows; on the sync engine the initial burst of 429s before AIMD stabilizes can drop rows, so start with a more conservative ceiling if you've opted into sync. **When to decrease**: You want to cap resource usage to a known safe level, or you want more predictable/debuggable execution. !!! tip "Finding the optimal value" The right value depends on your inference stack and model. Self-hosted vLLM servers can often handle values as high as 256, 512, or even 1024 depending on your hardware. - With AIMD, a practical approach is to set `max_parallel_requests` to the **upper bound** you're comfortable with and let the throttle controller find the sustainable level automatically. If you see frequent 429 → recovery cycles in the logs, your ceiling is above the server's true capacity but the system is handling it. If you never see any throttle activity, you may have room to increase the ceiling further. + With adaptive request admission, a practical approach is to set `max_parallel_requests` to the **upper bound** you're comfortable with and let the request controller find the sustainable level automatically. If you see frequent 429 → recovery cycles in the logs, your ceiling is above the server's true capacity but the system is handling it. If you never see any request-admission activity, you may have room to increase the ceiling further. **Benchmark approach**: Run a small dataset (e.g., 100 records) with increasing `max_parallel_requests` values (4 → 8 → 16 → 32 → ...) and measure generation time. Stop increasing when the runtime stops decreasing—that's when your inference server is saturated. @@ -198,38 +198,9 @@ designer.set_run_config(run_config) --- -### Adaptive Throttling (RunConfig) +### Adaptive Request Admission -Data Designer uses an AIMD (Additive Increase / Multiplicative Decrease) controller to automatically adjust concurrency per model based on rate-limit feedback from the inference server. The defaults work well for most workloads. Override them via `ThrottleConfig` only when you understand the trade-offs. - -!!! note "Engine paths" - Adaptive throttling is fully active on the default **async engine**, where 429 responses propagate directly to the AIMD controller. On the legacy **sync engine** (`DATA_DESIGNER_ASYNC_ENGINE=0`), 429s are first retried at the HTTP transport layer; `ThrottleConfig` settings only take effect as a fallback if transport retries are exhausted. - -```python -import data_designer.config as dd -from data_designer.interface import DataDesigner - -run_config = dd.RunConfig( - throttle=dd.ThrottleConfig( - reduce_factor=0.75, # Multiply limit by this on a 429 (default: 0.75) - additive_increase=1, # Add this many slots after success_window successes (default: 1) - success_window=25, # Consecutive successes before increasing (default: 25) - cooldown_seconds=2.0, # Pause after a 429 when no Retry-After header (default: 2.0) - ceiling_overshoot=0.10, # Probe 10% above observed server limit (default: 0.10) - ), -) - -designer = DataDesigner() -designer.set_run_config(run_config) -``` - -| Parameter | Default | Effect | -|-----------|---------|--------| -| `reduce_factor` | 0.75 | How aggressively to cut concurrency on a 429. Lower = more aggressive. | -| `additive_increase` | 1 | Slots added per recovery step. Higher = faster ramp-up, but riskier. | -| `success_window` | 25 | Consecutive successes required before each increase step. | -| `cooldown_seconds` | 2.0 | Pause duration after a 429 (used when the server doesn't send `Retry-After`). | -| `ceiling_overshoot` | 0.10 | Fraction above the observed rate-limit ceiling the controller is allowed to probe. | +Data Designer uses AIMD (Additive Increase / Multiplicative Decrease) request admission to automatically adjust concurrency per provider/model/domain based on rate-limit feedback from the inference server. This is an internal runtime controller, not a public `RunConfig` knob. Set `max_parallel_requests` as the user-facing ceiling and inspect `AsyncCapacityPlan`/logs to understand the effective runtime limits. !!! tip "How it works in practice" When a model endpoint returns HTTP 429, the controller reduces the concurrency limit for that model and pauses briefly. After enough consecutive successes, it begins ramping back up. If the server rate-limits again, the controller records that level as a ceiling and stabilizes just below it, with a small overshoot band to detect when the server can handle more load. @@ -263,11 +234,11 @@ designer.set_run_config(run_config) ## Async Engine -The async engine is the default execution path. It dispatches work at the cell level rather than the column level, so independent columns overlap in time and per-(provider, model) AIMD pools tune themselves independently. See the [Async All the Way Down](../devnotes/posts/async-all-the-way-down.md) dev note for the full architecture. +The async engine is the default execution path. It dispatches work at the cell level rather than the column level, so independent columns overlap in time and provider/model/domain request resources tune themselves independently. See the [Async All the Way Down](../devnotes/posts/async-all-the-way-down.md) dev note for the full architecture. ### Per-model timeouts drive every deadline -The `inference_parameters.timeout` field on a `ModelConfig` sets the per-request HTTP timeout. The same value also drives the sync→async bridge that custom columns use when they call `model.generate()`. There is no separate queue-wait deadline — waits scale with provider speed and AIMD's adaptive concurrency. Slow self-hosted endpoints (e.g. large models on a single GPU) only need this one knob raised: +The `inference_parameters.timeout` field on a `ModelConfig` sets the per-request HTTP timeout. The same value also drives the sync→async bridge that custom columns use when they call `model.generate()`. There is no separate queue-wait deadline — waits scale with provider speed and adaptive request admission. Slow self-hosted endpoints (e.g. large models on a single GPU) only need this one knob raised: ```python import data_designer.config as dd @@ -315,8 +286,8 @@ DATA_DESIGNER_ASYNC_ENGINE=0 python my_pipeline.py | Problem | Symptom | Solution | |---------|---------|----------| -| **Low throughput** | Low GPU utilization | Increase `max_parallel_requests` and/or `buffer_size`. If the throttle has self-reduced due to earlier 429s (check logs for "concurrency reduced" messages), the server may need more capacity or you can wait for AIMD recovery. | -| **Frequent 429 → recovery cycles** | Logs show repeated concurrency drops and ramp-ups | The `max_parallel_requests` ceiling is above the server's sustained capacity. This is handled automatically, but you can lower the ceiling to reduce the sawtooth or tune `reduce_factor` / `success_window`. | +| **Low throughput** | Low GPU utilization | Increase `max_parallel_requests` and/or `buffer_size`. If request admission has self-reduced due to earlier 429s (check logs for "concurrency reduced" messages), the server may need more capacity or you can wait for AIMD recovery. | +| **Frequent 429 → recovery cycles** | Logs show repeated concurrency drops and ramp-ups | The `max_parallel_requests` ceiling is above the server's sustained capacity. This is handled automatically, but you can lower the ceiling to reduce the sawtooth. | | **Long tail of slow generations** | Most records fast, few very slow | Reduce `max_conversation_restarts`, simplify schemas, improve prompts | | **Multi-model idle periods** | One model busy, others idle | Reduce `buffer_size` for faster cycling, or consolidate models | | **Memory errors** | OOM crashes | Reduce `buffer_size` and `max_parallel_requests` | @@ -326,10 +297,10 @@ DATA_DESIGNER_ASYNC_ENGINE=0 python my_pipeline.py ## Tuning Workflow -1. **Start with defaults** for initial development — AIMD handles rate-limit adaptation automatically +1. **Start with defaults** for initial development — adaptive request admission handles rate-limit adaptation automatically 2. **Profile your workload**: How many LLM columns? How many records? What models? -3. **Identify bottleneck**: Low GPU util → increase `max_parallel_requests` (AIMD will self-correct if you overshoot). Memory issues → decrease `buffer_size`. Long tails → tune retry settings. -4. **Check throttle logs**: Look for "concurrency reduced" / "concurrency increased" messages to understand whether rate limits are the bottleneck +3. **Identify bottleneck**: Low GPU util → increase `max_parallel_requests` (request admission will self-correct if you overshoot). Memory issues → decrease `buffer_size`. Long tails → tune retry settings. +4. **Check request-admission logs**: Look for "concurrency reduced" / "concurrency increased" messages to understand whether rate limits are the bottleneck 5. **Iterate**: Make one change at a time, measure impact before next change --- diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index eb385e15a..d3af55838 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -58,7 +58,7 @@ ProcessorType, SchemaTransformProcessorConfig, ) - from data_designer.config.run_config import JinjaRenderingEngine, RunConfig, ThrottleConfig # noqa: F401 + from data_designer.config.run_config import JinjaRenderingEngine, RunConfig # noqa: F401 from data_designer.config.sampler_constraints import ( # noqa: F401 ColumnInequalityConstraint, ConstraintType, @@ -82,6 +82,7 @@ UniformSamplerParams, UUIDSamplerParams, ) + from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError # noqa: F401 from data_designer.config.seed import ( # noqa: F401 IndexRange, PartitionBlock, @@ -177,7 +178,9 @@ # run_config "JinjaRenderingEngine": (f"{_MOD_BASE}.run_config", "JinjaRenderingEngine"), "RunConfig": (f"{_MOD_BASE}.run_config", "RunConfig"), - "ThrottleConfig": (f"{_MOD_BASE}.run_config", "ThrottleConfig"), + # scheduling metadata + "SchedulingMetadata": (f"{_MOD_BASE}.scheduling", "SchedulingMetadata"), + "SchedulingMetadataError": (f"{_MOD_BASE}.scheduling", "SchedulingMetadataError"), # sampler_constraints "ColumnInequalityConstraint": (_MOD_SAMPLER_CONSTRAINTS, "ColumnInequalityConstraint"), "ConstraintType": (_MOD_SAMPLER_CONSTRAINTS, "ConstraintType"), diff --git a/packages/data-designer-config/src/data_designer/config/run_config.py b/packages/data-designer-config/src/data_designer/config/run_config.py index d5f10c9e5..a67740aed 100644 --- a/packages/data-designer-config/src/data_designer/config/run_config.py +++ b/packages/data-designer-config/src/data_designer/config/run_config.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import ClassVar +from typing import Any from pydantic import Field, model_validator from typing_extensions import Self @@ -19,63 +19,6 @@ class JinjaRenderingEngine(StrEnum): SECURE = "secure" -class ThrottleConfig(ConfigBase): - """AIMD throttle tuning parameters for adaptive concurrency control. - - These knobs configure the ``ThrottleManager`` that wraps every outbound - model HTTP request. The defaults are conservative and suitable for most - workloads; override only when you understand the trade-offs. - - Attributes: - reduce_factor: Multiplicative decrease factor applied to the per-domain - concurrency limit on a 429 / rate-limit signal. Must be in (0, 1). - Default is 0.75 (reduce by 25% on rate-limit). - additive_increase: Additive increase step applied after every - ``success_window`` consecutive successes. Default is 1. - success_window: Number of consecutive successful releases before - the additive increase is applied. Default is 25. - cooldown_seconds: Default cooldown duration (seconds) applied after a - rate-limit when the provider does not include a ``Retry-After`` - header. Default is 2.0. - ceiling_overshoot: Fraction above the observed rate-limit ceiling - that additive increase is allowed to probe before capping. - Default is 0.10 (10% overshoot). - """ - - DEFAULT_REDUCE_FACTOR: ClassVar[float] = 0.75 - DEFAULT_ADDITIVE_INCREASE: ClassVar[int] = 1 - DEFAULT_SUCCESS_WINDOW: ClassVar[int] = 25 - DEFAULT_COOLDOWN_SECONDS: ClassVar[float] = 2.0 - DEFAULT_CEILING_OVERSHOOT: ClassVar[float] = 0.10 - - reduce_factor: float = Field( - default=DEFAULT_REDUCE_FACTOR, - gt=0.0, - lt=1.0, - description="Multiplicative decrease factor applied to the per-domain concurrency limit on a 429 signal.", - ) - additive_increase: int = Field( - default=DEFAULT_ADDITIVE_INCREASE, - ge=1, - description="Additive increase step applied after every `success_window` consecutive successes.", - ) - success_window: int = Field( - default=DEFAULT_SUCCESS_WINDOW, - ge=1, - description="Number of consecutive successful releases before the additive increase is applied.", - ) - cooldown_seconds: float = Field( - default=DEFAULT_COOLDOWN_SECONDS, - gt=0.0, - description="Default cooldown duration (seconds) after a rate-limit when no Retry-After header is present.", - ) - ceiling_overshoot: float = Field( - default=DEFAULT_CEILING_OVERSHOOT, - ge=0.0, - description="Fraction above the rate-limit ceiling that additive increase is allowed to probe.", - ) - - class RunConfig(ConfigBase): """Runtime configuration for dataset generation. @@ -112,7 +55,10 @@ class RunConfig(ConfigBase): fewer Data Designer-specific restrictions. ``secure`` uses Data Designer's hardened sandbox with additional AST, filter, and output guards. Default is ``secure``. - throttle: AIMD throttle tuning parameters. See ``ThrottleConfig`` for details. + + Notes: + Request admission is engine-internal in V1 and is not exposed as a + public run-config knob. """ disable_early_shutdown: bool = False @@ -132,7 +78,16 @@ class RunConfig(ConfigBase): "`native` uses Jinja2's built-in sandbox; `secure` uses Data Designer's hardened sandbox." ), ) - throttle: ThrottleConfig = Field(default_factory=ThrottleConfig) + + @model_validator(mode="before") + @classmethod + def reject_removed_throttle_config(cls, data: Any) -> Any: + if isinstance(data, dict) and "throttle" in data: + raise ValueError( + "RunConfig.throttle was removed. Request admission is now managed internally by the async " + "scheduling engine; remove the throttle field from your run config." + ) + return data @model_validator(mode="after") def normalize_shutdown_settings(self) -> Self: diff --git a/packages/data-designer-config/src/data_designer/config/scheduling.py b/packages/data-designer-config/src/data_designer/config/scheduling.py new file mode 100644 index 000000000..84e36b3a0 --- /dev/null +++ b/packages/data-designer-config/src/data_designer/config/scheduling.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal + +SchedulingMetadataKind = Literal["local", "model", "custom_model"] + + +@dataclass(frozen=True) +class SchedulingMetadata: + """Static generator-facing scheduling metadata. + + The metadata describes broad resource shape only. It intentionally does + not expose ready-queue state, task-admission state, request-admission + pressure, provider cooldowns, or adaptive request limits. + """ + + kind: SchedulingMetadataKind = "local" + identity: tuple[str, ...] = ("local", "default") + weight: int = 1 + diagnostics: dict[str, object] = field(default_factory=dict) + + @classmethod + def local(cls, resource_name: str = "default", *, weight: int = 1) -> SchedulingMetadata: + return cls(kind="local", identity=("local", resource_name), weight=weight) + + @classmethod + def model( + cls, + provider_name: str, + model_id: str, + generation_kind: str, + *, + weight: int, + diagnostics: dict[str, object] | None = None, + ) -> SchedulingMetadata: + return cls( + kind="model", + identity=("model", provider_name, model_id, generation_kind), + weight=weight, + diagnostics=diagnostics or {}, + ) + + @classmethod + def custom_model( + cls, + plugin_namespace: str, + resource_name: str, + version: str, + *, + weight: int = 1, + diagnostics: dict[str, object] | None = None, + ) -> SchedulingMetadata: + return cls( + kind="custom_model", + identity=("custom_model", plugin_namespace, resource_name, version), + weight=weight, + diagnostics=diagnostics or {}, + ) + + def __post_init__(self) -> None: + if self.kind not in {"local", "model", "custom_model"}: + raise SchedulingMetadataError( + code="invalid_kind", + message=f"Unknown scheduling metadata kind: {self.kind!r}", + diagnostics={"kind": self.kind}, + ) + if not isinstance(self.identity, tuple) or not self.identity: + raise SchedulingMetadataError( + code="invalid_identity", + message="Scheduling metadata identity must be a non-empty tuple of non-empty strings.", + diagnostics={"identity": self.identity}, + ) + if any(not isinstance(part, str) or not part for part in self.identity): + raise SchedulingMetadataError( + code="invalid_identity", + message="Scheduling metadata identity must contain only non-empty strings.", + diagnostics={"identity": self.identity}, + ) + expected_identity_lengths = {"local": 2, "model": 4, "custom_model": 4} + if self.identity[0] != self.kind or len(self.identity) != expected_identity_lengths[self.kind]: + raise SchedulingMetadataError( + code="invalid_identity", + message=f"Scheduling metadata identity for kind {self.kind!r} has an invalid shape.", + diagnostics={ + "kind": self.kind, + "identity": self.identity, + "expected_prefix": self.kind, + "expected_length": expected_identity_lengths[self.kind], + }, + ) + if isinstance(self.weight, bool) or not isinstance(self.weight, int) or self.weight <= 0: + raise SchedulingMetadataError( + code="invalid_weight", + message="Scheduling metadata weight must be a positive integer.", + diagnostics={"weight": self.weight}, + ) + + +class SchedulingMetadataError(ValueError): + """Typed scheduling metadata resolution error.""" + + def __init__( + self, + *, + code: str, + message: str, + fallback: SchedulingMetadata | None = None, + diagnostics: dict[str, object] | None = None, + ) -> None: + super().__init__(message) + self.code = code + self.message = message + self.fallback = fallback + self.diagnostics = diagnostics or {} diff --git a/packages/data-designer-config/tests/config/test_run_config.py b/packages/data-designer-config/tests/config/test_run_config.py index 98c819b38..a90a2d57a 100644 --- a/packages/data-designer-config/tests/config/test_run_config.py +++ b/packages/data-designer-config/tests/config/test_run_config.py @@ -3,6 +3,9 @@ from __future__ import annotations +import pytest +from pydantic import ValidationError + from data_designer.config.run_config import JinjaRenderingEngine, RunConfig @@ -13,3 +16,8 @@ def test_run_config_defaults_to_secure_jinja_renderer() -> None: def test_run_config_accepts_native_renderer() -> None: run_config = RunConfig(jinja_rendering_engine=JinjaRenderingEngine.NATIVE) assert JinjaRenderingEngine(run_config.jinja_rendering_engine) == JinjaRenderingEngine.NATIVE + + +def test_run_config_rejects_removed_throttle_with_targeted_message() -> None: + with pytest.raises(ValidationError, match="RunConfig.throttle was removed"): + RunConfig(throttle={"max_concurrent_requests": 1}) diff --git a/packages/data-designer-config/tests/config/test_scheduling.py b/packages/data-designer-config/tests/config/test_scheduling.py new file mode 100644 index 000000000..e219daddd --- /dev/null +++ b/packages/data-designer-config/tests/config/test_scheduling.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError + + +@pytest.mark.parametrize( + "metadata", + [ + SchedulingMetadata.local(), + SchedulingMetadata.model("nvidia", "nemotron", "chat", weight=2), + SchedulingMetadata.custom_model("plugin", "resource", "v1"), + ], +) +def test_scheduling_metadata_accepts_normative_shapes(metadata: SchedulingMetadata) -> None: + assert metadata.weight >= 1 + + +@pytest.mark.parametrize( + "kwargs", + [ + {"identity": ["local", "default"]}, + {"weight": True}, + {"kind": "model", "identity": ("local", "default")}, + {"kind": "local", "identity": ("local", "default", "extra")}, + {"kind": "custom_model", "identity": ("custom_model", "plugin")}, + ], +) +def test_scheduling_metadata_rejects_non_normative_direct_construction(kwargs: dict[str, object]) -> None: + with pytest.raises(SchedulingMetadataError): + SchedulingMetadata(**kwargs) # type: ignore[arg-type] diff --git a/packages/data-designer-engine/src/data_designer/engine/capacity.py b/packages/data-designer-engine/src/data_designer/engine/capacity.py new file mode 100644 index 000000000..4dc1abb0b --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/capacity.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Generic, Literal, TypeVar + +from data_designer.engine.dataset_builders.scheduling.resources import SchedulerResourceKey, TaskGroupKey +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.resources import RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap + +_T = TypeVar("_T") + +CapacityValueSource = Literal[ + "default", + "run_config", + "dataset_builder", + "model_metadata", + "engine_internal_config", + "adapter_config", + "environment", + "runtime_snapshot", + "benchmark_override", +] + + +@dataclass(frozen=True) +class CapacityValue(Generic[_T]): + value: _T | None + source: CapacityValueSource + fallback_from: str | None = None + missing_reason: str | None = None + + +@dataclass(frozen=True) +class RowGroupAdmission: + row_group_concurrency: CapacityValue[int] + observed_in_flight: int | None = None + mode: Literal["fixed", "adaptive"] = "fixed" + target_in_flight: int | None = None + observed_max_target: int | None = None + max_admitted_rows: int | None = None + blocked_reasons: Mapping[str, int] = field(default_factory=dict) + + +@dataclass(frozen=True) +class RequestAdmissionConfigSnapshot: + resources: Sequence[RequestResourceKey] + initial_limits: Mapping[RequestResourceKey, int] + max_limit_clamps: Mapping[RequestResourceKey, int | None] + cooldown_seconds: float + multiplicative_decrease_factor: float + additive_increase_step: int + increase_after_successes: int + default_queue_wait_timeout_seconds: float | None + + @classmethod + def from_config(cls, config: RequestAdmissionConfig) -> RequestAdmissionConfigSnapshot: + resources = tuple(sorted({*config.initial_limits, *config.max_limit_clamps})) + return cls( + resources=resources, + initial_limits=dict(config.initial_limits), + max_limit_clamps=dict(config.max_limit_clamps), + cooldown_seconds=config.cooldown_seconds, + multiplicative_decrease_factor=config.multiplicative_decrease_factor, + additive_increase_step=config.additive_increase_step, + increase_after_successes=config.increase_after_successes, + default_queue_wait_timeout_seconds=config.default_queue_wait_timeout_seconds, + ) + + +@dataclass(frozen=True) +class AsyncCapacityConfigured: + buffer_size: CapacityValue[int] + row_group_admission: RowGroupAdmission + submission_capacity: CapacityValue[int] + task_resource_limits: CapacityValue[Mapping[SchedulerResourceKey, int]] + request_resources: CapacityValue[Sequence[RequestResourceKey]] + provider_model_static_caps: CapacityValue[Mapping[ProviderModelKey, ProviderModelStaticCap]] + request_domain_initial_limits: CapacityValue[Mapping[RequestResourceKey, int]] + request_admission_config: CapacityValue[RequestAdmissionConfigSnapshot] + transport_pool_limits: CapacityValue[Mapping[ProviderModelKey, int]] + + +@dataclass(frozen=True) +class AsyncCapacityRuntimeSnapshot: + request_domain_current_limits: Mapping[RequestResourceKey, int] | None = None + request_domain_effective_max: Mapping[RequestResourceKey, int] | None = None + request_domain_blocked_until: Mapping[RequestResourceKey, float | None] | None = None + provider_model_aggregate_in_flight: Mapping[ProviderModelKey, int] | None = None + + +@dataclass(frozen=True) +class AsyncCapacityObservedMaxima: + row_groups_in_flight: int = 0 + queued_tasks_by_group: Mapping[TaskGroupKey | str, int] = field(default_factory=dict) + task_leases_by_resource: Mapping[SchedulerResourceKey, int] = field(default_factory=dict) + request_waiters_by_resource: Mapping[RequestResourceKey, int] = field(default_factory=dict) + request_in_flight_by_resource: Mapping[RequestResourceKey, int] = field(default_factory=dict) + provider_model_aggregate_in_flight: Mapping[ProviderModelKey, int] = field(default_factory=dict) + request_domain_current_limits: Mapping[RequestResourceKey, int] = field(default_factory=dict) + transport_pool_utilization: Mapping[ProviderModelKey, int] | None = None + + +@dataclass(frozen=True) +class AsyncCapacityPlan: + configured: AsyncCapacityConfigured + runtime_snapshot: AsyncCapacityRuntimeSnapshot + observed_maxima: AsyncCapacityObservedMaxima + + +def missing_capacity_value( + *, + source: CapacityValueSource, + missing_reason: str, + fallback_from: str | None = None, +) -> CapacityValue[object]: + return CapacityValue(value=None, source=source, fallback_from=fallback_from, missing_reason=missing_reason) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index 2431c0eb6..e2a333c69 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -8,16 +8,18 @@ import functools import logging from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Coroutine, TypeVar, overload from data_designer.config.column_configs import GenerationStrategy +from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT from data_designer.logging import LOG_DOUBLE_INDENT, LOG_INDENT _T = TypeVar("_T") # Preserved deliberately. Two other 300s deadlines were retired in the -# async-default flip (PR #592): the throttle queue-wait and the +# async-default flip (PR #592): the request-admission queue wait and the # ``_AsyncBridgedModelFacade`` bridge in ``custom.py`` — both have # ``ModelFacade`` context and could derive a per-call deadline from # ``inference_parameters.timeout``. This generic ``ColumnGenerator.generate()`` @@ -26,6 +28,20 @@ # tracked as a structural follow-up. SYNC_BRIDGE_TIMEOUT = 300 + +@dataclass +class _EndpointBucket: + aliases: list[str] = field(default_factory=list) + caps: list[int] = field(default_factory=list) + + +def _scheduling_generation_kind(generation_type: object) -> str: + value = getattr(generation_type, "value", generation_type) + if value == "chat-completion": + return "chat" + return str(value) + + if TYPE_CHECKING: import pandas as pd @@ -65,10 +81,14 @@ class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC): def can_generate_from_scratch(self) -> bool: return False - @property - def is_llm_bound(self) -> bool: - """Whether this generator makes model/API calls during generation.""" - return False + def get_scheduling_metadata(self) -> SchedulingMetadata: + """Return static scheduler metadata for this generator. + + Generators that do not declare model-backed behavior use the documented + local default. Model-aware base classes override this with provider/model + resource identity derived from registered model aliases. + """ + return SchedulingMetadata.local() @property def is_order_dependent(self) -> bool: @@ -143,10 +163,6 @@ async def agenerate_from_scratch(self, num_records: int) -> pd.DataFrame: class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC): - @property - def is_llm_bound(self) -> bool: - return True - @property def model_registry(self) -> ModelRegistry: return self.resource_provider.model_registry @@ -161,6 +177,72 @@ def get_model_provider_name(self, model_alias: str) -> str: provider = self.model_registry.get_model_provider(model_alias=model_alias) return provider.name + def get_scheduling_metadata(self) -> SchedulingMetadata: + aliases = self._get_scheduling_model_aliases() + if not aliases: + raise SchedulingMetadataError( + code="missing_model_alias", + message=f"{type(self).__name__} has no model aliases for scheduling metadata.", + fallback=SchedulingMetadata.local(), + diagnostics={"generator_type": type(self).__name__}, + ) + + endpoints: dict[tuple[str, str, str], _EndpointBucket] = {} + for alias in aliases: + try: + model_config = self.get_model_config(model_alias=alias) + provider_name = self.get_model_provider_name(model_alias=alias) + except Exception as exc: + raise SchedulingMetadataError( + code="alias_resolution_failed", + message=f"Could not resolve model alias {alias!r} for scheduling metadata.", + diagnostics={"alias": alias, "generator_type": type(self).__name__}, + ) from exc + + endpoint = ( + provider_name, + str(model_config.model), + _scheduling_generation_kind(model_config.generation_type), + ) + max_parallel = getattr(model_config.inference_parameters, "max_parallel_requests", 1) + cap = max_parallel if isinstance(max_parallel, int) and max_parallel > 0 else 1 + bucket = endpoints.setdefault(endpoint, _EndpointBucket()) + bucket.aliases.append(alias) + bucket.caps.append(cap) + + if len(endpoints) != 1: + raise SchedulingMetadataError( + code="ambiguous_model_aliases", + message="Model scheduling metadata must resolve to one provider/model/generation endpoint.", + diagnostics={"endpoints": sorted(str(endpoint) for endpoint in endpoints)}, + ) + + endpoint, bucket = next(iter(endpoints.items())) + provider_name, model_id, generation_kind = endpoint + effective_cap = max(1, min(bucket.caps)) + return SchedulingMetadata.model( + provider_name, + model_id, + generation_kind, + weight=effective_cap, + diagnostics={ + "aliases": tuple(bucket.aliases), + "raw_caps": tuple(bucket.caps), + "merge_rule": "min_same_endpoint", + }, + ) + + def _get_scheduling_model_aliases(self) -> list[str]: + get_aliases = getattr(self.config, "get_model_aliases", None) + if callable(get_aliases): + aliases = get_aliases() + else: + aliases = [] + if (alias := getattr(self.config, "model_alias", None)) is not None: + aliases.append(alias) + aliases.extend(getattr(self.config, "model_aliases", []) or []) + return list(dict.fromkeys(str(alias) for alias in aliases if alias)) + class ColumnGeneratorWithModel(ColumnGeneratorWithModelRegistry[TaskConfigT], ABC): @functools.cached_property diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index b4c863542..08c78120b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -13,6 +13,7 @@ import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy +from data_designer.config.scheduling import SchedulingMetadata from data_designer.engine.column_generators.generators.base import ColumnGenerator from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS, ModelTimeoutError @@ -105,7 +106,7 @@ def generate(self, *args: Any, **kwargs: Any) -> tuple[Any, list]: except concurrent.futures.TimeoutError as exc: future.cancel() # Demoted to debug: the raised ModelTimeoutError already surfaces - # the timeout at the scheduler with full context, and the throttled + # the timeout at the scheduler with full context, and the request-admission # degraded-provider WARN is the user-facing signal under sustained # bridge timeouts. Per-event WARN was noise on top of those. logger.debug("Async model bridge timed out after %.0fs; coroutine cancelled", bridge_timeout) @@ -137,10 +138,18 @@ class CustomColumnGenerator(ColumnGenerator[CustomColumnConfig]): The models dict provides direct access to ModelFacade instances keyed by alias. """ - @property - def is_llm_bound(self) -> bool: - """Custom generators with model_aliases make LLM calls and need the handoff.""" - return bool(self.config.model_aliases) + def get_scheduling_metadata(self) -> SchedulingMetadata: + """Return custom-model metadata when the custom column declares model aliases.""" + if not self.config.model_aliases: + return SchedulingMetadata.local() + identity = "-".join(sorted(str(alias) for alias in self.config.model_aliases)) + return SchedulingMetadata.custom_model( + "custom_column", + identity or self.config.name, + "v1", + weight=max(1, len(self.config.model_aliases)), + diagnostics={"aliases": tuple(sorted(str(alias) for alias in self.config.model_aliases))}, + ) def get_generation_strategy(self) -> GenerationStrategy: """Return strategy based on config.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 778501da1..df19c5716 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -5,51 +5,80 @@ import asyncio import contextlib +import hashlib import logging import time -from collections import defaultdict, deque -from collections.abc import Coroutine +import uuid +from collections import Counter, defaultdict, deque +from collections.abc import Coroutine, Mapping from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import GenerationStrategy +from data_designer.engine.capacity import ( + AsyncCapacityConfigured, + AsyncCapacityObservedMaxima, + AsyncCapacityPlan, + AsyncCapacityRuntimeSnapshot, + CapacityValue, + RequestAdmissionConfigSnapshot, + RowGroupAdmission, +) from data_designer.engine.context import current_row_group from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig +from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta +from data_designer.engine.dataset_builders.scheduling.queue import ( + FairTaskQueue, +) +from data_designer.engine.dataset_builders.scheduling.resolver import TaskSchedulingResolver +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_admission import ( + TaskAdmissionConfig, + TaskAdmissionController, + TaskAdmissionDenied, + TaskAdmissionLease, +) +from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef, Task, TaskTrace from data_designer.engine.dataset_builders.utils.async_progress_reporter import ( DEFAULT_REPORT_INTERVAL, AsyncProgressReporter, ) -from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta -from data_designer.engine.dataset_builders.utils.fair_task_queue import ( - FairTaskQueue, - TaskGroupKey, - TaskGroupSpec, -) from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker -from data_designer.engine.dataset_builders.utils.scheduling_hints import SchedulingHint, SchedulingHintResolver from data_designer.engine.dataset_builders.utils.skip_evaluator import should_skip_column_for_record from data_designer.engine.dataset_builders.utils.skip_tracker import ( apply_skip_to_record, strip_skip_metadata_from_records, ) from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar -from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task, TaskTrace -from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS +from data_designer.engine.errors import DataDesignerError +from data_designer.engine.models.clients.errors import ProviderError +from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS, GenerationValidationFailureError +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.resources import RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap +from data_designer.engine.observability import ( + RuntimeCorrelation, + SchedulerAdmissionEvent, + SchedulerAdmissionEventSink, + runtime_correlation_provider, +) if TYPE_CHECKING: from data_designer.engine.column_generators.generators.base import ColumnGenerator from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager + from data_designer.engine.models.request_admission.pressure import RequestPressureSnapshotProvider logger = logging.getLogger(__name__) DEFAULT_TASK_POOL_SIZE: int = 256 -# Global LLM wait-pool headroom sizes the memory-safety semaphore above provider capacity. -GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER: int = 2 -# Per-group admission backlog caps how many ready LLM tasks one fair-queue group can hold. -LLM_GROUP_ADMISSION_BACKLOG_MULTIPLIER: int = 2 +MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER: int = 2 +MODEL_GROUP_ADMISSION_BACKLOG_MULTIPLIER: int = 2 # Degraded-provider WARN: emit at most one warning per interval when the # rolling fraction of retryable errors exceeds the threshold. Distinct from @@ -58,21 +87,27 @@ DEGRADED_WARN_RATE: float = 0.5 DEGRADED_WARN_WINDOW: int = 20 DEGRADED_WARN_INTERVAL_S: float = 60.0 +INTERNAL_BUG_EXCEPTIONS = (KeyError, TypeError, AttributeError, AssertionError) -class TrackingSemaphore(asyncio.Semaphore): - """``asyncio.Semaphore`` subclass that exposes available permits publicly.""" +def _identity_hash(identity: tuple[str, ...]) -> str: + return hashlib.sha1("\0".join(identity).encode()).hexdigest()[:16] - @property - def available_permits(self) -> int: - return self._value # type: ignore[attr-defined] - def try_acquire(self) -> bool: - """Non-blocking acquire. Returns ``True`` if a permit was taken.""" - if self._value > 0: # type: ignore[attr-defined] - self._value -= 1 # type: ignore[attr-defined] - return True - return False +def _request_resource_label(resource: object | None) -> str | None: + if resource is None: + return None + provider = getattr(resource, "provider_name", None) + model = getattr(resource, "model_id", None) + domain = getattr(resource, "domain", None) + domain_value = getattr(domain, "value", domain) + if provider is None or model is None or domain_value is None: + return str(resource) + return f"{provider}/{model}/{domain_value}" + + +def _string_keyed_counts(values: Mapping[object, int]) -> dict[str, int]: + return {str(key): int(value) for key, value in values.items()} @dataclass @@ -90,8 +125,7 @@ class _DispatchOutcome: """Result of one fair-dispatch pass over the persistent ready queue.""" dispatched: bool = False - submission_full: bool = False - group_blocked: bool = False + admission_blocked: bool = False class AsyncTaskScheduler: @@ -111,7 +145,8 @@ def __init__( *, max_concurrent_row_groups: int = 3, max_submitted_tasks: int = DEFAULT_TASK_POOL_SIZE, - max_llm_wait_tasks: int = DEFAULT_TASK_POOL_SIZE, + max_model_task_admission: int = DEFAULT_TASK_POOL_SIZE, + task_admission_config: TaskAdmissionConfig | None = None, salvage_max_rounds: int = 2, on_finalize_row_group: Callable[[int], None] | None = None, on_seeds_complete: Callable[[int, int], FrontierDelta | None] | None = None, @@ -127,6 +162,12 @@ def __init__( buffer_size: int = 0, progress_interval: float | None = None, progress_bar: bool = False, + scheduler_event_sink: SchedulerAdmissionEventSink | None = None, + run_id: str | None = None, + adaptive_row_group_admission: bool = False, + adaptive_row_group_initial_target: int = 1, + request_pressure_provider: RequestPressureSnapshotProvider | None = None, + request_pressure_advisory: bool = False, ) -> None: self._generators = generators self._graph = graph @@ -135,22 +176,29 @@ def __init__( self._buffer_manager = buffer_manager self._rg_semaphore = asyncio.Semaphore(max_concurrent_row_groups) - self._submission_semaphore = TrackingSemaphore(max_submitted_tasks) - self._llm_wait_semaphore = TrackingSemaphore(max_llm_wait_tasks) - self._max_llm_wait_tasks = max_llm_wait_tasks - self._llm_bound_lookup = build_llm_bound_lookup(generators) - self._scheduling_hints = SchedulingHintResolver(generators) + self._task_scheduling = TaskSchedulingResolver( + generators, + model_group_limit_multiplier=MODEL_GROUP_ADMISSION_BACKLOG_MULTIPLIER, + model_group_limit_cap=max_model_task_admission, + ) + admission_config = task_admission_config or TaskAdmissionConfig( + submission_capacity=max_submitted_tasks, + resource_limits={"llm_wait": max_model_task_admission, "local": max_submitted_tasks}, + ) + self._task_admission = TaskAdmissionController(admission_config) + self._task_admission_config = admission_config self._fair_queue = FairTaskQueue() self._pending_pre_batch_ready: defaultdict[int, list[Task]] = defaultdict(list) self._pending_pre_batch_ready_tasks: set[Task] = set() - # Task group specs are derived from per-generator scheduling hints and flow identity. - self._task_group_spec_cache: dict[int, TaskGroupSpec] = {} self._dispatched: set[Task] = set() self._in_flight: set[Task] = set() self._worker_tasks: set[asyncio.Task] = set() self._wake_event = asyncio.Event() + self._run_id = run_id or f"run-{uuid.uuid4().hex}" + self._scheduler_event_sink = scheduler_event_sink + self._scheduler_event_sequence = 0 self._salvage_max_rounds = salvage_max_rounds self._on_finalize_row_group = on_finalize_row_group self._on_seeds_complete = on_seeds_complete @@ -202,7 +250,7 @@ def __init__( self._all_rgs_admitted = False # Degraded-provider WARN: separate window tracking retryable-vs-not for - # every outcome (success or failure), throttled to one log per interval. + # every outcome (success or failure), rate-limited to one log per interval. self._degraded_warn_rate = degraded_warn_rate self._degraded_warn_window = degraded_warn_window self._degraded_warn_interval_s = degraded_warn_interval_s @@ -224,9 +272,38 @@ def __init__( # context naturally because the from_scratch task raised; the async # engine drops rows and continues, losing the cause unless we capture it. self._first_non_retryable_error: Exception | None = None + self._fatal_worker_error: BaseException | None = None # Pre-compute row-group sizes for O(1) lookup self._rg_size_map: dict[int, int] = dict(row_groups) + self._max_concurrent_row_groups = max_concurrent_row_groups + self._max_submitted_tasks = max_submitted_tasks + self._max_model_task_admission = max_model_task_admission + self._num_records = num_records + self._buffer_size = buffer_size + self._observed_max_row_groups_in_flight = 0 + self._observed_max_task_leases_by_resource: dict[str, int] = {} + self._observed_max_queued_by_group: dict[str, int] = {} + self._observed_max_request_waiters_by_resource: dict[RequestResourceKey, int] = {} + self._observed_max_request_in_flight_by_resource: dict[RequestResourceKey, int] = {} + self._observed_max_provider_model_aggregate_in_flight: dict[ProviderModelKey, int] = {} + self._observed_max_request_domain_current_limits: dict[RequestResourceKey, int] = {} + self._adaptive_row_group_admission = adaptive_row_group_admission + self._row_group_admission_hard_cap = max(1, max_concurrent_row_groups) + self._row_group_admission_target = ( + max(1, min(self._row_group_admission_hard_cap, adaptive_row_group_initial_target)) + if adaptive_row_group_admission + else self._row_group_admission_hard_cap + ) + self._observed_max_row_group_admission_target = self._row_group_admission_target + self._row_group_admission_event = asyncio.Event() + self._row_group_admission_event.set() + self._row_group_admission_pressure_ticks = 0 + self._row_group_admission_blocked_reasons: Counter[str] = Counter() + self._adaptive_max_admitted_rows = self._max_admitted_rows_guardrail() + self._request_pressure_provider = request_pressure_provider + self._request_pressure_advisory = request_pressure_advisory and request_pressure_provider is not None + self._request_pressure_advisory_skips = 0 # Pre-compute seed columns (graph is static) self._seed_cols: tuple[str, ...] = tuple(c for c in graph.columns if not graph.get_upstream_columns(c)) @@ -293,6 +370,13 @@ def first_non_retryable_error(self) -> Exception | None: """ return self._first_non_retryable_error + def _raise_if_fatal_worker_error(self) -> None: + if self._fatal_worker_error is None: + return + raise DatasetGenerationError( + "Unexpected internal task failure in async scheduler." + ) from self._fatal_worker_error + def _spawn_worker(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task: """Create a tracked worker task that auto-removes itself on completion.""" task = asyncio.create_task(coro) @@ -300,6 +384,230 @@ def _spawn_worker(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task: task.add_done_callback(self._worker_tasks.discard) return task + def _emit_scheduler_event( + self, + event_kind: str, + *, + task: Task | None = None, + lease: TaskAdmissionLease | None = None, + task_execution_id: str | None = None, + scheduler_resource_key: str | None = None, + reason_or_result: str | None = None, + diagnostics: dict[str, object] | None = None, + ) -> None: + if self._scheduler_event_sink is None: + return + self._scheduler_event_sequence += 1 + correlation = None + event_diagnostics = dict(diagnostics or {}) + if task is not None: + schedulable = lease.item if lease is not None else self._schedulable_task(task) + group = schedulable.group + identity_hash = _identity_hash(group.key.identity) + event_diagnostics.setdefault("task_group_key", group.key) + event_diagnostics.setdefault("resource_request", dict(schedulable.resource_request.amounts)) + correlation = RuntimeCorrelation( + run_id=self._run_id, + row_group=task.row_group, + task_column=task.column, + task_type=task.task_type, + scheduling_group_kind=group.key.kind, + scheduling_group_identity_hash=identity_hash, + task_execution_id=task_execution_id, + ) + try: + self._scheduler_event_sink.emit_scheduler_event( + SchedulerAdmissionEvent.capture( + event_kind, # type: ignore[arg-type] + sequence=self._scheduler_event_sequence, + correlation=correlation, + task_id=stable_task_id(task) if task is not None else None, + task_execution_id=task_execution_id, + task_lease_id=lease.lease_id if lease is not None else None, + scheduler_resource_key=scheduler_resource_key, + reason_or_result=reason_or_result, + snapshot=self.task_admission_snapshot(), + diagnostics=event_diagnostics, + ) + ) + except Exception: + logger.warning("Scheduler admission event sink raised; dropping event.", exc_info=True) + return + + def _record_observed_task_state(self) -> None: + self._observed_max_row_groups_in_flight = max(self._observed_max_row_groups_in_flight, len(self._rg_states)) + view = self._task_admission.view() + for resource, count in view.leased_resources.items(): + self._observed_max_task_leases_by_resource[resource] = max( + self._observed_max_task_leases_by_resource.get(resource, 0), + count, + ) + queue_view = self._fair_queue.view() + for group, count in queue_view.queued_by_group.items(): + label = f"{group.kind}:{'/'.join(group.identity)}" + self._observed_max_queued_by_group[label] = max(self._observed_max_queued_by_group.get(label, 0), count) + if self._request_pressure_provider is None: + return + for resource, snapshot in self._request_pressure_provider.snapshots().items(): + self._observed_max_request_waiters_by_resource[resource] = max( + self._observed_max_request_waiters_by_resource.get(resource, 0), + snapshot.waiters, + ) + self._observed_max_request_in_flight_by_resource[resource] = max( + self._observed_max_request_in_flight_by_resource.get(resource, 0), + snapshot.in_flight_count, + ) + self._observed_max_request_domain_current_limits[resource] = max( + self._observed_max_request_domain_current_limits.get(resource, 0), + snapshot.current_limit, + ) + for provider_model, snapshot in self._request_pressure_provider.global_snapshots().items(): + self._observed_max_provider_model_aggregate_in_flight[provider_model] = max( + self._observed_max_provider_model_aggregate_in_flight.get(provider_model, 0), + snapshot.aggregate_in_flight, + ) + + def _emit_scheduler_health_snapshot(self, reason: str) -> None: + self._emit_scheduler_event( + "scheduler_health_snapshot", + diagnostics=self._scheduler_health_diagnostics(reason=reason), + ) + + def _scheduler_health_diagnostics(self, *, reason: str) -> dict[str, object]: + queue_view = self._fair_queue.view() + task_view = self._task_admission.view() + return { + "reason": reason, + "active_row_groups": len(self._rg_states), + "target_row_groups": self._row_group_admission_target, + "hard_cap_row_groups": self._row_group_admission_hard_cap, + "active_admitted_rows": self._active_admitted_row_count(), + "max_admitted_rows": self._adaptive_max_admitted_rows, + "all_row_groups_admitted": self._all_rgs_admitted, + "queued_total": queue_view.queued_total, + "queued_by_group": _string_keyed_counts(queue_view.queued_by_group), + "queued_demand_by_resource": dict(queue_view.queued_peer_demand_by_resource), + "leased_resources": dict(task_view.leased_resources), + "resource_limits": dict(task_view.resource_limits), + "resources_available": dict(task_view.resources_available), + "in_flight_tasks": len(self._in_flight), + "active_workers": self.active_worker_count, + "deferred_tasks": len(self._deferred), + "pending_pre_batch_tasks": len(self._pending_pre_batch_ready_tasks), + "dispatched_tasks": len(self._dispatched), + "request_pressure_advisory_enabled": self._request_pressure_advisory, + "request_pressure_advisory_skips": self._request_pressure_advisory_skips, + "row_group_admission_blocked_reasons": dict(self._row_group_admission_blocked_reasons), + "request_pressure": self._request_pressure_diagnostics(), + } + + def _scheduler_job_diagnostics(self) -> dict[str, object]: + row_group_sizes = [size for _rg_id, size in self._row_groups] + strategies = {column: self._graph.get_strategy(column).value for column in self._graph.columns} + task_count_by_strategy = Counter(strategies.values()) + return { + "run_id": self._run_id, + "num_records": self._num_records, + "buffer_size": self._buffer_size, + "row_group_count": len(self._row_groups), + "row_group_total_rows": sum(row_group_sizes), + "row_group_min_size": min(row_group_sizes, default=0), + "row_group_max_size": max(row_group_sizes, default=0), + "graph_column_count": len(self._graph.columns), + "graph_root_columns": tuple(self._graph.get_root_columns()), + "graph_depth": len(self._graph.get_longest_dependency_chain()), + "task_count_by_strategy": dict(task_count_by_strategy), + "column_scheduling": self._column_scheduling_diagnostics(strategies), + "resource_limits": dict(self._task_admission_config.resource_limits), + "submission_capacity": self._task_admission_config.submission_capacity, + "adaptive_row_group_admission": self._adaptive_row_group_admission, + "row_group_initial_target": self._row_group_admission_target, + "row_group_hard_cap": self._row_group_admission_hard_cap, + "max_admitted_rows": self._adaptive_max_admitted_rows, + "request_pressure_advisory_enabled": self._request_pressure_advisory, + } + + def _column_scheduling_diagnostics(self, strategies: dict[str, str]) -> tuple[dict[str, object], ...]: + diagnostics = [] + for column in self._graph.columns: + task_type = "batch" if self._graph.get_strategy(column) != GenerationStrategy.CELL_BY_CELL else "cell" + row_index = None if task_type == "batch" else 0 + task = Task(column=column, row_group=0, row_index=row_index, task_type=task_type) + resolved = self._task_scheduling.scheduling_for_task(task, self._task_flow_identity(task)) + diagnostics.append( + { + "column": column, + "strategy": strategies[column], + "group_kind": resolved.group.key.kind, + "group_identity_hash": _identity_hash(resolved.group.key.identity), + "group_weight": resolved.group.weight, + "group_admitted_limit": resolved.group.admitted_limit, + "resource_request": dict(resolved.resource_request.amounts), + "request_resource": _request_resource_label(resolved.request_resource_key), + } + ) + return tuple(diagnostics) + + def _request_pressure_diagnostics(self) -> dict[str, object]: + if self._request_pressure_provider is None: + return {"enabled": False} + return { + "enabled": True, + "resources": { + _request_resource_label(resource): { + "effective_max": snapshot.effective_max, + "current_limit": snapshot.current_limit, + "in_flight_count": snapshot.in_flight_count, + "active_lease_count": snapshot.active_lease_count, + "waiters": snapshot.waiters, + "blocked": snapshot.blocked_until_monotonic is not None, + "cooldown_remaining_seconds": snapshot.cooldown_remaining_seconds, + "last_outcome": snapshot.last_outcome, + } + for resource, snapshot in self._request_pressure_provider.snapshots().items() + }, + "provider_models": { + f"{provider_model.provider_name}/{provider_model.model_id}": { + "static_cap": snapshot.static_cap, + "aggregate_in_flight": snapshot.aggregate_in_flight, + "aggregate_active_lease_count": snapshot.aggregate_active_lease_count, + "domains": {domain.value: count for domain, count in snapshot.domains.items()}, + } + for provider_model, snapshot in self._request_pressure_provider.global_snapshots().items() + }, + } + + def _request_pressure_item_diagnostics(self, item: SchedulableTask) -> dict[str, object]: + if item.request_resource_key is None or self._request_pressure_provider is None: + return {"request_resource": None} + snapshot = self._request_pressure_provider.snapshot(item.request_resource_key) + global_snapshot = self._request_pressure_provider.global_snapshot( + item.request_resource_key.provider_name, + item.request_resource_key.model_id, + ) + diagnostics: dict[str, object] = { + "request_resource": _request_resource_label(item.request_resource_key), + "pressure_reason": self._request_pressure_reason(item), + "resource_snapshot": None, + "provider_model_snapshot": None, + } + if snapshot is not None: + diagnostics["resource_snapshot"] = { + "effective_max": snapshot.effective_max, + "current_limit": snapshot.current_limit, + "in_flight_count": snapshot.in_flight_count, + "waiters": snapshot.waiters, + "blocked": snapshot.blocked_until_monotonic is not None, + "cooldown_remaining_seconds": snapshot.cooldown_remaining_seconds, + } + if global_snapshot is not None: + diagnostics["provider_model_snapshot"] = { + "static_cap": global_snapshot.static_cap, + "aggregate_in_flight": global_snapshot.aggregate_in_flight, + "aggregate_active_lease_count": global_snapshot.aggregate_active_lease_count, + } + return diagnostics + async def _cancel_workers(self) -> None: """Cancel all tracked worker tasks and wait for them to finish.""" for t in self._worker_tasks: @@ -313,114 +621,370 @@ def _apply_frontier_delta(self, delta: FrontierDelta) -> None: return for task in delta.removed: self._discard_ready_task(task) - for task in delta.added: - self._enqueue_ready_task(task) + self._enqueue_ready_tasks(delta.added) def _enqueue_ready_task(self, task: Task) -> None: - if task in self._dispatched or task.row_group not in self._rg_states: - return - if not self._tracker.is_frontier_task(task): - return - state = self._rg_states[task.row_group] - if self._on_seeds_complete is not None and not state.pre_batch_done: - if task not in self._pending_pre_batch_ready_tasks: - self._pending_pre_batch_ready[task.row_group].append(task) - self._pending_pre_batch_ready_tasks.add(task) + self._enqueue_ready_tasks((task,)) + + def _enqueue_ready_tasks(self, tasks: tuple[Task, ...]) -> None: + schedulables: list[SchedulableTask] = [] + accepted_tasks_by_id: dict[str, Task] = {} + for task in tasks: + if task in self._dispatched or task.row_group not in self._rg_states: + continue + if not self._tracker.is_frontier_task(task): + continue + self._emit_scheduler_event("dependency_ready", task=task) + state = self._rg_states[task.row_group] + if self._on_seeds_complete is not None and not state.pre_batch_done and task.column not in self._seed_cols: + if task not in self._pending_pre_batch_ready_tasks: + self._pending_pre_batch_ready[task.row_group].append(task) + self._pending_pre_batch_ready_tasks.add(task) + continue + schedulable = self._schedulable_task(task) + schedulables.append(schedulable) + accepted_tasks_by_id[schedulable.task_id] = task + + if not schedulables: return - self._fair_queue.enqueue(task, self._task_group_spec(task)) + accepted = self._fair_queue.enqueue(schedulables) + if accepted: + self._tracker.mark_enqueued(accepted) + for task_id in accepted: + self._emit_scheduler_event("ready_enqueued", task=accepted_tasks_by_id[task_id]) + self._record_observed_task_state() + self._wake_event.set() def _discard_ready_task(self, task: Task) -> None: - self._fair_queue.discard(task) + self._fair_queue.discard(stable_task_id(task)) self._pending_pre_batch_ready_tasks.discard(task) def _flush_pre_batch_ready(self, row_group: int) -> None: pending = self._pending_pre_batch_ready.pop(row_group, []) + ready = [] for task in pending: if task not in self._pending_pre_batch_ready_tasks: continue self._pending_pre_batch_ready_tasks.discard(task) - self._enqueue_ready_task(task) + ready.append(task) + self._enqueue_ready_tasks(tuple(ready)) def _drop_pending_ready_for_row_group(self, row_group: int) -> None: pending = self._pending_pre_batch_ready.pop(row_group, []) for task in pending: self._pending_pre_batch_ready_tasks.discard(task) - self._fair_queue.discard_where(lambda task: task.row_group == row_group) + self._fair_queue.discard_where(lambda item: item.payload.row_group == row_group) def _dispatch_queued_tasks(self) -> _DispatchOutcome: dispatched = False while self._fair_queue.has_queued_tasks: - if not self._submission_semaphore.try_acquire(): - return _DispatchOutcome(dispatched=dispatched, submission_full=True) - - selection = self._fair_queue.admit_next() + selection = self._fair_queue.select_next(self._is_dispatch_eligible) if selection is None: - self._submission_semaphore.release() - return _DispatchOutcome(dispatched=dispatched, group_blocked=True) + summary = self._task_admission.explain_blocked(self._fair_queue.view()) + if "group_cap" in summary.dominant_denial_reasons: + event_kind = "group_capped" + elif summary.dominant_denial_reasons: + event_kind = "admission_blocked" + else: + event_kind = "queue_empty" + self._emit_scheduler_event( + event_kind, + diagnostics={ + "queued_count": summary.queued_count, + "reasons": dict(summary.dominant_denial_reasons), + }, + ) + self._emit_scheduler_health_snapshot(event_kind) + return _DispatchOutcome(dispatched=dispatched, admission_blocked=True) + + self._emit_scheduler_event("selected", task=selection.item.payload) + decision = self._task_admission.try_acquire(selection.item, selection.queue_view) + if isinstance(decision, TaskAdmissionDenied): + self._emit_scheduler_event( + "admission_denied", + task=selection.item.payload, + reason_or_result=decision.reason, + diagnostics=dict(decision.diagnostics), + ) + return _DispatchOutcome(dispatched=dispatched, admission_blocked=True) + self._emit_scheduler_event("task_lease_acquired", task=selection.item.payload, lease=decision) + + committed = self._fair_queue.commit(selection) + if committed is None: + result = self._task_admission.release(decision) + self._emit_scheduler_event( + "stale_selection", + task=selection.item.payload, + lease=decision, + reason_or_result=result.reason, + ) + return _DispatchOutcome(dispatched=dispatched, admission_blocked=True) - self._dispatch_selected_task(selection.task) + self._dispatch_selected_task(committed, decision) dispatched = True + self._record_observed_task_state() + if dispatched: + self._emit_scheduler_event("queue_drained") + self._emit_scheduler_health_snapshot("queue_drained") return _DispatchOutcome(dispatched=dispatched) - def _dispatch_selected_task(self, task: Task) -> None: + def _is_dispatch_eligible(self, item: SchedulableTask, view: Any) -> bool: + if not self._task_admission.is_eligible(item, view): + return False + if not self._request_pressure_advisory: + return True + if not self._is_request_pressure_limited(item): + return True + open_peer = self._request_pressure_open_peer(item, view) + if open_peer is not None: + self._request_pressure_advisory_skips += 1 + self._emit_scheduler_event( + "request_pressure_advisory_skipped", + task=item.payload, + diagnostics=self._request_pressure_item_diagnostics(item) + | { + "open_peer_task_id": open_peer.task_id, + "open_peer_column": open_peer.payload.column, + "open_peer_request_resource": _request_resource_label(open_peer.request_resource_key), + "skip_count": self._request_pressure_advisory_skips, + }, + ) + return False + return True + + def _is_request_pressure_limited(self, item: SchedulableTask) -> bool: + return self._request_pressure_reason(item) is not None + + def _request_pressure_reason(self, item: SchedulableTask) -> str | None: + if item.request_resource_key is None or self._request_pressure_provider is None: + return None + snapshot = self._request_pressure_provider.snapshot(item.request_resource_key) + global_snapshot = self._request_pressure_provider.global_snapshot( + item.request_resource_key.provider_name, + item.request_resource_key.model_id, + ) + if ( + global_snapshot is not None + and global_snapshot.static_cap > 0 + and global_snapshot.aggregate_in_flight >= global_snapshot.static_cap + ): + return "provider_model_aggregate_cap" + if snapshot is None: + return None + if snapshot.cooldown_remaining_seconds > 0.0 or snapshot.blocked_until_monotonic is not None: + return "cooldown" + if snapshot.waiters > 0: + return "waiters" + if snapshot.current_limit > 0 and snapshot.in_flight_count >= snapshot.current_limit: + return "resource_limit" + return None + + def _has_request_pressure_open_peer(self, item: SchedulableTask, view: Any) -> bool: + return self._request_pressure_open_peer(item, view) is not None + + def _request_pressure_open_peer(self, item: SchedulableTask, view: Any) -> SchedulableTask | None: + for peer in view.first_candidate_tasks_by_group.values(): + if peer.task_id == item.task_id: + continue + if not self._task_admission.is_eligible(peer, view): + continue + if not self._is_request_pressure_limited(peer): + return peer + return None + + def _dispatch_selected_task(self, item: SchedulableTask, lease: TaskAdmissionLease) -> None: + task = item.payload + task_execution_id = f"task-exec-{uuid.uuid4().hex}" self._dispatched.add(task) self._in_flight.add(task) if (s := self._rg_states.get(task.row_group)) is not None: s.in_flight_count += 1 - self._spawn_worker(self._execute_task(task)) - - def _task_group_spec(self, task: Task) -> TaskGroupSpec: - generator = self._generators[task.column] - generator_id = id(generator) - cached = self._task_group_spec_cache.get(generator_id) - if cached is not None: - return cached - - spec = self._task_group_spec_from_hint( - self._scheduling_hints.hint_for(generator), - self._task_flow_identity(task), - ) - self._task_group_spec_cache[generator_id] = spec - return spec - - def _task_group_spec_from_hint(self, hint: SchedulingHint, flow_identity: tuple[str, ...]) -> TaskGroupSpec: - if hint.group_kind == "local": - return TaskGroupSpec(key=TaskGroupKey(kind="local", identity=flow_identity)) + try: + self._spawn_worker(self._execute_task(task, lease, task_execution_id)) + self._emit_scheduler_event("worker_spawned", task=task, lease=lease, task_execution_id=task_execution_id) + except Exception: + result = self._task_admission.release(lease) + self._emit_scheduler_event( + "worker_spawn_failed", + task=task, + lease=lease, + task_execution_id=task_execution_id, + reason_or_result=result.reason, + ) + self._in_flight.discard(task) + raise - if hint.group_kind == "custom_model": - identity = (*flow_identity, *hint.identity_suffix) - else: - identity = (*hint.identity_prefix, *flow_identity, *hint.identity_suffix) - - weight = max(1, hint.weight) - return TaskGroupSpec( - key=TaskGroupKey(kind=hint.group_kind, identity=identity), - weight=float(weight), - admitted_limit=self._llm_group_admitted_limit(weight), - ) + def _schedulable_task(self, task: Task) -> SchedulableTask: + return self._task_scheduling.schedulable_task(task, self._task_flow_identity(task)) def _task_flow_identity(self, task: Task) -> tuple[str, ...]: generator = self._generators[task.column] output_columns = self._gen_instance_to_columns.get(id(generator), [task.column]) return tuple(output_columns) - def _llm_group_admitted_limit(self, weight: int) -> int: - return max(1, min(self._max_llm_wait_tasks, LLM_GROUP_ADMISSION_BACKLOG_MULTIPLIER * weight)) + def _max_admitted_rows_guardrail(self) -> int: + if self._num_records > 0 and self._buffer_size > 0: + return min(self._num_records, max(3 * self._buffer_size, 8192)) + total_rows = sum(size for _rg_id, size in self._row_groups) + return max(1, total_rows) + + async def _wait_for_row_group_admission_capacity(self, row_group_size: int) -> None: + while True: + target_blocked = len(self._rg_states) >= self._row_group_admission_target + row_guard_blocked = not self._row_group_row_guard_allows(row_group_size) + if not target_blocked and not row_guard_blocked: + return + self._row_group_admission_event.clear() + target_blocked = len(self._rg_states) >= self._row_group_admission_target + row_guard_blocked = not self._row_group_row_guard_allows(row_group_size) + if not target_blocked and not row_guard_blocked: + return + if row_guard_blocked: + self._row_group_admission_blocked_reasons["max_admitted_rows"] += 1 + self._emit_scheduler_event( + "row_group_admission_blocked", + diagnostics=self._row_group_admission_diagnostics(reason="max_admitted_rows"), + ) + self._emit_scheduler_health_snapshot("row_group_admission_blocked") + await self._row_group_admission_event.wait() + self._raise_if_fatal_worker_error() + + def _row_group_row_guard_allows(self, row_group_size: int) -> bool: + if not self._adaptive_row_group_admission: + return True + admitted_rows = self._active_admitted_row_count() + return admitted_rows == 0 or admitted_rows + row_group_size <= self._adaptive_max_admitted_rows + + def _active_admitted_row_count(self) -> int: + return sum(state.size for state in self._rg_states.values()) + + def _maybe_update_adaptive_row_group_target(self) -> None: + if not self._adaptive_row_group_admission: + return + if self._all_rgs_admitted or self._early_shutdown or self._fatal_worker_error is not None: + return + if len(self._rg_states) >= self._row_group_admission_hard_cap: + self._row_group_admission_pressure_ticks = 0 + return + reason = self._adaptive_row_group_block_reason() + if reason is not None: + self._row_group_admission_blocked_reasons[reason] += 1 + self._row_group_admission_pressure_ticks = 0 + self._emit_scheduler_event( + "row_group_admission_blocked", + diagnostics=self._row_group_admission_diagnostics(reason=reason), + ) + self._emit_scheduler_health_snapshot("row_group_admission_blocked") + return + + self._row_group_admission_pressure_ticks += 1 + if self._fair_queue.view().queued_total > 0 and self._row_group_admission_pressure_ticks < 2: + return + old_target = self._row_group_admission_target + self._row_group_admission_target = min(self._row_group_admission_hard_cap, old_target + 1) + self._observed_max_row_group_admission_target = max( + self._observed_max_row_group_admission_target, + self._row_group_admission_target, + ) + self._row_group_admission_pressure_ticks = 0 + if self._row_group_admission_target != old_target: + self._emit_scheduler_event( + "row_group_admission_target_changed", + diagnostics=self._row_group_admission_diagnostics(reason="horizon_limited") + | {"old_target": old_target, "new_target": self._row_group_admission_target}, + ) + self._emit_scheduler_health_snapshot("row_group_admission_target_changed") + self._row_group_admission_event.set() + + def _adaptive_row_group_block_reason(self) -> str | None: + if self._deferred: + return "deferred_tasks" + next_size = self._next_unadmitted_row_group_size() + if next_size is None: + return "no_pending_row_groups" + if not self._row_group_row_guard_allows(next_size): + return "max_admitted_rows" + queue_view = self._fair_queue.view() + queue_guard = max(self._max_submitted_tasks * 4, self._max_model_task_admission * 2) + if queue_view.queued_total >= queue_guard: + return "queued_task_guardrail" + task_view = self._task_admission.view() + llm_limit = task_view.resource_limits.get("llm_wait", 0) + if llm_limit <= 0: + return "no_llm_wait_resource" + llm_available = task_view.resources_available.get("llm_wait", 0) + queued_llm = queue_view.queued_peer_demand_by_resource.get("llm_wait", 0) + if llm_available <= queued_llm and queue_view.queued_total > 0: + return "queued_llm_demand" + if llm_available <= 0: + return "llm_wait_saturated" + return None + + def _next_unadmitted_row_group_size(self) -> int | None: + for rg_id, rg_size in self._row_groups: + if rg_id not in self._rg_states and not self._tracker.is_row_group_complete( + rg_id, rg_size, self._graph.columns + ): + return rg_size + return None + + def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]: + queue_view = self._fair_queue.view() + task_view = self._task_admission.view() + admitted_rows = self._active_admitted_row_count() + return { + "mode": "adaptive" if self._adaptive_row_group_admission else "fixed", + "reason": reason, + "active_row_groups": len(self._rg_states), + "target_row_groups": self._row_group_admission_target, + "hard_cap": self._row_group_admission_hard_cap, + "admitted_rows": admitted_rows, + "max_admitted_rows": self._adaptive_max_admitted_rows, + "queued_total": queue_view.queued_total, + "queued_llm_wait_demand": queue_view.queued_peer_demand_by_resource.get("llm_wait", 0), + "llm_wait_limit": task_view.resource_limits.get("llm_wait", 0), + "llm_wait_leased": task_view.leased_resources.get("llm_wait", 0), + "llm_wait_available": task_view.resources_available.get("llm_wait", 0), + "blocked_reasons": dict(self._row_group_admission_blocked_reasons), + } async def _admit_row_groups(self) -> None: """Admit row groups as semaphore slots become available.""" + all_admitted = True for rg_id, rg_size in self._row_groups: + await self._wait_for_row_group_admission_capacity(rg_size) + if self._early_shutdown or self._fatal_worker_error is not None: + all_admitted = False + break await self._rg_semaphore.acquire() + if self._early_shutdown or self._fatal_worker_error is not None: + self._rg_semaphore.release() + all_admitted = False + break + if not self._row_group_row_guard_allows(rg_size): + self._rg_semaphore.release() + await self._wait_for_row_group_admission_capacity(rg_size) + await self._rg_semaphore.acquire() + if self._early_shutdown or self._fatal_worker_error is not None: + self._rg_semaphore.release() + all_admitted = False + break self._rg_states[rg_id] = _RowGroupState(size=rg_size) if self._buffer_manager is not None: self._buffer_manager.init_row_group(rg_id, rg_size) await self._dispatch_seeds(rg_id, rg_size) + self._emit_scheduler_event( + "row_group_admitted", + diagnostics=self._row_group_admission_diagnostics(reason="admitted") + | {"row_group": rg_id, "row_group_size": rg_size}, + ) + self._emit_scheduler_health_snapshot("row_group_admitted") self._wake_event.set() - self._all_rgs_admitted = True + self._all_rgs_admitted = all_admitted self._wake_event.set() async def run(self) -> None: @@ -440,6 +1004,9 @@ async def run(self) -> None: if self._reporter: self._reporter.log_start(num_row_groups=num_rgs) + self._emit_scheduler_event("scheduler_job_started", diagnostics=self._scheduler_job_diagnostics()) + self._emit_scheduler_health_snapshot("start") + # Launch admission as a background task so it interleaves with dispatch. admission_task = asyncio.create_task(self._admit_row_groups()) @@ -466,6 +1033,11 @@ async def run(self) -> None: if self._reporter: self._reporter.log_final() + self._emit_scheduler_health_snapshot("completed") + self._emit_scheduler_event( + "scheduler_job_completed", diagnostics=self._scheduler_health_diagnostics(reason="completed") + ) + if self._rg_states: incomplete = list(self._rg_states) logger.error( @@ -481,6 +1053,7 @@ async def _main_dispatch_loop( ) -> None: """Core dispatch loop extracted from ``run()``.""" while True: + self._raise_if_fatal_worker_error() if self._early_shutdown: logger.warning("Early shutdown triggered - non-retryable error rate exceeded threshold") if self._deferred: @@ -496,28 +1069,36 @@ async def _main_dispatch_loop( dispatch_outcome = self._dispatch_queued_tasks() self._checkpoint_completed_row_groups(all_columns) + self._maybe_update_adaptive_row_group_target() # Eagerly salvage any row groups that have only deferred tasks, # even if other row groups are still in-flight. This frees # semaphore slots so admission doesn't lose capacity. if self._deferred: await self._salvage_stalled_row_groups(seed_cols, has_pre_batch, all_columns) + self._maybe_update_adaptive_row_group_target() # Are we done? all_done = self._all_rgs_admitted and not self._rg_states and not self._in_flight if all_done: break + pending_pre_batch = has_pre_batch and any( + state.seeds_dispatched and not state.pre_batch_done for state in self._rg_states.values() + ) if not self._fair_queue.has_queued_tasks and not self._in_flight: - if self._all_rgs_admitted: + if self._all_rgs_admitted and not pending_pre_batch: break + if pending_pre_batch: + continue - if ( - not self._fair_queue.has_queued_tasks - or dispatch_outcome.submission_full - or dispatch_outcome.group_blocked - ): + if not self._fair_queue.has_queued_tasks or dispatch_outcome.admission_blocked: + if self._fair_queue.has_queued_tasks and not dispatch_outcome.dispatched and not self._in_flight: + raise RuntimeError( + "Ready frontier is admission-blocked with no in-flight task to release scheduler capacity." + ) await self._wake_event.wait() + self._raise_if_fatal_worker_error() async def _salvage_rounds( self, @@ -549,34 +1130,10 @@ async def _salvage_rounds( self._dispatched.discard( Task(column=sibling, row_group=task.row_group, row_index=None, task_type="batch") ) - # Acquire stateful lock (mirrors _dispatch_seeds) so - # _execute_seed_task can safely release it in finally. - if gid in self._stateful_locks: - await self._stateful_locks[gid].acquire() - await self._submission_semaphore.acquire() - self._dispatched.add(task) - # Re-register batch alias to mirror _dispatch_seeds and prevent - # duplicate dispatch if the frontier contains a stale batch task. - self._dispatched.add( - Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch") - ) - # Re-mark sibling columns as dispatched to mirror _dispatch_seeds - # and prevent _drain_frontier from re-dispatching them. - for sibling in self._gen_instance_to_columns.get(gid, []): - if sibling != task.column: - self._dispatched.add( - Task(column=sibling, row_group=task.row_group, row_index=None, task_type="from_scratch") - ) - self._dispatched.add( - Task(column=sibling, row_group=task.row_group, row_index=None, task_type="batch") - ) - self._in_flight.add(task) - if (s := self._rg_states.get(task.row_group)) is not None: - s.in_flight_count += 1 - self._spawn_worker(self._execute_seed_task(task, gid)) + self._apply_frontier_delta(self._tracker.add_ready_tasks((task,))) else: self._dispatched.discard(task) - self._enqueue_ready_task(task) + self._apply_frontier_delta(self._tracker.add_ready_tasks((task,))) # Drain: dispatch frontier tasks and any newly-ready downstream tasks # until nothing remains in-flight or in the frontier. await self._drain_frontier(seed_cols, has_pre_batch) @@ -585,6 +1142,7 @@ async def _salvage_rounds( async def _drain_frontier(self, seed_cols: tuple[str, ...], has_pre_batch: bool) -> None: """Dispatch all frontier tasks and their downstream until quiescent.""" while True: + self._raise_if_fatal_worker_error() if has_pre_batch: self._run_seeds_complete_check(seed_cols) dispatch_outcome = self._dispatch_queued_tasks() @@ -599,6 +1157,7 @@ async def _drain_frontier(self, seed_cols: tuple[str, ...], has_pre_batch: bool) continue self._wake_event.clear() await self._wake_event.wait() + self._raise_if_fatal_worker_error() async def _salvage_stalled_row_groups( self, @@ -657,6 +1216,9 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: if self._tracker.is_row_group_complete(rg_id, state.size, all_columns) ] for rg_id, rg_size in completed: + dropped_rows = sum(1 for ri in range(rg_size) if self._tracker.is_dropped(rg_id, ri)) + checkpointed = False + checkpoint_result = "unknown" try: if self._on_before_checkpoint: try: @@ -670,17 +1232,36 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: # Remove from tracking only after the callback succeeds. del self._rg_states[rg_id] # If all rows were dropped (e.g. seed failure), free instead of finalizing - if all(self._tracker.is_dropped(rg_id, ri) for ri in range(rg_size)): + if dropped_rows == rg_size: if self._buffer_manager: self._buffer_manager.free_row_group(rg_id) + checkpoint_result = "all_rows_dropped" elif self._on_finalize_row_group is not None: self._on_finalize_row_group(rg_id) + checkpoint_result = "finalized" + else: + checkpoint_result = "completed" + checkpointed = True except DatasetGenerationError: raise except Exception: logger.error(f"Failed to checkpoint row group {rg_id}.", exc_info=True) finally: self._rg_semaphore.release() + self._row_group_admission_event.set() + if checkpointed: + self._emit_scheduler_event( + "row_group_checkpointed", + diagnostics={ + "row_group": rg_id, + "row_group_size": rg_size, + "dropped_rows": dropped_rows, + "surviving_rows": rg_size - dropped_rows, + "result": checkpoint_result, + "active_row_groups": len(self._rg_states), + }, + ) + self._emit_scheduler_health_snapshot("row_group_checkpointed") # Clean up deferred tasks for checkpointed row groups if completed: @@ -803,7 +1384,7 @@ def _check_error_rate(self, *, success: bool) -> None: self._early_shutdown = True def _record_retryable_outcome(self, *, retryable: bool) -> None: - """Track retryable-error rate and emit a throttled WARN under provider degradation. + """Track retryable-error rate and emit a rate-limited WARN under provider degradation. Distinct from ``_check_error_rate``: every LLM-bound task outcome (success or failure) feeds this window so the rate reflects the provider's overall @@ -832,7 +1413,7 @@ def _record_retryable_outcome(self, *, retryable: bool) -> None: ) async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: - """Dispatch from_scratch tasks for a row group.""" + """Make from-scratch/root tasks ready for a row group.""" self._rg_states[rg_id].seeds_dispatched = True seed_cols = self._seed_cols if not seed_cols: @@ -841,6 +1422,7 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: width = len(str(num_rgs)) logger.info(f"🚀 ({rg_id + 1:0{width}d}/{num_rgs}) Dispatching with {rg_size} records") seen_instances: set[int] = set() + root_columns: list[str] = [] for col in seed_cols: gen = self._generators[col] @@ -848,64 +1430,38 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: if gid in seen_instances: continue seen_instances.add(gid) + root_columns.append(col) - task = Task(column=col, row_group=rg_id, row_index=None, task_type="from_scratch") - # Also mark the "batch" variant as dispatched to prevent duplicate - # scheduling for this column. - batch_alias = Task(column=col, row_group=rg_id, row_index=None, task_type="batch") - if task in self._dispatched or batch_alias in self._dispatched: - continue - - # Seeds bypass fair-queue admission while row groups are being admitted; - # direct dispatch preserves stateful lock ordering across row groups. - # Acquire stateful lock *before* submission semaphore to preserve - # row-group ordering. Held until generation completes (_execute_seed_task). - if gid in self._stateful_locks: - await self._stateful_locks[gid].acquire() - - await self._submission_semaphore.acquire() - self._dispatched.add(task) - self._dispatched.add(batch_alias) - # Also mark all sibling output columns as dispatched (multi-column dedup) - for sibling_col in self._gen_instance_to_columns.get(gid, []): - if sibling_col != col: - self._dispatched.add( - Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="from_scratch") - ) - self._dispatched.add(Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="batch")) - self._in_flight.add(task) - if (s := self._rg_states.get(task.row_group)) is not None: - s.in_flight_count += 1 - self._spawn_worker(self._execute_seed_task(task, gid)) - - async def _execute_seed_task(self, task: Task, generator_id: int) -> None: - """Execute a from_scratch task and release stateful lock if held.""" - try: - await self._execute_task_inner(task) - finally: - if generator_id in self._stateful_locks: - self._stateful_locks[generator_id].release() + self._apply_frontier_delta(self._tracker.add_root_tasks(rg_id, rg_size, columns=tuple(root_columns))) - async def _execute_task(self, task: Task) -> None: + async def _execute_task(self, task: Task, lease: TaskAdmissionLease, task_execution_id: str) -> None: """Execute a single task (cell or batch).""" - await self._execute_task_inner(task) - - async def _execute_task_inner(self, task: Task) -> None: - """Core task execution logic. + await self._execute_task_inner(task, lease, task_execution_id) - For LLM-bound tasks, uses a one-way semaphore handoff: acquires the - LLM-wait slot while still holding the submission slot, then releases - the submission slot (never reacquired). This prevents cross-key - starvation while bounding live coroutines. - """ + async def _execute_task_inner(self, task: Task, lease: TaskAdmissionLease, task_execution_id: str) -> None: + """Core task execution logic.""" num_rgs = len(self._row_groups) token = current_row_group.set((task.row_group, num_rgs)) + group = lease.item.group + identity_hash = hashlib.sha1("\0".join(group.key.identity).encode()).hexdigest()[:16] + correlation_token = runtime_correlation_provider.set( + RuntimeCorrelation( + run_id=self._run_id, + row_group=task.row_group, + task_column=task.column, + task_type=task.task_type, + scheduling_group_kind=group.key.kind, + scheduling_group_identity_hash=identity_hash, + task_execution_id=task_execution_id, + ) + ) try: - await self._execute_task_inner_impl(task) + await self._execute_task_inner_impl(task, lease, task_execution_id) finally: + runtime_correlation_provider.reset(correlation_token) current_row_group.reset(token) - async def _execute_task_inner_impl(self, task: Task) -> None: + async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, task_execution_id: str) -> None: trace: TaskTrace | None = None if self._trace: trace = TaskTrace.from_task(task) @@ -914,12 +1470,12 @@ async def _execute_task_inner_impl(self, task: Task) -> None: generator = self._generators[task.column] output_cols = self._gen_instance_to_columns.get(id(generator), [task.column]) retryable = False + cancelled = False # When True, skip removing from _dispatched so the task isn't re-dispatched # from the frontier (it was never completed, so it stays in the frontier). skipped = False - is_llm = self._llm_bound_lookup.get(task.column, False) - holds_submission = True - holds_llm_wait = False + uses_model_stage_resource = "llm_wait" in lease.resources + stateful_lock_acquired = False try: # Skip tasks whose row group was already checkpointed (can happen @@ -929,11 +1485,9 @@ async def _execute_task_inner_impl(self, task: Task) -> None: skipped = True return - if is_llm: - await self._llm_wait_semaphore.acquire() - holds_llm_wait = True - self._submission_semaphore.release() - holds_submission = False + if task.task_type == "from_scratch" and id(generator) in self._stateful_locks: + await self._stateful_locks[id(generator)].acquire() + stateful_lock_acquired = True if self._trace and trace: trace.slot_acquired_at = time.perf_counter() @@ -962,7 +1516,7 @@ async def _execute_task_inner_impl(self, task: Task) -> None: # window from LLM-bound tasks so a healthy non-model task mix # (samplers, expressions, non-LLM customs) doesn't dilute the # rate and silence the WARN under genuine provider stress. - if is_llm: + if uses_model_stage_resource: self._record_retryable_outcome(retryable=False) if self._reporter: if cell_skipped: @@ -972,6 +1526,13 @@ async def _execute_task_inner_impl(self, task: Task) -> None: if self._trace and trace: trace.status = "ok" + except asyncio.CancelledError: + cancelled = True + if self._trace and trace: + trace.status = "cancelled" + self._emit_scheduler_event("cancelled", task=task, lease=lease, task_execution_id=task_execution_id) + raise + except Exception as exc: retryable = self._is_retryable(exc) # Only non-retryable errors (auth, schema, code bugs) count toward @@ -980,7 +1541,7 @@ async def _execute_task_inner_impl(self, task: Task) -> None: # and would otherwise trip the gate even when salvage could recover. if not retryable: self._check_error_rate(success=False) - if is_llm: + if uses_model_stage_resource: self._record_retryable_outcome(retryable=retryable) if not retryable and self._reporter: self._reporter.record_failure(task.column) @@ -990,21 +1551,41 @@ async def _execute_task_inner_impl(self, task: Task) -> None: if retryable: self._deferred.append(task) + self._emit_scheduler_event( + "retry_deferred", task=task, lease=lease, task_execution_id=task_execution_id + ) else: # Capture the first non-retryable error for the interface to surface # as the root cause when the run produces 0 records (e.g. deterministic # seed failures). Subsequent failures are still logged below. if self._first_non_retryable_error is None: self._first_non_retryable_error = exc - # Non-retryable: drop the affected row(s) + log_message = ( + f"Non-retryable failure on {task.column}[rg={task.row_group}, row={task.row_index}]: {exc}" + ) + if self._is_expected_non_retryable(exc): + logger.warning(log_message) + elif self._is_internal_bug(exc): + logger.error("Unexpected fatal %s", log_message, exc_info=True) + self._fatal_worker_error = exc + self._wake_event.set() + raise + else: + logger.error("Unexpected %s", log_message, exc_info=True) + # Non-retryable data/user/provider failures drop the affected row(s); + # internal bug-shaped failures above abort the run instead. if task.row_index is not None: self._drop_row(task.row_group, task.row_index, exclude_columns={task.column}) else: # Batch/from_scratch failure: drop all rows in the row group rg_size = self._get_rg_size(task.row_group) self._drop_row_group(task.row_group, rg_size, exclude_columns={task.column}) - logger.warning( - f"Non-retryable failure on {task.column}[rg={task.row_group}, row={task.row_index}]: {exc}" + self._emit_scheduler_event( + "non_retryable_dropped", + task=task, + lease=lease, + task_execution_id=task_execution_id, + diagnostics={"error_type": type(exc).__name__}, ) finally: @@ -1012,16 +1593,38 @@ async def _execute_task_inner_impl(self, task: Task) -> None: trace.completed_at = time.perf_counter() self.traces.append(trace) - self._fair_queue.release(task) + self._tracker.mark_complete(task) + if not cancelled: + self._emit_scheduler_event( + "task_completed", + task=task, + lease=lease, + task_execution_id=task_execution_id, + ) self._in_flight.discard(task) if (s := self._rg_states.get(task.row_group)) is not None: s.in_flight_count = max(0, s.in_flight_count - 1) if not retryable and not skipped: self._dispatched.discard(task) - if holds_llm_wait: - self._llm_wait_semaphore.release() - if holds_submission: - self._submission_semaphore.release() + if stateful_lock_acquired: + self._stateful_locks[id(generator)].release() + release_result = self._task_admission.release(lease) + self._emit_scheduler_event( + "task_lease_released", + task=task, + lease=lease, + task_execution_id=task_execution_id, + reason_or_result=release_result.reason, + ) + if not release_result.released: + self._emit_scheduler_event( + "release_diagnostic", + task=task, + lease=lease, + task_execution_id=task_execution_id, + reason_or_result=release_result.reason, + ) + self._record_observed_task_state() self._wake_event.set() async def _run_from_scratch(self, task: Task, generator: ColumnGenerator) -> Any: @@ -1176,18 +1779,154 @@ def _get_rg_size(self, row_group: int) -> int: except KeyError: raise ValueError(f"Unknown row group: {row_group}") from None - def get_semaphore_permits(self) -> tuple[int, int]: - """Return ``(submission_available, llm_wait_available)`` for diagnostics.""" - return ( - self._submission_semaphore.available_permits, - self._llm_wait_semaphore.available_permits, + def task_admission_snapshot(self) -> object: + """Return the current scheduler task-admission snapshot for diagnostics.""" + return self._task_admission.view() + + def capacity_plan(self) -> AsyncCapacityPlan: + """Return the scheduler-side async capacity explanation for this run.""" + task_view = self._task_admission.view() + request_snapshots = ( + dict(self._request_pressure_provider.snapshots()) if self._request_pressure_provider is not None else {} + ) + provider_snapshots = ( + dict(self._request_pressure_provider.global_snapshots()) + if self._request_pressure_provider is not None + else {} + ) + request_resources = tuple(sorted(request_snapshots)) + provider_model_static_caps = { + provider_model: ProviderModelStaticCap( + cap=snapshot.static_cap, + aliases=snapshot.aliases, + raw_caps=snapshot.raw_caps, + ) + for provider_model, snapshot in provider_snapshots.items() + } + request_config = self._request_pressure_provider.config if self._request_pressure_provider is not None else None + request_config_snapshot = ( + RequestAdmissionConfigSnapshot.from_config(request_config) + if isinstance(request_config, RequestAdmissionConfig) + else None + ) + request_domain_initial_limits: dict[RequestResourceKey, int] = {} + if request_config_snapshot is not None: + request_domain_initial_limits.update(request_config_snapshot.initial_limits) + for resource, snapshot in request_snapshots.items(): + configured_initial = ( + request_config_snapshot.initial_limits.get(resource) if request_config_snapshot is not None else None + ) + request_domain_initial_limits[resource] = ( + max(1, min(configured_initial, snapshot.effective_max)) + if configured_initial is not None + else snapshot.effective_max + ) + request_domain_current_limits = { + resource: snapshot.current_limit for resource, snapshot in request_snapshots.items() + } + request_domain_effective_max = { + resource: snapshot.effective_max for resource, snapshot in request_snapshots.items() + } + request_domain_blocked_until = { + resource: snapshot.blocked_until_monotonic for resource, snapshot in request_snapshots.items() + } + provider_model_aggregate_in_flight = { + provider_model: snapshot.aggregate_in_flight for provider_model, snapshot in provider_snapshots.items() + } + return AsyncCapacityPlan( + configured=AsyncCapacityConfigured( + buffer_size=CapacityValue(value=self._buffer_size, source="run_config"), + row_group_admission=RowGroupAdmission( + row_group_concurrency=CapacityValue( + value=self._max_concurrent_row_groups, + source="dataset_builder", + ), + observed_in_flight=len(self._rg_states), + mode="adaptive" if self._adaptive_row_group_admission else "fixed", + target_in_flight=self._row_group_admission_target, + observed_max_target=self._observed_max_row_group_admission_target, + max_admitted_rows=self._adaptive_max_admitted_rows, + blocked_reasons=dict(self._row_group_admission_blocked_reasons), + ), + submission_capacity=CapacityValue(value=self._max_submitted_tasks, source="dataset_builder"), + task_resource_limits=CapacityValue( + value=dict(self._task_admission_config.resource_limits), + source="engine_internal_config", + ), + request_resources=CapacityValue( + value=request_resources, + source="runtime_snapshot", + missing_reason=None if request_resources else "request admission has not observed any resources", + ), + provider_model_static_caps=CapacityValue( + value=provider_model_static_caps, + source="model_metadata", + missing_reason=None if provider_model_static_caps else "request admission has no registered models", + ), + request_domain_initial_limits=CapacityValue( + value=request_domain_initial_limits, + source="engine_internal_config" if request_config_snapshot is not None else "runtime_snapshot", + missing_reason=None + if request_domain_initial_limits + else "request admission has not observed any domain limits", + ), + request_admission_config=CapacityValue( + value=request_config_snapshot, + source="engine_internal_config", + missing_reason=None + if request_config_snapshot is not None + else "request admission config is not exposed by the pressure provider", + ), + transport_pool_limits=CapacityValue( + value={}, + source="adapter_config", + missing_reason="transport pool utilization is adapter-specific", + ), + ), + runtime_snapshot=AsyncCapacityRuntimeSnapshot( + request_domain_current_limits=request_domain_current_limits, + request_domain_effective_max=request_domain_effective_max, + request_domain_blocked_until=request_domain_blocked_until, + provider_model_aggregate_in_flight=provider_model_aggregate_in_flight, + ), + observed_maxima=AsyncCapacityObservedMaxima( + row_groups_in_flight=self._observed_max_row_groups_in_flight, + queued_tasks_by_group=dict(self._observed_max_queued_by_group), + task_leases_by_resource=dict(self._observed_max_task_leases_by_resource or task_view.leased_resources), + request_waiters_by_resource=dict( + self._observed_max_request_waiters_by_resource + or {resource: snapshot.waiters for resource, snapshot in request_snapshots.items()} + ), + request_in_flight_by_resource=dict( + self._observed_max_request_in_flight_by_resource + or {resource: snapshot.in_flight_count for resource, snapshot in request_snapshots.items()} + ), + provider_model_aggregate_in_flight=dict( + self._observed_max_provider_model_aggregate_in_flight or provider_model_aggregate_in_flight + ), + request_domain_current_limits=dict( + self._observed_max_request_domain_current_limits or request_domain_current_limits + ), + transport_pool_utilization=None, + ), ) @staticmethod - def _is_retryable(exc: Exception) -> bool: + def _is_retryable(exc: BaseException) -> bool: """Classify whether an exception is retryable.""" return isinstance(exc, RETRYABLE_MODEL_ERRORS) + @staticmethod + def _is_expected_non_retryable(exc: BaseException) -> bool: + return isinstance( + exc, + ( + DataDesignerError, + DatasetGenerationError, + GenerationValidationFailureError, + ProviderError, + ), + ) -def build_llm_bound_lookup(generators: dict[str, ColumnGenerator]) -> dict[str, bool]: - return {col: gen.is_llm_bound for col, gen in generators.items()} + def _is_internal_bug(self, exc: BaseException) -> bool: + return isinstance(exc, INTERNAL_BUG_EXCEPTIONS) and not self._is_expected_non_retryable(exc) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index b820c95aa..93ca2ea25 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -70,7 +70,7 @@ from data_designer.config.run_config import RunConfig from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry - from data_designer.engine.dataset_builders.utils.task_model import TaskTrace + from data_designer.engine.dataset_builders.scheduling.task_model import TaskTrace from data_designer.engine.models.usage import ModelUsageStats logger = logging.getLogger(__name__) @@ -85,14 +85,14 @@ from data_designer.engine.dataset_builders.async_scheduler import ( DEFAULT_TASK_POOL_SIZE, - GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER, + MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER, AsyncTaskScheduler, ) + from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.utils.async_concurrency import ( AsyncConcurrentExecutor, ensure_async_engine_loop, ) - from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager @@ -1015,9 +1015,9 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: df = self._processor_runner.run_post_batch(df, current_batch_number=rg_id, strict_row_count=True) buffer_manager.replace_dataframe(rg_id, df) - # Coarse upper bound: sums all registered aliases, not just those used - # in this build. Oversizing is harmless - ThrottleManager enforces - # the real per-key limit; the semaphore is a memory-safety cap. + # Coarse upper bound used only for scheduler task-stage model admission. + # Concrete provider/model request capacity is enforced by request admission + # at the model-call boundary. aggregate = self._resource_provider.model_registry.get_aggregate_max_parallel_requests() scheduler = AsyncTaskScheduler( @@ -1027,7 +1027,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: row_groups=row_groups, buffer_manager=buffer_manager, max_submitted_tasks=DEFAULT_TASK_POOL_SIZE, - max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER * aggregate), + max_model_task_admission=max(DEFAULT_TASK_POOL_SIZE, MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER * aggregate), on_finalize_row_group=on_finalize_row_group, on_seeds_complete=( on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None @@ -1045,6 +1045,8 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: buffer_size=buffer_size, progress_interval=self._resource_provider.run_config.progress_interval, progress_bar=self._resource_provider.run_config.progress_bar, + request_pressure_provider=self._resource_provider.model_registry.request_admission, + request_pressure_advisory=True, ) return scheduler, buffer_manager diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py similarity index 83% rename from packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py rename to packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py index 2d35ec0be..b34ffe69a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py @@ -8,7 +8,8 @@ from typing import TYPE_CHECKING from data_designer.config.column_configs import GenerationStrategy -from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task +from data_designer.engine.dataset_builders.scheduling.resources import stable_task_id +from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef, Task if TYPE_CHECKING: from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -147,14 +148,32 @@ def is_row_group_complete( return False return True - def get_ready_tasks(self, dispatched: set[Task], admitted_rgs: set[int] | None = None) -> list[Task]: - """Return all currently dispatchable tasks from the frontier. + def ready_frontier(self) -> tuple[Task, ...]: + """Return dependency-ready tasks not yet acknowledged as enqueued.""" + return tuple(self._frontier) - Excludes already-dispatched/in-flight tasks and tasks for row groups - not yet admitted (if ``admitted_rgs`` is provided). - """ + def mark_enqueued(self, task_ids: set[str] | list[str] | tuple[str, ...]) -> None: + """Acknowledge tasks accepted by the ready queue.""" + wanted = set(task_ids) + self._frontier = {task for task in self._frontier if stable_task_id(task) not in wanted} + + def mark_complete(self, task: Task) -> None: + """Compatibility hook for scheduler terminal accounting.""" + + def add_ready_tasks(self, tasks: list[Task] | tuple[Task, ...]) -> FrontierDelta: + """Add ready tasks to the frontier idempotently.""" + added: list[Task] = [] + for task in tasks: + if self._add_frontier_task(task): + added.append(task) + return self._record_delta(added=added, removed=[]) + + def get_ready_tasks(self, dispatched: set[Task], admitted_rgs: set[int] | None = None) -> list[Task]: + """Return all currently dispatchable tasks from the frontier.""" return [ - t for t in self._frontier if t not in dispatched and (admitted_rgs is None or t.row_group in admitted_rgs) + t + for t in self.ready_frontier() + if t not in dispatched and (admitted_rgs is None or t.row_group in admitted_rgs) ] def is_frontier_task(self, task: Task) -> bool: @@ -171,13 +190,36 @@ def seed_frontier(self) -> None: if self._graph is None: raise RuntimeError("This method requires a graph to be set.") for col in self._graph.get_root_columns(): - strategy = self._graph.get_strategy(col) for rg_id, rg_size in self._row_group_sizes.items(): - if strategy == GenerationStrategy.CELL_BY_CELL: - for ri in range(rg_size): - self._frontier.add(Task(column=col, row_group=rg_id, row_index=ri, task_type="cell")) - else: - self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch")) + self.add_root_tasks(rg_id, rg_size, columns=(col,)) + + def add_root_tasks( + self, + row_group: int, + row_group_size: int, + *, + columns: tuple[str, ...] | None = None, + ) -> FrontierDelta: + """Add root/from-scratch tasks for one admitted row group.""" + if self._graph is None: + raise RuntimeError("This method requires a graph to be set.") + expected = self._validate_row_group(row_group) + if expected is not None and expected != row_group_size: + raise ValueError(f"Row-group size mismatch for rg={row_group}: got {row_group_size}, expected {expected}") + root_columns = columns or tuple(self._graph.get_root_columns()) + added: list[Task] = [] + for col in root_columns: + strategy = self._graph.get_strategy(col) + if strategy == GenerationStrategy.CELL_BY_CELL: + for ri in range(row_group_size): + task = Task(column=col, row_group=row_group, row_index=ri, task_type="cell") + if self._add_frontier_task(task): + added.append(task) + else: + task = Task(column=col, row_group=row_group, row_index=None, task_type="from_scratch") + if self._add_frontier_task(task): + added.append(task) + return self._record_delta(added=added, removed=[]) def _record_delta(self, *, added: list[Task], removed: list[Task]) -> FrontierDelta: return FrontierDelta(added=tuple(added), removed=tuple(removed)) @@ -204,7 +246,7 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None rg_batch_complete = self._batch_complete.get(row_group, set()) rg_size = self._row_group_sizes[row_group] - for down in self._graph.get_downstream_columns(column): + for down in sorted(self._graph.get_downstream_columns(column)): batch_ups, cell_ups = self._graph.split_upstream_by_strategy(down) if any(up not in rg_batch_complete for up in batch_ups): diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/queue.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/queue.py new file mode 100644 index 000000000..2cdd99b36 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/queue.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import heapq +from collections import Counter, defaultdict, deque +from collections.abc import Callable, Iterable, Mapping +from dataclasses import dataclass + +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceKey, + TaskGroupKey, + TaskGroupSpec, +) + + +@dataclass(frozen=True) +class QueueView: + """Read-only queue facts supplied to task admission policies.""" + + queued_total: int + queued_by_group: Mapping[TaskGroupKey, int] + queued_resource_demand_by_group: Mapping[TaskGroupKey, Mapping[SchedulerResourceKey, int]] + first_candidate_resources_by_group: Mapping[TaskGroupKey, Mapping[SchedulerResourceKey, int]] + first_candidate_tasks_by_group: Mapping[TaskGroupKey, SchedulableTask] + first_candidate_group_specs_by_group: Mapping[TaskGroupKey, TaskGroupSpec] + queued_peer_demand_by_resource: Mapping[SchedulerResourceKey, int] + + +@dataclass(frozen=True) +class QueueSelection: + """Non-mutating fair-queue selection returned to the scheduler.""" + + item: SchedulableTask + queue_view: QueueView + sequence_version: int + + +class FairTaskQueue: + """Virtual-time fair queue that owns ready membership and ordering only.""" + + def __init__(self) -> None: + self._queues: dict[TaskGroupKey, deque[SchedulableTask]] = {} + self._queued: dict[str, SchedulableTask] = {} + self._task_groups: dict[str, TaskGroupKey] = {} + self._group_specs: dict[TaskGroupKey, TaskGroupSpec] = {} + self._group_finish: dict[TaskGroupKey, float] = {} + self._heap: list[tuple[float, int, TaskGroupKey]] = [] + self._active_heap_keys: set[TaskGroupKey] = set() + self._active_heap_entries: dict[TaskGroupKey, tuple[float, int]] = {} + self._sequence = 0 + self._sequence_version = 0 + self._virtual_time = 0.0 + + @property + def has_queued_tasks(self) -> bool: + return bool(self._queued) + + def enqueue(self, items: Iterable[SchedulableTask]) -> tuple[str, ...]: + """Add ready tasks idempotently and return newly accepted task ids.""" + accepted: list[str] = [] + for item in items: + if item.task_id in self._queued: + continue + self._group_specs[item.group.key] = item.group + queue = self._queues.setdefault(item.group.key, deque()) + queue.append(item) + self._queued[item.task_id] = item + self._task_groups[item.task_id] = item.group.key + self._activate_group(item.group.key) + accepted.append(item.task_id) + if accepted: + self._sequence_version += 1 + return tuple(accepted) + + def discard(self, task_id: str) -> None: + """Remove a queued task lazily if it is no longer dispatchable.""" + if task_id in self._queued: + self._sequence_version += 1 + self._queued.pop(task_id, None) + self._task_groups.pop(task_id, None) + + def discard_where(self, predicate: Callable[[SchedulableTask], bool]) -> None: + """Remove queued tasks matching a predicate.""" + for task_id, item in tuple(self._queued.items()): + if predicate(item): + self.discard(task_id) + + def select_next(self, is_eligible: Callable[[SchedulableTask, QueueView], bool]) -> QueueSelection | None: + """Return the next eligible task without mutating queue state.""" + view = self.view() + heap_copy = list(self._heap) + heapq.heapify(heap_copy) + active_seen: set[TaskGroupKey] = set() + while heap_copy: + finish, sequence, key = heapq.heappop(heap_copy) + if key in active_seen: + continue + if self._active_heap_entries.get(key) != (finish, sequence): + continue + active_seen.add(key) + item = self._first_valid_item(key) + if item is None: + continue + if not is_eligible(item, view): + continue + return QueueSelection(item=item, queue_view=view, sequence_version=self._sequence_version) + return None + + def commit(self, selection: QueueSelection) -> SchedulableTask | None: + """Remove a previously selected task and advance fair-queue state.""" + if selection.sequence_version != self._sequence_version: + return None + item = selection.item + key = self._task_groups.get(item.task_id) + if key is None or key != item.group.key: + return None + queue = self._queues.get(key) + if queue is None: + return None + self._purge_queue_head(key) + if not queue or queue[0].task_id != item.task_id: + return None + + queue.popleft() + self._queued.pop(item.task_id, None) + self._task_groups.pop(item.task_id, None) + self._active_heap_keys.discard(key) + self._active_heap_entries.pop(key, None) + group = self._group_specs[key] + finish = self._group_finish.get(key, self._virtual_time) + self._virtual_time = max(self._virtual_time, finish) + self._group_finish[key] = self._virtual_time + (1.0 / max(group.weight, 1.0)) + self._sequence_version += 1 + self._purge_queue_head(key) + if queue: + self._activate_group(key) + return item + + def view(self) -> QueueView: + queued_by_group: Counter[TaskGroupKey] = Counter() + demand_by_group: dict[TaskGroupKey, dict[SchedulerResourceKey, int]] = defaultdict(lambda: defaultdict(int)) + first_by_group: dict[TaskGroupKey, Mapping[SchedulerResourceKey, int]] = {} + first_tasks_by_group: dict[TaskGroupKey, SchedulableTask] = {} + first_group_specs: dict[TaskGroupKey, TaskGroupSpec] = {} + demand_by_resource: Counter[SchedulerResourceKey] = Counter() + + for item in self._queued.values(): + key = item.group.key + queued_by_group[key] += 1 + for resource, amount in item.resource_request.amounts.items(): + demand_by_group[key][resource] += amount + demand_by_resource[resource] += amount + + for key, queue in self._queues.items(): + first = self._first_valid_item(key) + if first is not None: + first_by_group[key] = dict(first.resource_request.amounts) + first_tasks_by_group[key] = first + first_group_specs[key] = first.group + + return QueueView( + queued_total=len(self._queued), + queued_by_group=dict(queued_by_group), + queued_resource_demand_by_group={key: dict(value) for key, value in demand_by_group.items()}, + first_candidate_resources_by_group=first_by_group, + first_candidate_tasks_by_group=first_tasks_by_group, + first_candidate_group_specs_by_group=first_group_specs, + queued_peer_demand_by_resource=dict(demand_by_resource), + ) + + def _activate_group(self, key: TaskGroupKey) -> None: + self._purge_queue_head(key) + queue = self._queues.get(key) + if not queue or key in self._active_heap_keys: + return + self._sequence += 1 + finish = self._group_finish.get(key, self._virtual_time) + heapq.heappush(self._heap, (finish, self._sequence, key)) + self._active_heap_keys.add(key) + self._active_heap_entries[key] = (finish, self._sequence) + + def _first_valid_item(self, key: TaskGroupKey) -> SchedulableTask | None: + queue = self._queues.get(key) + if queue is None: + return None + for item in queue: + if item.task_id in self._queued and self._task_groups.get(item.task_id) == key: + return item + return None + + def _purge_queue_head(self, key: TaskGroupKey) -> None: + queue = self._queues.get(key) + if queue is None: + return + while queue: + item = queue[0] + if item.task_id in self._queued and self._task_groups.get(item.task_id) == key: + break + queue.popleft() diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resolver.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resolver.py new file mode 100644 index 000000000..c2f61e1e1 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resolver.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceRequest, + TaskGroupKey, + TaskGroupSpec, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey + +if TYPE_CHECKING: + from data_designer.engine.column_generators.generators.base import ColumnGenerator + + +@dataclass(frozen=True) +class ResolvedTaskScheduling: + """Scheduler inputs resolved from generator-facing metadata.""" + + group: TaskGroupSpec + resource_request: SchedulerResourceRequest + request_resource_key: RequestResourceKey | None = None + + +class TaskSchedulingResolver: + """Resolve generator metadata into scheduler-internal task inputs.""" + + def __init__( + self, + generators: Mapping[str, ColumnGenerator], + *, + model_group_limit_multiplier: int = 2, + model_group_limit_cap: int = 256, + ) -> None: + self._generators = generators + self._model_group_limit_multiplier = model_group_limit_multiplier + self._model_group_limit_cap = model_group_limit_cap + self._metadata_by_generator_id: dict[int, SchedulingMetadata] = {} + self._diagnostics: list[dict[str, object]] = [] + for generator in dict.fromkeys(generators.values()): + self._metadata_by_generator_id[id(generator)] = self._resolve_metadata(generator) + + @property + def diagnostics(self) -> tuple[dict[str, object], ...]: + return tuple(self._diagnostics) + + def scheduling_for_task(self, task: Task, flow_identity: tuple[str, ...]) -> ResolvedTaskScheduling: + generator = self._generators[task.column] + metadata = self._metadata_by_generator_id[id(generator)] + return self._resolved_from_metadata(metadata, flow_identity) + + def schedulable_task(self, task: Task, flow_identity: tuple[str, ...]) -> SchedulableTask: + resolved = self.scheduling_for_task(task, flow_identity) + return SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=resolved.group, + resource_request=resolved.resource_request, + request_resource_key=resolved.request_resource_key, + ) + + def _resolve_metadata(self, generator: ColumnGenerator) -> SchedulingMetadata: + try: + return generator.get_scheduling_metadata() + except SchedulingMetadataError as exc: + if exc.fallback is None: + raise + self._diagnostics.append( + { + "code": exc.code, + "message": exc.message, + "fallback": exc.fallback.identity, + "diagnostics": exc.diagnostics, + } + ) + return exc.fallback + + def _resolved_from_metadata( + self, + metadata: SchedulingMetadata, + flow_identity: tuple[str, ...], + ) -> ResolvedTaskScheduling: + weight = max(1, metadata.weight) + if metadata.kind == "local": + key = TaskGroupKey(kind="local", identity=(*metadata.identity, *flow_identity)) + return ResolvedTaskScheduling( + group=TaskGroupSpec(key=key, weight=float(weight)), + resource_request=SchedulerResourceRequest({"submission": 1}), + ) + + identity = (*metadata.identity, *flow_identity) + admitted_limit = max(1, min(self._model_group_limit_cap, self._model_group_limit_multiplier * weight)) + request_resource_key = _request_resource_key(metadata) + return ResolvedTaskScheduling( + group=TaskGroupSpec( + key=TaskGroupKey(kind=metadata.kind, identity=identity), + weight=float(weight), + admitted_limit=admitted_limit, + ), + resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 1}), + request_resource_key=request_resource_key, + ) + + +def _request_resource_key(metadata: SchedulingMetadata) -> RequestResourceKey | None: + if metadata.kind != "model": + return None + _kind, provider_name, model_id, generation_kind = metadata.identity + try: + domain = RequestDomain(generation_kind) + except ValueError: + return None + return RequestResourceKey(provider_name=provider_name, model_id=model_id, domain=domain) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resources.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resources.py new file mode 100644 index 000000000..35a0ec18f --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resources.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import hashlib +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Literal + +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.models.request_admission.resources import RequestResourceKey + +SchedulerResourceKey = Literal["submission", "llm_wait", "local"] + + +@dataclass(frozen=True, order=True) +class TaskGroupKey: + """Stable identity for a stream of related scheduler tasks.""" + + kind: Literal["model", "custom_model", "local"] + identity: tuple[str, ...] + + +@dataclass(frozen=True) +class TaskGroupSpec: + """Scheduler-internal task group metadata.""" + + key: TaskGroupKey + weight: float = 1.0 + admitted_limit: int | None = None + + +@dataclass(frozen=True) +class SchedulerResourceRequest: + """Scheduler task-stage resource request.""" + + amounts: Mapping[SchedulerResourceKey, int] = field(default_factory=lambda: {"submission": 1}) + + def __post_init__(self) -> None: + for resource, amount in self.amounts.items(): + if resource not in {"submission", "llm_wait", "local"}: + raise ValueError(f"Unknown scheduler resource key: {resource!r}") + if not isinstance(amount, int) or amount <= 0: + raise ValueError(f"Scheduler resource amount for {resource!r} must be a positive integer.") + + +@dataclass(frozen=True) +class SchedulableTask: + """Ready task plus scheduler-owned grouping and resource request.""" + + task_id: str + payload: Task + group: TaskGroupSpec + resource_request: SchedulerResourceRequest + request_resource_key: RequestResourceKey | None = None + + +def stable_task_id(task: Task) -> str: + """Return a stable scheduler task id for queue/admission membership.""" + raw = f"{task.column}\0{task.row_group}\0{task.row_index}\0{task.task_type}".encode() + digest = hashlib.sha1(raw).hexdigest()[:16] + return f"task-{digest}" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_admission.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_admission.py new file mode 100644 index 000000000..89fb3e280 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_admission.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import time +import uuid +from collections import Counter, defaultdict, deque +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Literal + +from data_designer.engine.dataset_builders.scheduling.queue import QueueView +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceKey, + SchedulerResourceRequest, + TaskGroupKey, +) +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.dataset_builders.scheduling.task_policies import ( + BoundedBorrowTaskAdmissionPolicy, + BoundedBorrowTaskAdmissionPolicyConfig, + PolicyStateDelta, + StrictFairTaskAdmissionPolicy, + TaskAdmissionDenyReason, + TaskAdmissionPolicy, + TaskAdmissionPolicyDecision, +) + +ReleaseReason = Literal[ + "released", + "duplicate", + "stale_lease", + "wrong_controller_generation", + "unknown_lease", +] +RELEASED_TASK_LEASE_HISTORY_LIMIT = 8192 + + +@dataclass(frozen=True) +class TaskAdmissionConfig: + """Engine-internal scheduler task-stage admission configuration.""" + + submission_capacity: int = 256 + resource_limits: Mapping[SchedulerResourceKey, int] = field(default_factory=dict) + bounded_borrow: BoundedBorrowTaskAdmissionPolicyConfig | None = None + + def __post_init__(self) -> None: + if self.submission_capacity <= 0: + raise ValueError("submission_capacity must be positive.") + merged = {"submission": self.submission_capacity, **self.resource_limits} + for resource, limit in merged.items(): + if limit <= 0: + raise ValueError(f"Task admission limit for {resource!r} must be positive.") + object.__setattr__(self, "resource_limits", merged) + + +@dataclass(frozen=True) +class TaskAdmissionView: + resource_limits: Mapping[SchedulerResourceKey, int] + resources_available: Mapping[SchedulerResourceKey, int] + leased_resources: Mapping[SchedulerResourceKey, int] + leased_resources_by_group: Mapping[TaskGroupKey, Mapping[SchedulerResourceKey, int]] + running_counts_by_group: Mapping[TaskGroupKey, int] + policy_debt_by_group_resource: Mapping[tuple[TaskGroupKey, SchedulerResourceKey], int] + + +@dataclass(frozen=True) +class TaskAdmissionLease: + lease_id: str + item: SchedulableTask + resources: Mapping[SchedulerResourceKey, int] + acquired_at: float + controller_generation: str + + +@dataclass(frozen=True) +class TaskAdmissionDenied: + item: SchedulableTask + reason: TaskAdmissionDenyReason + available_after: float | None = None + snapshot: TaskAdmissionView | None = None + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +TaskAdmissionDecision = TaskAdmissionLease | TaskAdmissionDenied + + +@dataclass(frozen=True) +class ReleaseResult: + released: bool + reason: ReleaseReason + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class TaskAdmissionBlockSummary: + queued_count: int + dominant_denial_reasons: Mapping[TaskAdmissionDenyReason, int] + available_after: float | None = None + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +class TaskAdmissionController: + """Owns scheduler-level task leases and resource accounting.""" + + def __init__( + self, + config: TaskAdmissionConfig | None = None, + policy: TaskAdmissionPolicy | None = None, + ) -> None: + self._config = config or TaskAdmissionConfig() + self._generation = uuid.uuid4().hex + self._leases: dict[str, TaskAdmissionLease] = {} + self._released: set[str] = set() + self._released_order: deque[str] = deque(maxlen=RELEASED_TASK_LEASE_HISTORY_LIMIT) + self._leased_by_resource: Counter[SchedulerResourceKey] = Counter() + self._leased_by_group: dict[TaskGroupKey, Counter[SchedulerResourceKey]] = defaultdict(Counter) + self._running_by_group: Counter[TaskGroupKey] = Counter() + self._policy_debt: Counter[tuple[TaskGroupKey, SchedulerResourceKey]] = Counter() + self._release_diagnostics: Counter[str] = Counter() + if policy is not None: + self._policy = policy + elif self._config.bounded_borrow is not None: + self._policy = BoundedBorrowTaskAdmissionPolicy(self._config.bounded_borrow) + else: + self._policy = StrictFairTaskAdmissionPolicy() + + def is_eligible(self, item: SchedulableTask, queue_view: QueueView) -> bool: + return not isinstance(self.try_evaluate(item, queue_view), TaskAdmissionDenied) + + def try_evaluate( + self, item: SchedulableTask, queue_view: QueueView + ) -> TaskAdmissionPolicyDecision | TaskAdmissionDenied: + view = self.view() + missing = self._missing_resources(item, view) + if missing: + return TaskAdmissionDenied( + item=item, + reason="no_capacity", + snapshot=view, + diagnostics={"missing_resources": missing}, + ) + decision = self._policy.evaluate(item, queue_view, view) + if not decision.allowed: + return TaskAdmissionDenied( + item=item, + reason=decision.reason or "policy_denial", + available_after=decision.available_after, + snapshot=view, + diagnostics=decision.diagnostics, + ) + return decision + + def try_acquire(self, item: SchedulableTask, queue_view: QueueView) -> TaskAdmissionDecision: + evaluated = self.try_evaluate(item, queue_view) + if isinstance(evaluated, TaskAdmissionDenied): + return evaluated + lease = TaskAdmissionLease( + lease_id=uuid.uuid4().hex, + item=item, + resources=dict(item.resource_request.amounts), + acquired_at=time.monotonic(), + controller_generation=self._generation, + ) + for resource, amount in lease.resources.items(): + self._leased_by_resource[resource] += amount + self._leased_by_group[item.group.key][resource] += amount + self._running_by_group[item.group.key] += 1 + self._apply_delta(self._policy.on_acquire(lease, evaluated)) + self._leases[lease.lease_id] = lease + return lease + + def release(self, lease: TaskAdmissionLease) -> ReleaseResult: + if lease.controller_generation != self._generation: + self._release_diagnostics["wrong_controller_generation"] += 1 + return ReleaseResult(released=False, reason="wrong_controller_generation") + active = self._leases.pop(lease.lease_id, None) + if active is None: + reason: ReleaseReason = "duplicate" if lease.lease_id in self._released else "unknown_lease" + self._release_diagnostics[reason] += 1 + return ReleaseResult(released=False, reason=reason) + if active.item.task_id != lease.item.task_id: + self._leases[lease.lease_id] = active + self._release_diagnostics["stale_lease"] += 1 + return ReleaseResult(released=False, reason="stale_lease") + + self._remember_released(lease.lease_id) + for resource, amount in active.resources.items(): + self._leased_by_resource[resource] = max(0, self._leased_by_resource[resource] - amount) + self._leased_by_group[active.item.group.key][resource] = max( + 0, + self._leased_by_group[active.item.group.key][resource] - amount, + ) + self._running_by_group[active.item.group.key] = max(0, self._running_by_group[active.item.group.key] - 1) + self._apply_delta(self._policy.on_release(active)) + return ReleaseResult(released=True, reason="released") + + def view(self) -> TaskAdmissionView: + limits = dict(self._config.resource_limits) + leased = {resource: count for resource, count in self._leased_by_resource.items() if count > 0} + available = { + resource: max(0, limit - self._leased_by_resource.get(resource, 0)) for resource, limit in limits.items() + } + return TaskAdmissionView( + resource_limits=limits, + resources_available=available, + leased_resources=leased, + leased_resources_by_group={ + group: {resource: count for resource, count in counts.items() if count > 0} + for group, counts in self._leased_by_group.items() + }, + running_counts_by_group={group: count for group, count in self._running_by_group.items() if count > 0}, + policy_debt_by_group_resource={key: count for key, count in self._policy_debt.items() if count > 0}, + ) + + def explain_blocked(self, queue_view: QueueView) -> TaskAdmissionBlockSummary: + reasons: Counter[TaskAdmissionDenyReason] = Counter() + available_after_values: list[float] = [] + view = self.view() + for group_key, resources in queue_view.first_candidate_resources_by_group.items(): + for resource, amount in resources.items(): + if view.resources_available.get(resource, 0) < amount: + reasons["no_capacity"] += 1 + break + else: + group = queue_view.first_candidate_group_specs_by_group.get(group_key) + if group is None: + continue + task = SchedulableTask( + task_id=f"blocked-{group_key.kind}-{'-'.join(group_key.identity)}", + payload=Task(column="", row_group=-1, row_index=None, task_type="batch"), + group=group, + resource_request=SchedulerResourceRequest(dict(resources)), + ) + decision = self._policy.evaluate(task, queue_view, view) + if not decision.allowed: + reasons[decision.reason or "policy_denial"] += 1 + if decision.available_after is not None: + available_after_values.append(decision.available_after) + return TaskAdmissionBlockSummary( + queued_count=queue_view.queued_total, + dominant_denial_reasons=dict(reasons), + available_after=min(available_after_values) if available_after_values else None, + diagnostics={"snapshot": self.view()}, + ) + + def _missing_resources( + self, + item: SchedulableTask, + view: TaskAdmissionView, + ) -> dict[SchedulerResourceKey, dict[str, int]]: + missing: dict[SchedulerResourceKey, dict[str, int]] = {} + for resource, amount in item.resource_request.amounts.items(): + available = view.resources_available.get(resource, 0) + if available < amount: + missing[resource] = {"requested": amount, "available": available} + return missing + + def _apply_delta(self, delta: PolicyStateDelta) -> None: + for key, change in delta.debt_changes.items(): + self._policy_debt[key] = max(0, self._policy_debt[key] + change) + + def _remember_released(self, lease_id: str) -> None: + if lease_id in self._released: + return + maxlen = self._released_order.maxlen + if maxlen is not None and len(self._released_order) >= maxlen: + self._released.discard(self._released_order[0]) + self._released.add(lease_id) + self._released_order.append(lease_id) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py similarity index 100% rename from packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py rename to packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_policies.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_policies.py new file mode 100644 index 000000000..72d8b46f5 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_policies.py @@ -0,0 +1,226 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Literal, Protocol + +from data_designer.engine.dataset_builders.scheduling.queue import QueueView +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceKey, + TaskGroupKey, +) + +if TYPE_CHECKING: + from data_designer.engine.dataset_builders.scheduling.task_admission import ( + TaskAdmissionLease, + TaskAdmissionView, + ) + +TaskAdmissionDenyReason = Literal[ + "no_capacity", + "group_cap", + "borrow_debt", + "shutdown", + "policy_denial", +] + + +@dataclass(frozen=True) +class BoundedBorrowTaskAdmissionPolicyConfig: + """Engine-internal bounded-borrow policy configuration.""" + + borrow_ceiling_by_group_resource: Mapping[tuple[TaskGroupKey, SchedulerResourceKey], int] = field( + default_factory=dict + ) + default_borrow_ceiling: int = 0 + strict_share_rounding: Literal["floor", "ceil"] = "floor" + repay_on_withheld_peer_pressure: bool = True + + +@dataclass(frozen=True) +class TaskAdmissionPolicyDecision: + allowed: bool + reason: TaskAdmissionDenyReason | None = None + available_after: float | None = None + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class PolicyStateDelta: + debt_changes: Mapping[tuple[TaskGroupKey, SchedulerResourceKey], int] = field(default_factory=dict) + diagnostic_counters: Mapping[str, int] = field(default_factory=dict) + + +class TaskAdmissionPolicy(Protocol): + def evaluate( + self, + item: SchedulableTask, + queue_view: QueueView, + admission_view: TaskAdmissionView, + ) -> TaskAdmissionPolicyDecision: ... + + def on_acquire( + self, + lease: TaskAdmissionLease, + decision: TaskAdmissionPolicyDecision, + ) -> PolicyStateDelta: ... + + def on_release(self, lease: TaskAdmissionLease) -> PolicyStateDelta: ... + + +class StrictFairTaskAdmissionPolicy: + """Behavior-preserving policy that enforces per-group admitted caps.""" + + def evaluate( + self, + item: SchedulableTask, + queue_view: QueueView, + admission_view: TaskAdmissionView, + ) -> TaskAdmissionPolicyDecision: + if item.group.admitted_limit is None: + return TaskAdmissionPolicyDecision(allowed=True) + leased_count = admission_view.running_counts_by_group.get(item.group.key, 0) + if leased_count < item.group.admitted_limit: + return TaskAdmissionPolicyDecision(allowed=True) + pressure_resources = _queued_peer_pressure_resources(item, queue_view, admission_view) + if not pressure_resources: + return TaskAdmissionPolicyDecision(allowed=True) + return TaskAdmissionPolicyDecision( + allowed=False, + reason="group_cap", + diagnostics={ + "admitted_limit": item.group.admitted_limit, + "leased_count": leased_count, + "pressure_resources": pressure_resources, + }, + ) + + def on_acquire( + self, + lease: TaskAdmissionLease, + decision: TaskAdmissionPolicyDecision, + ) -> PolicyStateDelta: + return PolicyStateDelta() + + def on_release(self, lease: TaskAdmissionLease) -> PolicyStateDelta: + return PolicyStateDelta() + + +class BoundedBorrowTaskAdmissionPolicy(StrictFairTaskAdmissionPolicy): + """Strict policy with optional bounded borrow debt over peer pressure.""" + + def __init__(self, config: BoundedBorrowTaskAdmissionPolicyConfig) -> None: + self._config = config + + def evaluate( + self, + item: SchedulableTask, + queue_view: QueueView, + admission_view: TaskAdmissionView, + ) -> TaskAdmissionPolicyDecision: + limit = item.group.admitted_limit + if limit is None: + return TaskAdmissionPolicyDecision(allowed=True) + + leased_count = admission_view.running_counts_by_group.get(item.group.key, 0) + if leased_count < limit: + return TaskAdmissionPolicyDecision(allowed=True) + + pressure_resources = _queued_peer_pressure_resources(item, queue_view, admission_view) + if pressure_resources: + for resource in pressure_resources: + debt_key = (item.group.key, resource) + debt = admission_view.policy_debt_by_group_resource.get(debt_key, 0) + if debt > 0: + return TaskAdmissionPolicyDecision( + allowed=False, + reason="borrow_debt", + diagnostics={"resource": resource, "debt": debt}, + ) + return TaskAdmissionPolicyDecision( + allowed=False, + reason="group_cap", + diagnostics={ + "admitted_limit": limit, + "leased_count": leased_count, + "pressure_resources": pressure_resources, + }, + ) + + borrow_resources: list[tuple[SchedulerResourceKey, int]] = [] + for resource, amount in item.resource_request.amounts.items(): + debt_key = (item.group.key, resource) + debt = admission_view.policy_debt_by_group_resource.get(debt_key, 0) + ceiling = self._config.borrow_ceiling_by_group_resource.get( + debt_key, + self._config.default_borrow_ceiling, + ) + if debt + amount > ceiling: + return TaskAdmissionPolicyDecision( + allowed=False, + reason="borrow_debt", + diagnostics={"resource": resource, "debt": debt, "requested": amount, "ceiling": ceiling}, + ) + borrow_resources.append((resource, amount)) + return TaskAdmissionPolicyDecision(allowed=True, diagnostics={"borrow_resources": tuple(borrow_resources)}) + + def on_acquire( + self, + lease: TaskAdmissionLease, + decision: TaskAdmissionPolicyDecision, + ) -> PolicyStateDelta: + borrow_resources = decision.diagnostics.get("borrow_resources") + if borrow_resources: + changes = { + (lease.item.group.key, resource): amount + for resource, amount in borrow_resources + if isinstance(resource, str) and isinstance(amount, int) + } + return PolicyStateDelta(debt_changes=changes) + return PolicyStateDelta() + + def on_release(self, lease: TaskAdmissionLease) -> PolicyStateDelta: + if not self._config.repay_on_withheld_peer_pressure: + return PolicyStateDelta() + # Borrow debt is group-level: any completed lease in the group repays it, clamped to zero by the controller. + return PolicyStateDelta( + debt_changes={(lease.item.group.key, resource): -amount for resource, amount in lease.resources.items()} + ) + + +def _queued_peer_pressure_resources( + item: SchedulableTask, + queue_view: QueueView, + admission_view: TaskAdmissionView, +) -> tuple[SchedulerResourceKey, ...]: + candidate_resources = _fair_pressure_resources(item.resource_request.amounts) + pressure_resources: list[SchedulerResourceKey] = [] + for group_key, peer_resources in queue_view.first_candidate_resources_by_group.items(): + if group_key == item.group.key: + continue + if not _is_hard_resource_eligible(peer_resources, admission_view): + continue + for resource in candidate_resources: + if peer_resources.get(resource, 0) > 0 and resource not in pressure_resources: + pressure_resources.append(resource) + return tuple(pressure_resources) + + +def _fair_pressure_resources( + resources: Mapping[SchedulerResourceKey, int], +) -> tuple[SchedulerResourceKey, ...]: + typed_resources = tuple(resource for resource in resources if resource != "submission") + if typed_resources: + return typed_resources + return tuple(resources) + + +def _is_hard_resource_eligible( + resources: Mapping[SchedulerResourceKey, int], + admission_view: TaskAdmissionView, +) -> bool: + return all(admission_view.resources_available.get(resource, 0) >= amount for resource, amount in resources.items()) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index b090cf63d..29b7d99bc 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -15,8 +15,8 @@ DatasetBuilderColumnConfigT, MultiColumnConfig, ) +from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError -from data_designer.engine.dataset_builders.utils.task_model import SliceRef from data_designer.logging import LOG_INDENT logger = logging.getLogger(__name__) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/fair_task_queue.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/fair_task_queue.py deleted file mode 100644 index 32301b767..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/fair_task_queue.py +++ /dev/null @@ -1,156 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import heapq -from collections import deque -from collections.abc import Callable -from dataclasses import dataclass -from typing import Literal - -from data_designer.engine.dataset_builders.utils.task_model import Task - - -@dataclass(frozen=True, order=True) -class TaskGroupKey: - """Stable identity for a stream of related scheduler tasks.""" - - kind: Literal["model", "custom_model", "local"] - identity: tuple[str, ...] - - -@dataclass(frozen=True) -class TaskGroupSpec: - """Scheduling metadata for a task group.""" - - key: TaskGroupKey - weight: float = 1.0 - admitted_limit: int | None = None - - -@dataclass(frozen=True) -class TaskSelection: - """A task selected for dispatch with the group metadata used to choose it.""" - - task: Task - group: TaskGroupSpec - - -class FairTaskQueue: - """Virtual-time fair queue with peer-sensitive per-group FIFO admission limits.""" - - def __init__(self) -> None: - self._queues: dict[TaskGroupKey, deque[Task]] = {} - self._queued: set[Task] = set() - self._task_groups: dict[Task, TaskGroupKey] = {} - self._group_specs: dict[TaskGroupKey, TaskGroupSpec] = {} - self._group_finish: dict[TaskGroupKey, float] = {} - self._admitted_by_group: dict[TaskGroupKey, int] = {} - self._admitted_task_groups: dict[Task, TaskGroupKey] = {} - self._heap: list[tuple[float, int, TaskGroupKey]] = [] - self._active_heap_keys: set[TaskGroupKey] = set() - self._sequence = 0 - self._virtual_time = 0.0 - - @property - def has_queued_tasks(self) -> bool: - return bool(self._queued) - - def enqueue(self, task: Task, group: TaskGroupSpec) -> None: - """Add one ready task to its fair scheduling group.""" - self._group_specs[group.key] = group - if task in self._queued: - return - queue = self._queues.setdefault(group.key, deque()) - queue.append(task) - self._queued.add(task) - self._task_groups[task] = group.key - self._activate_group(group.key) - - def discard(self, task: Task) -> None: - """Remove a queued task lazily if it is no longer dispatchable.""" - self._queued.discard(task) - self._task_groups.pop(task, None) - - def discard_where(self, predicate: Callable[[Task], bool]) -> None: - """Remove queued tasks matching a predicate.""" - for task in tuple(self._queued): - if predicate(task): - self.discard(task) - - def admit_next(self) -> TaskSelection | None: - """Admit the next eligible task, or ``None`` if no queued group can run.""" - blocked: list[TaskGroupKey] = [] - try: - while self._heap: - finish, _, key = heapq.heappop(self._heap) - self._active_heap_keys.discard(key) - self._purge_queue_head(key) - queue = self._queues.get(key) - if not queue: - continue - if not self._can_admit_group(key): - blocked.append(key) - continue - - task = queue.popleft() - self._queued.discard(task) - self._task_groups.pop(task, None) - self._admitted_task_groups[task] = key - self._admitted_by_group[key] = self._admitted_by_group.get(key, 0) + 1 - - group = self._group_specs[key] - self._virtual_time = max(self._virtual_time, finish) - self._group_finish[key] = self._virtual_time + (1.0 / max(group.weight, 1.0)) - self._purge_queue_head(key) - if queue: - self._activate_group(key) - return TaskSelection(task=task, group=group) - return None - finally: - for key in blocked: - self._activate_group(key) - - def release(self, task: Task) -> None: - """Release one previously admitted task from its group limit.""" - key = self._admitted_task_groups.pop(task, None) - if key is None: - return - admitted = self._admitted_by_group.get(key, 0) - if admitted <= 1: - self._admitted_by_group.pop(key, None) - else: - self._admitted_by_group[key] = admitted - 1 - self._activate_group(key) - - def _activate_group(self, key: TaskGroupKey) -> None: - self._purge_queue_head(key) - queue = self._queues.get(key) - if not queue or key in self._active_heap_keys: - return - self._sequence += 1 - finish = self._group_finish.get(key, self._virtual_time) - heapq.heappush(self._heap, (finish, self._sequence, key)) - self._active_heap_keys.add(key) - - def _purge_queue_head(self, key: TaskGroupKey) -> None: - queue = self._queues.get(key) - if queue is None: - return - while queue: - task = queue[0] - if task in self._queued and self._task_groups.get(task) == key: - break - queue.popleft() - - def _can_admit_group(self, key: TaskGroupKey) -> bool: - group = self._group_specs[key] - if group.admitted_limit is None: - return True - if self._admitted_by_group.get(key, 0) < group.admitted_limit: - return True - return not self._has_queued_peer_group(key) - - def _has_queued_peer_group(self, key: TaskGroupKey) -> bool: - return any(queued_key != key for queued_key in self._task_groups.values()) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/scheduling_hints.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/scheduling_hints.py deleted file mode 100644 index dea66eeda..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/scheduling_hints.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import logging -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal - -if TYPE_CHECKING: - from data_designer.engine.column_generators.generators.base import ColumnGenerator - -logger = logging.getLogger(__name__) - -SchedulingGroupKind = Literal["local", "model", "custom_model"] - - -@dataclass(frozen=True) -class SchedulingHint: - """Resolved task-scheduling metadata independent of graph flow identity.""" - - group_kind: SchedulingGroupKind - identity_prefix: tuple[str, ...] = () - identity_suffix: tuple[str, ...] = () - weight: int = 1 - - -class SchedulingHintResolver: - """Resolve generator/config/model metadata once for a scheduler run.""" - - def __init__(self, generators: dict[str, ColumnGenerator]) -> None: - self._hints_by_generator_id: dict[int, SchedulingHint] = {} - for column, generator in generators.items(): - generator_id = id(generator) - if generator_id not in self._hints_by_generator_id: - self._hints_by_generator_id[generator_id] = self._resolve_hint(column, generator) - - def hint_for(self, generator: ColumnGenerator) -> SchedulingHint: - return self._hints_by_generator_id[id(generator)] - - def _resolve_hint(self, column: str, generator: ColumnGenerator) -> SchedulingHint: - if not generator.is_llm_bound: - return SchedulingHint(group_kind="local") - - aliases = _model_aliases_for_generator(generator) - if not aliases: - return SchedulingHint(group_kind="model", identity_prefix=("unknown",), weight=1) - - model_parts: list[str] = [] - total_parallel = 0 - primary_alias = getattr(generator.config, "model_alias", None) - for alias in aliases: - try: - model_config = _get_model_config_for_alias(generator, alias) - provider_name = _get_model_provider_name_for_alias(generator, alias) - except Exception: - logger.debug( - "Falling back to custom-model scheduling group for column %r after failing to resolve " - "model alias %r from aliases %r.", - column, - alias, - aliases, - exc_info=True, - ) - return SchedulingHint( - group_kind="custom_model", - identity_suffix=tuple(sorted(aliases)), - weight=max(1, total_parallel), - ) - - max_parallel = getattr(model_config.inference_parameters, "max_parallel_requests", 1) - if not isinstance(max_parallel, int): - max_parallel = 1 - model_parts.extend( - ( - provider_name, - str(model_config.model), - str(model_config.generation_type), - alias, - ) - ) - total_parallel += max_parallel - - weight = max(1, total_parallel) - if len(aliases) == 1 and primary_alias == aliases[0]: - return SchedulingHint( - group_kind="model", - identity_prefix=tuple(model_parts[:3]), - weight=weight, - ) - - return SchedulingHint( - group_kind="custom_model", - identity_suffix=tuple(sorted(aliases)), - weight=weight, - ) - - -def _get_model_config_for_alias(generator: ColumnGenerator, alias: str) -> Any: - get_model_config = getattr(generator, "get_model_config", None) - if callable(get_model_config): - return get_model_config(model_alias=alias) - return generator.resource_provider.model_registry.get_model_config(model_alias=alias) - - -def _get_model_provider_name_for_alias(generator: ColumnGenerator, alias: str) -> str: - get_provider_name = getattr(generator, "get_model_provider_name", None) - if callable(get_provider_name): - return str(get_provider_name(model_alias=alias)) - provider = generator.resource_provider.model_registry.get_model_provider(model_alias=alias) - return str(provider.name) - - -def _model_aliases_for_generator(generator: ColumnGenerator) -> list[str]: - get_aliases = getattr(generator.config, "get_model_aliases", None) - if callable(get_aliases): - aliases = get_aliases() - else: - aliases = [] - if (alias := getattr(generator.config, "model_alias", None)) is not None: - aliases.append(alias) - aliases.extend(getattr(generator.config, "model_aliases", []) or []) - return list(dict.fromkeys(alias for alias in aliases if alias)) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index df9afc48f..afdd7f616 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -12,9 +12,8 @@ map_http_status_to_provider_error_kind, ) from data_designer.engine.models.clients.factory import create_model_client +from data_designer.engine.models.clients.model_request_executor import ModelRequestExecutor from data_designer.engine.models.clients.retry import RetryConfig -from data_designer.engine.models.clients.throttle_manager import ThrottleDomain, ThrottleManager -from data_designer.engine.models.clients.throttled import ThrottledModelClient from data_designer.engine.models.clients.types import ( AssistantMessage, ChatCompletionRequest, @@ -40,13 +39,11 @@ "ImageGenerationResponse", "ImagePayload", "ModelClient", + "ModelRequestExecutor", "OpenAICompatibleClient", "ProviderError", "ProviderErrorKind", "RetryConfig", - "ThrottleDomain", - "ThrottleManager", - "ThrottledModelClient", "ToolCall", "Usage", "create_model_client", diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py index 2424d3f8c..b73acffb6 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py @@ -31,8 +31,8 @@ class AnthropicClient(HttpModelClient): """Native HTTP adapter for the Anthropic Messages API. Uses ``httpx`` with ``httpx_retries.RetryTransport`` for resilient HTTP - calls. Concurrency / throttle policy is an orchestration concern and - is not managed here — see ``ThrottleManager`` and ``AsyncTaskScheduler``. + calls. Concurrency and request-admission policy are orchestration concerns + and are not managed here. """ _ROUTE_MESSAGES = "/messages" diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py index 54f01961b..44ab1f1d5 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py @@ -33,8 +33,8 @@ class OpenAICompatibleClient(HttpModelClient): """Native HTTP adapter for OpenAI-compatible provider APIs. Uses ``httpx`` with ``httpx_retries.RetryTransport`` for resilient HTTP - calls. Concurrency / throttle policy is an orchestration concern and - is not managed here — see ``ThrottleManager`` and ``AsyncTaskScheduler``. + calls. Concurrency and request-admission policy are orchestration concerns + and are not managed here. """ _ROUTE_CHAT = "/chat/completions" diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py index 458ebfcad..398d151a4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py @@ -10,13 +10,15 @@ from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient from data_designer.engine.models.clients.base import ModelClient +from data_designer.engine.models.clients.model_request_executor import ModelRequestExecutor from data_designer.engine.models.clients.retry import RetryConfig -from data_designer.engine.models.clients.throttle_manager import ThrottleManager -from data_designer.engine.models.clients.throttled import ThrottledModelClient from data_designer.engine.models.errors import FormattedLLMErrorMessage +from data_designer.engine.models.request_admission.controller import RequestAdmissionController +from data_designer.engine.observability import RequestAdmissionEventSink from data_designer.engine.secret_resolver import SecretResolver _SUPPORTED_PROVIDER_TYPES = ("openai", "anthropic") +_NO_TRANSPORT_RETRY_CONFIG = RetryConfig(max_retries=0, retryable_status_codes=frozenset()) def create_model_client( @@ -26,7 +28,8 @@ def create_model_client( *, retry_config: RetryConfig | None = None, client_concurrency_mode: ClientConcurrencyMode = ClientConcurrencyMode.SYNC, - throttle_manager: ThrottleManager | None = None, + request_admission: RequestAdmissionController | None = None, + request_event_sink: RequestAdmissionEventSink | None = None, ) -> ModelClient: """Create a ``ModelClient`` for the given model configuration. @@ -40,12 +43,12 @@ def create_model_client( client_concurrency_mode: ``"sync"`` (default) for the sync engine path, ``"async"`` for the async engine path. Native HTTP adapters are constrained to a single concurrency mode. - throttle_manager: Optional throttle manager for per-request AIMD - concurrency control. When provided, the returned client is wrapped - with ``ThrottledModelClient``. + request_admission: Optional request-admission controller for per-request + provider/model/domain admission. When provided, the returned client + is wrapped with ``ModelRequestExecutor``. **Ordering invariant:** the ``(provider_name, model_id)`` pair must - be registered on the ``ThrottleManager`` via ``register()`` before + be registered on the request-admission controller via ``register()`` before the returned client makes its first request. In the standard flow, ``ModelRegistry._get_model()`` calls ``register()`` during model setup, which happens before any generation task invokes the client. @@ -69,13 +72,14 @@ def create_model_client( max_parallel = model_config.inference_parameters.max_parallel_requests raw_timeout = model_config.inference_parameters.timeout timeout_s = float(raw_timeout if raw_timeout is not None else 60) + adapter_retry_config = _NO_TRANSPORT_RETRY_CONFIG if request_admission is not None else retry_config if provider.provider_type == "openai": client: ModelClient = OpenAICompatibleClient( provider_name=provider.name, endpoint=provider.endpoint, api_key=api_key, - retry_config=retry_config, + retry_config=adapter_retry_config, max_parallel_requests=max_parallel, timeout_s=timeout_s, concurrency_mode=client_concurrency_mode, @@ -85,7 +89,7 @@ def create_model_client( provider_name=provider.name, endpoint=provider.endpoint, api_key=api_key, - retry_config=retry_config, + retry_config=adapter_retry_config, max_parallel_requests=max_parallel, timeout_s=timeout_s, concurrency_mode=client_concurrency_mode, @@ -102,12 +106,14 @@ def create_model_client( ) ) - if throttle_manager is not None: - client = ThrottledModelClient( + if request_admission is not None: + client = ModelRequestExecutor( inner=client, - throttle_manager=throttle_manager, + request_admission=request_admission, provider_name=provider.name, model_id=model_config.model, + event_sink=request_event_sink, + retry_config=retry_config, ) return client diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/model_request_executor.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/model_request_executor.py new file mode 100644 index 000000000..6782c16bc --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/model_request_executor.py @@ -0,0 +1,305 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid +from typing import TYPE_CHECKING, TypeVar + +from data_designer.engine.models.clients.base import ModelClient +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind +from data_designer.engine.models.clients.retry import RetryConfig +from data_designer.engine.models.clients.types import ( + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ImageGenerationRequest, + ImageGenerationResponse, +) +from data_designer.engine.models.request_admission.controller import ( + RequestAdmissionController, + RequestAdmissionError, + RequestAdmissionLease, +) +from data_designer.engine.models.request_admission.outcomes import RequestReleaseOutcome +from data_designer.engine.models.request_admission.resolver import RequestResourceResolver +from data_designer.engine.models.request_admission.resources import ( + RequestAdmissionItem, + RequestDomain, + RequestEventContext, + RequestGroupSpec, +) +from data_designer.engine.observability import ( + RequestAdmissionEvent, + RequestAdmissionEventSink, + runtime_correlation_provider, +) + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + +_T = TypeVar("_T") + +logger = logging.getLogger(__name__) + + +class ModelRequestExecutor(ModelClient): + """Model-call boundary that acquires/releases request-admission leases.""" + + def __init__( + self, + inner: ModelClient, + request_admission: RequestAdmissionController, + provider_name: str, + model_id: str, + event_sink: RequestAdmissionEventSink | None = None, + resource_resolver: RequestResourceResolver | None = None, + retry_config: RetryConfig | None = None, + ) -> None: + self._inner = inner + self._request_admission = request_admission + self._provider_name = provider_name + self._model_id = model_id + self._event_sink = event_sink + self._resource_resolver = resource_resolver or RequestResourceResolver() + self._retry_config = retry_config or RetryConfig() + self._event_sequence = 0 + + @property + def provider_name(self) -> str: + return self._inner.provider_name + + def supports_chat_completion(self) -> bool: + return self._inner.supports_chat_completion() + + def supports_embeddings(self) -> bool: + return self._inner.supports_embeddings() + + def supports_image_generation(self) -> bool: + return self._inner.supports_image_generation() + + def close(self) -> None: + self._inner.close() + + async def aclose(self) -> None: + await self._inner.aclose() + + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + return self._execute_sync(RequestDomain.CHAT, lambda: self._inner.completion(request)) + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + return await self._execute_async(RequestDomain.CHAT, lambda: self._inner.acompletion(request)) + + def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + return self._execute_sync(RequestDomain.EMBEDDING, lambda: self._inner.embeddings(request)) + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + return await self._execute_async(RequestDomain.EMBEDDING, lambda: self._inner.aembeddings(request)) + + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + return self._execute_sync(self._image_domain(request), lambda: self._inner.generate_image(request)) + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + return await self._execute_async(self._image_domain(request), lambda: self._inner.agenerate_image(request)) + + def _execute_sync(self, domain: RequestDomain, call: Callable[[], _T]) -> _T: + for attempt in range(self._max_attempts()): + try: + return self._execute_sync_attempt(domain, call) + except ProviderError as exc: + if not self._should_retry(exc, attempt): + raise + self._sleep_before_retry(attempt) + raise RuntimeError("unreachable request retry state") + + def _execute_sync_attempt(self, domain: RequestDomain, call: Callable[[], _T]) -> _T: + item = self._item(domain) + try: + lease = self._request_admission.acquire_sync(item) + except RequestAdmissionError as exc: + raise ProviderError( + kind=ProviderErrorKind.TIMEOUT, + message=str(exc), + provider_name=self._provider_name, + model_name=self._model_id, + ) from exc + try: + self._emit_model_event("model_request_started", item=item, lease=lease) + result = call() + except ProviderError as exc: + self._release_provider_error(lease, exc) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": exc.kind.value} + ) + raise + except TimeoutError: + self._request_admission.release(lease, RequestReleaseOutcome(kind="provider_timeout")) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": "provider_timeout"} + ) + raise + except BaseException as exc: + outcome = "local_cancelled" if isinstance(exc, KeyboardInterrupt) else "unexpected_exception" + self._request_admission.release(lease, RequestReleaseOutcome(kind=outcome)) + self._emit_model_event("model_request_completed", item=item, lease=lease, diagnostics={"outcome": outcome}) + raise + else: + self._request_admission.release(lease, RequestReleaseOutcome(kind="success")) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": "success"} + ) + return result + + async def _execute_async(self, domain: RequestDomain, call: Callable[[], Awaitable[_T]]) -> _T: + for attempt in range(self._max_attempts()): + try: + return await self._execute_async_attempt(domain, call) + except ProviderError as exc: + if not self._should_retry(exc, attempt): + raise + await self._async_sleep_before_retry(attempt) + raise RuntimeError("unreachable request retry state") + + async def _execute_async_attempt(self, domain: RequestDomain, call: Callable[[], Awaitable[_T]]) -> _T: + item = self._item(domain) + try: + lease = await self._request_admission.acquire_async(item) + except RequestAdmissionError as exc: + raise ProviderError( + kind=ProviderErrorKind.TIMEOUT, + message=str(exc), + provider_name=self._provider_name, + model_name=self._model_id, + ) from exc + except asyncio.CancelledError: + raise + try: + self._emit_model_event("model_request_started", item=item, lease=lease) + result = await call() + except asyncio.CancelledError: + self._request_admission.release(lease, RequestReleaseOutcome(kind="local_cancelled")) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": "local_cancelled"} + ) + raise + except ProviderError as exc: + self._release_provider_error(lease, exc) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": exc.kind.value} + ) + raise + except TimeoutError: + self._request_admission.release(lease, RequestReleaseOutcome(kind="provider_timeout")) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": "provider_timeout"} + ) + raise + except BaseException: + self._request_admission.release(lease, RequestReleaseOutcome(kind="unexpected_exception")) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": "unexpected_exception"} + ) + raise + else: + self._request_admission.release(lease, RequestReleaseOutcome(kind="success")) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": "success"} + ) + return result + + def _max_attempts(self) -> int: + return max(1, self._retry_config.max_retries + 1) + + def _should_retry(self, exc: ProviderError, attempt: int) -> bool: + if attempt >= self._max_attempts() - 1: + return False + if isinstance(exc.__cause__, RequestAdmissionError): + return False + if exc.kind == ProviderErrorKind.RATE_LIMIT: + return False + if exc.status_code is not None: + return exc.status_code in self._retry_config.retryable_status_codes + return exc.kind == ProviderErrorKind.API_CONNECTION + + def _sleep_before_retry(self, attempt: int) -> None: + delay = self._retry_delay_seconds(attempt) + if delay > 0.0: + time.sleep(delay) + + async def _async_sleep_before_retry(self, attempt: int) -> None: + delay = self._retry_delay_seconds(attempt) + if delay > 0.0: + await asyncio.sleep(delay) + + def _retry_delay_seconds(self, attempt: int) -> float: + if self._retry_config.backoff_factor <= 0.0: + return 0.0 + delay = self._retry_config.backoff_factor * (2**attempt) + return min(delay, self._retry_config.max_backoff_wait) + + def _release_provider_error(self, lease: RequestAdmissionLease, exc: ProviderError) -> None: + if exc.kind == ProviderErrorKind.RATE_LIMIT: + outcome = RequestReleaseOutcome(kind="rate_limited", retry_after_seconds=exc.retry_after) + elif exc.kind == ProviderErrorKind.TIMEOUT: + outcome = RequestReleaseOutcome(kind="provider_timeout") + else: + outcome = RequestReleaseOutcome(kind="provider_failure") + self._request_admission.release(lease, outcome) + + def _item(self, domain: RequestDomain) -> RequestAdmissionItem: + resolved = self._resource_resolver.resolve( + provider_name=self._provider_name, + model_id=self._model_id, + domain=domain, + ) + resource = resolved.resource + correlation = runtime_correlation_provider.current() + return RequestAdmissionItem( + resource=resource, + group=RequestGroupSpec(key=resource), + event_context=RequestEventContext( + captured_correlation=correlation, + task_execution_id=correlation.task_execution_id if correlation is not None else None, + request_attempt_id=f"request-{uuid.uuid4().hex}", + ), + ) + + @staticmethod + def _image_domain(request: ImageGenerationRequest) -> RequestDomain: + return RequestDomain.CHAT if request.messages is not None else RequestDomain.IMAGE + + def _emit_model_event( + self, + event_kind: str, + *, + item: RequestAdmissionItem, + lease: RequestAdmissionLease, + diagnostics: dict[str, object] | None = None, + ) -> None: + if self._event_sink is None: + return + self._event_sequence += 1 + context = item.event_context + try: + self._event_sink.emit_request_event( + RequestAdmissionEvent.capture( + event_kind, # type: ignore[arg-type] + sequence=self._event_sequence, + correlation=context.captured_correlation + if context is not None + else runtime_correlation_provider.current(), + request_attempt_id=context.request_attempt_id if context is not None else None, + request_lease_id=lease.lease_id, + request_resource_key=item.resource, + request_group_key=item.group.key, + pressure_snapshot=self._request_admission.pressure.snapshot(item.resource), + diagnostics=diagnostics or {}, + ) + ) + except Exception: + logger.warning("Model request event sink raised; dropping event.", exc_info=True) + return diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py index 56aa1eec4..9f51a48b2 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py @@ -14,8 +14,8 @@ logger = logging.getLogger(__name__) -# 429 must not be retried at the transport layer so that rate-limit signals -# propagate to ThrottledModelClient for AIMD backoff. +# 429 must not be retried at the transport layer so rate-limit signals +# propagate to ModelRequestExecutor and request admission for AIMD backoff. _RESERVED_STATUS_CODES: frozenset[int] = frozenset({429}) @@ -25,7 +25,7 @@ class RetryConfig: Retries non-rate-limit transient failures (``502``, ``503``, ``504``) and connection/transport errors. ``429`` is intentionally excluded so that - rate-limit signals reach the ``ThrottledModelClient`` wrapper for AIMD + rate-limit signals reach the ``ModelRequestExecutor`` boundary for AIMD backoff. If a caller includes ``429`` in ``retryable_status_codes``, ``create_retry_transport`` will strip it and log a warning. """ @@ -52,10 +52,8 @@ def create_retry_transport( config: Retry policy. Uses ``RetryConfig()`` defaults when ``None``. strip_rate_limit_codes: When ``True`` (default, used by the async engine), status codes in ``_RESERVED_STATUS_CODES`` (currently ``{429}``) are - stripped so that rate-limit responses reach the ``ThrottledModelClient`` - AIMD feedback loop. When ``False`` (used by the sync engine, which has - no salvage queue), 429 is kept in the retry list so the transport layer - retries it transparently. + stripped so that rate-limit responses reach the request-admission + AIMD feedback loop. transport: Optional pre-configured transport to pass directly to ``RetryTransport``. Pass ``httpx.HTTPTransport`` for sync clients or ``httpx.AsyncHTTPTransport`` for async clients — typically with a custom @@ -70,7 +68,7 @@ def create_retry_transport( if reserved_overlap: logger.warning( "Stripping reserved status codes %s from retryable_status_codes; " - "these must reach ThrottledModelClient for AIMD backoff.", + "these must reach ModelRequestExecutor/request admission for AIMD backoff.", sorted(reserved_overlap), ) status_codes = status_codes - _RESERVED_STATUS_CODES diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle_manager.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle_manager.py deleted file mode 100644 index fb57c32ab..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle_manager.py +++ /dev/null @@ -1,505 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import asyncio -import logging -import math -import threading -import time -from dataclasses import dataclass, field -from enum import Enum - -from data_designer.config.run_config import ThrottleConfig - -logger = logging.getLogger(__name__) - - -class ThrottleDomain(str, Enum): - CHAT = "chat" - EMBEDDING = "embedding" - IMAGE = "image" - HEALTHCHECK = "healthcheck" - - -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- - -DEFAULT_MIN_LIMIT: int = 1 -CAPACITY_POLL_INTERVAL: float = 0.05 - - -# --------------------------------------------------------------------------- -# Internal state containers -# --------------------------------------------------------------------------- - - -@dataclass -class DomainThrottleState: - """Per-domain AIMD concurrency state. - - All mutations must be performed while holding the owning - ``ThrottleManager._lock``. - """ - - current_limit: int - in_flight: int = 0 - blocked_until: float = 0.0 - success_streak: int = 0 - waiters: int = 0 - rate_limit_ceiling: int = 0 - consecutive_429s: int = 0 - - -@dataclass -class GlobalCapState: - """Tracks the effective hard cap across aliases sharing a provider+model.""" - - limits_by_alias: dict[str, int] = field(default_factory=dict) - effective_max: int = 0 - - def register_alias(self, alias: str, max_parallel: int) -> None: - self.limits_by_alias[alias] = max_parallel - self.effective_max = min(self.limits_by_alias.values()) - - -# --------------------------------------------------------------------------- -# ThrottleManager -# --------------------------------------------------------------------------- - - -class ThrottleManager: - """Adaptive concurrency manager using AIMD (Additive Increase / - Multiplicative Decrease). - - Keyed at two levels: - - - **Global cap**: ``(provider_name, model_id)`` — shared hard ceiling. - - **Domain**: ``(provider_name, model_id, throttle_domain)`` — per-route - AIMD state that floats between 1 and the global effective max. - - **AIMD behaviour**: - - - *Decrease* — on a 429 / rate-limit signal the domain's concurrency limit - is multiplied by ``reduce_factor`` (default 0.75, i.e. reduced by 25%) - and a cooldown block is applied for ``retry_after`` seconds (or - ``default_cooldown_seconds``). - - *Increase* — after every ``success_window`` consecutive successful - releases the limit grows by ``additive_increase`` (default 1), up to - the *rate-limit ceiling* (or the global effective max if no 429 has - been observed yet). - - *Stabilization* — each 429 records the pre-halving limit as - ``rate_limit_ceiling``. Subsequent additive increases stop at - ``ceiling * (1 + ceiling_overshoot)`` (default 10%) instead of - climbing all the way to ``effective_max``. The overshoot band lets - the system probe whether the endpoint can now handle more traffic - (e.g. after load drops) while dampening the sawtooth. If the probe - succeeds, the ceiling ratchets up; if it triggers another 429, the - ceiling lowers. - - Thread-safe: all state mutations are guarded by a single lock so that - sync and async callers co-throttle correctly. - """ - - def __init__( - self, - config: ThrottleConfig | None = None, - ) -> None: - tc = config or ThrottleConfig() - self._reduce_factor = tc.reduce_factor - self._additive_increase = tc.additive_increase - self._success_window = tc.success_window - self._default_cooldown_seconds = tc.cooldown_seconds - self._ceiling_overshoot = tc.ceiling_overshoot - self._lock = threading.Lock() - self._global_caps: dict[tuple[str, str], GlobalCapState] = {} - self._domains: dict[tuple[str, str, str], DomainThrottleState] = {} - - # ------------------------------------------------------------------- - # Registration - # ------------------------------------------------------------------- - - def register( - self, - *, - provider_name: str, - model_id: str, - alias: str, - max_parallel_requests: int, - ) -> None: - """Register a model alias and its concurrency limit. - - If multiple aliases share the same ``(provider_name, model_id)`` the - effective max is ``min()`` of all registered limits. Existing domain - states are clamped to the new effective max. - - **Ordering invariant:** ``register()`` must be called for a - ``(provider_name, model_id)`` pair *before* any ``try_acquire()`` for - the same key. If ``try_acquire()`` runs first it creates a domain at - ``DEFAULT_MIN_LIMIT`` and ``_clamp_domains`` only *decreases* limits, - so a later ``register()`` will not raise the domain to the intended - capacity. - """ - with self._lock: - global_key = (provider_name, model_id) - cap = self._global_caps.setdefault(global_key, GlobalCapState()) - cap.register_alias(alias, max_parallel_requests) - self._clamp_domains(global_key, cap.effective_max) - logger.debug( - "Throttle registered alias=%r for %s/%s (max_parallel=%d, effective_max=%d)", - alias, - provider_name, - model_id, - max_parallel_requests, - cap.effective_max, - ) - - # ------------------------------------------------------------------- - # Core non-blocking primitives - # ------------------------------------------------------------------- - - def is_registered(self, provider_name: str, model_id: str) -> bool: - """Return ``True`` if ``register()`` has been called for this key.""" - with self._lock: - return (provider_name, model_id) in self._global_caps - - def try_acquire( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - now: float | None = None, - ) -> float: - """Attempt to acquire a concurrency slot. - - Returns ``0.0`` if the slot was acquired, otherwise the number of - seconds the caller should wait before retrying. - - Raises ``RuntimeError`` if the ``(provider_name, model_id)`` pair - has not been registered via ``register()``. - """ - now = now if now is not None else time.monotonic() - with self._lock: - if (provider_name, model_id) not in self._global_caps: - raise RuntimeError( - f"ThrottleManager.try_acquire() called before register() " - f"for ({provider_name!r}, {model_id!r}). " - f"Call register() first to set the concurrency limit." - ) - state = self._get_or_create_domain(provider_name, model_id, domain) - if now < state.blocked_until: - return state.blocked_until - now - if state.in_flight >= state.current_limit: - return CAPACITY_POLL_INTERVAL - state.in_flight += 1 - return 0.0 - - def release_success( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - now: float | None = None, - ) -> None: - with self._lock: - state = self._get_or_create_domain(provider_name, model_id, domain) - state.in_flight = max(0, state.in_flight - 1) - state.consecutive_429s = 0 - state.success_streak += 1 - if state.success_streak >= self._success_window: - effective_max = self._effective_max_for(provider_name, model_id) - cap = self._compute_soft_ceiling(state, effective_max) - if state.current_limit < cap: - prev = state.current_limit - state.current_limit = min(state.current_limit + self._additive_increase, cap) - if state.current_limit >= cap: - if cap < effective_max: - logger.info( - "🔋✅ '%s' [%s] concurrency recovered to %d parallel requests", - model_id, - domain.value, - state.current_limit, - ) - else: - logger.info( - "🔋✅ '%s' [%s] concurrency fully recovered (%d parallel requests)", - model_id, - domain.value, - state.current_limit, - ) - else: - logger.info( - "🪫📈🔥 '%s' [%s] concurrency increased from %d → %d", - model_id, - domain.value, - prev, - state.current_limit, - ) - state.success_streak = 0 - - def release_rate_limited( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - retry_after: float | None = None, - now: float | None = None, - ) -> None: - now = now if now is not None else time.monotonic() - with self._lock: - state = self._get_or_create_domain(provider_name, model_id, domain) - state.in_flight = max(0, state.in_flight - 1) - prev_limit = state.current_limit - is_first_in_cascade = state.consecutive_429s == 0 - state.consecutive_429s += 1 - cooldown_duration = ( - retry_after if retry_after is not None and retry_after > 0 else self._default_cooldown_seconds - ) - state.blocked_until = now + cooldown_duration - state.success_streak = 0 - - if is_first_in_cascade: - state.current_limit = max(DEFAULT_MIN_LIMIT, math.floor(state.current_limit * self._reduce_factor)) - if state.current_limit < prev_limit: - if state.rate_limit_ceiling == 0: - state.rate_limit_ceiling = prev_limit - else: - state.rate_limit_ceiling = min(state.rate_limit_ceiling, prev_limit) - if state.rate_limit_ceiling < prev_limit: - logger.info( - "🪫📉 '%s' [%s] server rate-limited at %d (server limit ~%d) — concurrency reduced to %d (retrying in %.0fs)", - model_id, - domain.value, - prev_limit, - state.rate_limit_ceiling, - state.current_limit, - cooldown_duration, - ) - else: - logger.info( - "🪫📉 '%s' [%s] server rate-limited — concurrency reduced from %d → %d (retrying in %.0fs)", - model_id, - domain.value, - prev_limit, - state.current_limit, - cooldown_duration, - ) - else: - logger.debug( - "Throttle %s [%s] cascade 429 #%d (limit held at %d)", - model_id, - domain.value, - state.consecutive_429s, - state.current_limit, - ) - - def release_failure( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - now: float | None = None, - ) -> None: - with self._lock: - state = self._get_or_create_domain(provider_name, model_id, domain) - state.in_flight = max(0, state.in_flight - 1) - # Non-rate-limit failure breaks the 429 cascade: a sequence like - # 429 → 500 → 429 should treat the second 429 as the start of a - # new cascade. But only after the prior burst has fully drained - # (in_flight == 0) - otherwise mixed responses from a single - # in-flight wave (429 → 500 → 429 with concurrent slots) would - # double-reduce the limit even though the provider hasn't - # recovered between the two 429s. - if state.in_flight == 0: - state.consecutive_429s = 0 - - # ------------------------------------------------------------------- - # Sync / async wrappers - # ------------------------------------------------------------------- - - def acquire_sync( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - timeout: float | None = None, - ) -> None: - """Block until a permit is available. - - ``timeout=None`` (the default) waits indefinitely; the per-request HTTP - timeout (``inference_parameters.timeout``) is the only deadline that bounds - actual work, and queue waits scale naturally with provider speed and - AIMD's adaptive concurrency. Pass an explicit float for tests or for - support cases where a queue-wait deadline is genuinely desired. - """ - deadline = (time.monotonic() + timeout) if timeout is not None else None - wait = self.try_acquire(provider_name=provider_name, model_id=model_id, domain=domain) - if wait == 0.0: - return - with self._lock: - # state is captured once and reused in the finally block; safe - # because DomainThrottleState objects are never replaced after creation. - state = self._get_or_create_domain(provider_name, model_id, domain) - state.waiters += 1 - if state.waiters == 1: - logger.debug( - "Throttle %s/%s [%s] queue forming (in_flight=%d/%d)", - provider_name, - model_id, - domain.value, - state.in_flight, - state.current_limit, - ) - try: - while True: - if deadline is not None: - remaining = deadline - time.monotonic() - if remaining <= 0 or wait > remaining: - raise TimeoutError( - f"Throttle acquire timed out after {timeout:.0f}s " - f"for {provider_name}/{model_id} [{domain.value}]" - ) - sleep_for = min(wait, remaining) - else: - sleep_for = wait - time.sleep(sleep_for) - wait = self.try_acquire(provider_name=provider_name, model_id=model_id, domain=domain) - if wait == 0.0: - return - finally: - with self._lock: - state.waiters -= 1 - if state.waiters == 0: - logger.debug( - "Throttle %s/%s [%s] queue drained", - provider_name, - model_id, - domain.value, - ) - - async def acquire_async( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - timeout: float | None = None, - ) -> None: - """Block until a permit is available. - - ``timeout=None`` (the default) waits indefinitely; the per-request HTTP - timeout (``inference_parameters.timeout``) is the only deadline that bounds - actual work, and queue waits scale naturally with provider speed and - AIMD's adaptive concurrency. Pass an explicit float for tests or for - support cases where a queue-wait deadline is genuinely desired. - """ - deadline = (time.monotonic() + timeout) if timeout is not None else None - wait = self.try_acquire(provider_name=provider_name, model_id=model_id, domain=domain) - if wait == 0.0: - return - with self._lock: - # state is captured once and reused in the finally block; safe - # because DomainThrottleState objects are never replaced after creation. - state = self._get_or_create_domain(provider_name, model_id, domain) - state.waiters += 1 - if state.waiters == 1: - logger.debug( - "Throttle %s/%s [%s] queue forming (in_flight=%d/%d)", - provider_name, - model_id, - domain.value, - state.in_flight, - state.current_limit, - ) - try: - while True: - if deadline is not None: - remaining = deadline - time.monotonic() - if remaining <= 0 or wait > remaining: - raise TimeoutError( - f"Throttle acquire timed out after {timeout:.0f}s " - f"for {provider_name}/{model_id} [{domain.value}]" - ) - sleep_for = min(wait, remaining) - else: - sleep_for = wait - await asyncio.sleep(sleep_for) - wait = self.try_acquire(provider_name=provider_name, model_id=model_id, domain=domain) - if wait == 0.0: - return - finally: - with self._lock: - state.waiters -= 1 - if state.waiters == 0: - logger.debug( - "Throttle %s/%s [%s] queue drained", - provider_name, - model_id, - domain.value, - ) - - # ------------------------------------------------------------------- - # Introspection (useful for tests and telemetry) - # ------------------------------------------------------------------- - - def get_domain_state( - self, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - ) -> DomainThrottleState | None: - with self._lock: - return self._domains.get((provider_name, model_id, domain.value)) - - def get_effective_max(self, provider_name: str, model_id: str) -> int: - with self._lock: - return self._effective_max_for(provider_name, model_id) - - # ------------------------------------------------------------------- - # Private helpers - # ------------------------------------------------------------------- - - def _compute_soft_ceiling(self, state: DomainThrottleState, effective_max: int) -> int: - """Return the upper bound for additive increase. - - If a rate-limit ceiling has been recorded, allow probing up to - ``ceiling * (1 + overshoot)`` (clamped to ``effective_max``). - Otherwise fall back to ``effective_max``. - """ - if state.rate_limit_ceiling <= 0: - return effective_max - soft = state.rate_limit_ceiling + max(1, math.floor(state.rate_limit_ceiling * self._ceiling_overshoot)) - return min(soft, effective_max) - - def _get_or_create_domain( - self, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - ) -> DomainThrottleState: - key = (provider_name, model_id, domain.value) - state = self._domains.get(key) - if state is None: - effective_max = self._effective_max_for(provider_name, model_id) - state = DomainThrottleState(current_limit=effective_max) - self._domains[key] = state - return state - - def _effective_max_for(self, provider_name: str, model_id: str) -> int: - cap = self._global_caps.get((provider_name, model_id)) - if cap is None or cap.effective_max <= 0: - return DEFAULT_MIN_LIMIT - return cap.effective_max - - def _clamp_domains(self, global_key: tuple[str, str], effective_max: int) -> None: - provider_name, model_id = global_key - for (pn, mid, _dom), state in self._domains.items(): - if pn == provider_name and mid == model_id and state.current_limit > effective_max: - state.current_limit = effective_max diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/throttled.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/throttled.py deleted file mode 100644 index 797452c69..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/throttled.py +++ /dev/null @@ -1,222 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import contextlib -import logging -from typing import TYPE_CHECKING - -from data_designer.engine.models.clients.base import ModelClient -from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind -from data_designer.engine.models.clients.throttle_manager import ThrottleDomain -from data_designer.engine.models.clients.types import ( - ChatCompletionRequest, - ChatCompletionResponse, - EmbeddingRequest, - EmbeddingResponse, - ImageGenerationRequest, - ImageGenerationResponse, -) - -if TYPE_CHECKING: - from collections.abc import AsyncIterator, Iterator - - from data_designer.engine.models.clients.throttle_manager import ThrottleManager - - -logger = logging.getLogger(__name__) - - -class ThrottledModelClient(ModelClient): - """Wraps a ``ModelClient`` with per-request throttle acquire/release. - - Inherits from ``ModelClient`` (a ``Protocol``) so that static type - checkers verify conformance and flag missing methods if the protocol - evolves. - - Every outbound HTTP call acquires a throttle permit from the - ``ThrottleManager`` and releases it on success, rate-limit, or failure. - The ``ThrottleDomain`` is determined by the method: - - - ``completion`` / ``acompletion`` -> ``CHAT`` - - ``embeddings`` / ``aembeddings`` -> ``EMBEDDING`` - - ``generate_image`` / ``agenerate_image`` -> ``IMAGE`` when - ``request.messages is None`` (diffusion), ``CHAT`` when messages are set - """ - - def __init__( - self, - inner: ModelClient, - throttle_manager: ThrottleManager, - provider_name: str, - model_id: str, - ) -> None: - self._inner = inner - self._tm = throttle_manager - self._provider_name = provider_name - self._model_id = model_id - - # --- ModelClient protocol delegation --- - - @property - def provider_name(self) -> str: - return self._inner.provider_name - - def supports_chat_completion(self) -> bool: - return self._inner.supports_chat_completion() - - def supports_embeddings(self) -> bool: - return self._inner.supports_embeddings() - - def supports_image_generation(self) -> bool: - return self._inner.supports_image_generation() - - def close(self) -> None: - self._inner.close() - - async def aclose(self) -> None: - await self._inner.aclose() - - # --- Throttled methods --- - - def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - with self._throttled_sync(ThrottleDomain.CHAT): - return self._inner.completion(request) - - async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - async with self._athrottled(ThrottleDomain.CHAT): - return await self._inner.acompletion(request) - - def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - with self._throttled_sync(ThrottleDomain.EMBEDDING): - return self._inner.embeddings(request) - - async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - async with self._athrottled(ThrottleDomain.EMBEDDING): - return await self._inner.aembeddings(request) - - def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - domain = self._image_domain(request) - with self._throttled_sync(domain): - return self._inner.generate_image(request) - - async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - domain = self._image_domain(request) - async with self._athrottled(domain): - return await self._inner.agenerate_image(request) - - # --- Context managers --- - - @contextlib.contextmanager - def _throttled_sync(self, domain: ThrottleDomain) -> Iterator[None]: - try: - self._tm.acquire_sync( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except TimeoutError as exc: - raise ProviderError( - kind=ProviderErrorKind.TIMEOUT, - message=str(exc), - provider_name=self._provider_name, - model_name=self._model_id, - ) from exc - exc_to_reraise: BaseException | None = None - try: - yield - except ProviderError as exc: - exc_to_reraise = exc - try: - self._release_on_provider_error(domain, exc) - except Exception: - logger.exception("ThrottleManager release failed; permit may leak") - except BaseException as exc: - exc_to_reraise = exc - try: - self._tm.release_failure( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except Exception: - logger.exception("ThrottleManager release failed; permit may leak") - else: - try: - self._tm.release_success( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except Exception: - logger.exception("ThrottleManager release_success failed") - if exc_to_reraise is not None: - raise exc_to_reraise - - @contextlib.asynccontextmanager - async def _athrottled(self, domain: ThrottleDomain) -> AsyncIterator[None]: - try: - await self._tm.acquire_async( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except TimeoutError as exc: - raise ProviderError( - kind=ProviderErrorKind.TIMEOUT, - message=str(exc), - provider_name=self._provider_name, - model_name=self._model_id, - ) from exc - exc_to_reraise: BaseException | None = None - try: - yield - except ProviderError as exc: - exc_to_reraise = exc - try: - self._release_on_provider_error(domain, exc) - except Exception: - logger.exception("ThrottleManager release failed; permit may leak") - except BaseException as exc: - exc_to_reraise = exc - try: - self._tm.release_failure( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except Exception: - logger.exception("ThrottleManager release failed; permit may leak") - else: - try: - self._tm.release_success( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except Exception: - logger.exception("ThrottleManager release_success failed") - if exc_to_reraise is not None: - raise exc_to_reraise - - # --- Private helpers --- - - def _release_on_provider_error(self, domain: ThrottleDomain, exc: ProviderError) -> None: - if exc.kind == ProviderErrorKind.RATE_LIMIT: - self._tm.release_rate_limited( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - retry_after=exc.retry_after, - ) - else: - self._tm.release_failure( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - - @staticmethod - def _image_domain(request: ImageGenerationRequest) -> ThrottleDomain: - return ThrottleDomain.CHAT if request.messages is not None else ThrottleDomain.IMAGE diff --git a/packages/data-designer-engine/src/data_designer/engine/models/factory.py b/packages/data-designer-engine/src/data_designer/engine/models/factory.py index 6ef2b2727..6531d1dbd 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/factory.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/factory.py @@ -40,9 +40,8 @@ def create_model_registry( client_concurrency_mode: ``"sync"`` (default) or ``"async"``. Forwarded to native HTTP adapters so each client is constrained to a single concurrency mode. - run_config: Optional runtime configuration. The nested - ``run_config.throttle`` (a ``ThrottleConfig``) is forwarded to the - ``ThrottleManager`` constructor. + run_config: Optional runtime configuration. Request admission uses + engine-internal defaults in V1; no public run-config knob is exposed. Returns: A configured ModelRegistry instance. @@ -50,11 +49,12 @@ def create_model_registry( from data_designer.config.run_config import RunConfig from data_designer.engine.models.clients.factory import create_model_client from data_designer.engine.models.clients.retry import RetryConfig - from data_designer.engine.models.clients.throttle_manager import ThrottleManager from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.registry import ModelRegistry + from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController - throttle_manager = ThrottleManager((run_config or RunConfig()).throttle) + _ = run_config or RunConfig() + request_admission = AdaptiveRequestAdmissionController() def model_facade_factory( model_config: ModelConfig, @@ -68,7 +68,7 @@ def model_facade_factory( model_provider_registry, retry_config=retry_config, client_concurrency_mode=client_concurrency_mode, - throttle_manager=throttle_manager, + request_admission=request_admission, ) return ModelFacade( model_config, @@ -82,6 +82,6 @@ def model_facade_factory( secret_resolver=secret_resolver, model_provider_registry=model_provider_registry, model_facade_factory=model_facade_factory, - throttle_manager=throttle_manager, + request_admission=request_admission, retry_config=RetryConfig(), ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/registry.py b/packages/data-designer-engine/src/data_designer/engine/models/registry.py index a6103ffc3..5b81bb504 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/registry.py @@ -16,8 +16,8 @@ from collections.abc import Callable from data_designer.engine.models.clients.retry import RetryConfig - from data_designer.engine.models.clients.throttle_manager import ThrottleManager from data_designer.engine.models.facade import ModelFacade + from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController ModelFacadeFactory = Callable[ [ModelConfig, SecretResolver, ModelProviderRegistry, RetryConfig | None], @@ -35,13 +35,13 @@ def __init__( model_provider_registry: ModelProviderRegistry, model_configs: list[ModelConfig] | None = None, model_facade_factory: ModelFacadeFactory | None = None, - throttle_manager: ThrottleManager | None = None, + request_admission: AdaptiveRequestAdmissionController | None = None, retry_config: RetryConfig | None = None, ) -> None: self._secret_resolver = secret_resolver self._model_provider_registry = model_provider_registry self._model_facade_factory = model_facade_factory - self._throttle_manager = throttle_manager + self._request_admission = request_admission self._retry_config = retry_config self._model_configs: dict[str, ModelConfig] = {} self._models: dict[str, ModelFacade] = {} @@ -56,8 +56,8 @@ def models(self) -> dict[str, ModelFacade]: return self._models @property - def throttle_manager(self) -> ThrottleManager | None: - return self._throttle_manager + def request_admission(self) -> AdaptiveRequestAdmissionController | None: + return self._request_admission @property def retry_config(self) -> RetryConfig | None: @@ -178,10 +178,9 @@ def get_aggregate_max_parallel_requests(self) -> int: This is a coarse upper bound: it sums over *all* registered aliases, including those not referenced by the current generator set, and does not deduplicate aliases sharing a ``(provider_name, model_id)`` key. - The result is used to size the scheduler's LLM-wait semaphore, which - is a memory-safety cap — oversizing wastes a few coroutine slots but - does not affect correctness because the ``ThrottleManager`` enforces - the real per-key limit. + The result is used to size scheduler task-stage model admission, which + is a memory-safety cap. Concrete provider/model request capacity is + enforced by request admission at model-call time. """ return sum(mc.inference_parameters.max_parallel_requests for mc in self._model_configs.values()) @@ -314,8 +313,8 @@ def _get_model(self, model_config: ModelConfig) -> ModelFacade: self._model_provider_registry, self._retry_config, ) - if self._throttle_manager is not None: - self._throttle_manager.register( + if self._request_admission is not None: + self._request_admission.register( provider_name=facade.model_provider_name, model_id=model_config.model, alias=model_config.alias, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/config.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/config.py new file mode 100644 index 000000000..2796698c5 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/config.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field + +from data_designer.engine.models.request_admission.resources import RequestResourceKey + + +@dataclass(frozen=True) +class RequestAdmissionConfig: + initial_limits: Mapping[RequestResourceKey, int] = field(default_factory=dict) + max_limit_clamps: Mapping[RequestResourceKey, int | None] = field(default_factory=dict) + cooldown_seconds: float = 2.0 + multiplicative_decrease_factor: float = 0.75 + additive_increase_step: int = 1 + increase_after_successes: int = 25 + default_queue_wait_timeout_seconds: float | None = None diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py new file mode 100644 index 000000000..8494d7d69 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py @@ -0,0 +1,741 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import logging +import math +import threading +import time +import uuid +from collections import Counter, deque +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Literal, Protocol + +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.limits import AdaptiveRequestLimitState +from data_designer.engine.models.request_admission.outcomes import ReleaseResult, RequestReleaseOutcome +from data_designer.engine.models.request_admission.pressure import ( + ProviderModelPressureSnapshot, + RequestPressureSnapshot, + RequestPressureSnapshotProvider, +) +from data_designer.engine.models.request_admission.queue import RequestFairQueue, RequestWaiter +from data_designer.engine.models.request_admission.resources import ( + RequestAdmissionItem, + RequestDomain, + RequestResourceKey, +) +from data_designer.engine.models.resources import ProviderModelKey +from data_designer.engine.observability import ( + RequestAdmissionEvent, + RequestAdmissionEventSink, + runtime_correlation_provider, +) + +logger = logging.getLogger(__name__) + +RequestDenyReason = Literal[ + "no_capacity", + "cooldown", + "queue_timeout", + "queued_waiters_ahead", + "cancellation", + "shutdown", + "hard_policy_denial", +] +RELEASED_LEASE_HISTORY_LIMIT = 8192 +_TERMINAL_DENIAL_REASONS: frozenset[RequestDenyReason] = frozenset({"hard_policy_denial", "shutdown"}) + + +@dataclass(frozen=True) +class RequestAdmissionDenied: + item: RequestAdmissionItem + reason: RequestDenyReason + retry_after_seconds: float | None = None + available_after_monotonic: float | None = None + snapshot: object | None = None + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class RequestAdmissionLease: + lease_id: str + item: RequestAdmissionItem + acquired_at: float + current_adaptive_limit: int + effective_max: int + controller_generation: str + + +RequestAdmissionDecision = RequestAdmissionLease | RequestAdmissionDenied + + +class RequestAdmissionError(RuntimeError): + """Raised by blocking acquire paths when no request lease is acquired.""" + + def __init__(self, decision: RequestAdmissionDenied) -> None: + super().__init__(f"Request admission failed: {decision.reason}") + self.decision = decision + + +class RequestAdmissionController(Protocol): + def try_acquire(self, item: RequestAdmissionItem) -> RequestAdmissionDecision: ... + + def acquire_sync(self, item: RequestAdmissionItem) -> RequestAdmissionLease: ... + + async def acquire_async(self, item: RequestAdmissionItem) -> RequestAdmissionLease: ... + + def release(self, lease: RequestAdmissionLease, outcome: RequestReleaseOutcome) -> ReleaseResult: ... + + @property + def pressure(self) -> RequestPressureSnapshotProvider: ... + + +@dataclass +class _GlobalCapState: + limits_by_alias: dict[str, int] = field(default_factory=dict) + effective_max: int = 0 + + def register_alias(self, alias: str, max_parallel: int) -> None: + self.limits_by_alias[alias] = max(1, max_parallel) + self.effective_max = min(self.limits_by_alias.values()) + + +class AdaptiveRequestAdmissionController(RequestPressureSnapshotProvider): + """AIMD-backed request admission controller with exact request leases.""" + + def __init__( + self, + config: RequestAdmissionConfig | None = None, + *, + event_sink: RequestAdmissionEventSink | None = None, + ) -> None: + self._config = config or RequestAdmissionConfig() + self._lock = threading.Lock() + self._condition = threading.Condition(self._lock) + self._generation = uuid.uuid4().hex + self._global_caps: dict[ProviderModelKey, _GlobalCapState] = {} + self._domains: dict[RequestResourceKey, AdaptiveRequestLimitState] = {} + self._active_leases: dict[str, RequestAdmissionLease] = {} + self._released: set[str] = set() + self._released_order: deque[str] = deque(maxlen=RELEASED_LEASE_HISTORY_LIMIT) + self._aggregate_in_flight: Counter[ProviderModelKey] = Counter() + self._aggregate_active_leases: Counter[ProviderModelKey] = Counter() + self._sequence = 0 + self._release_diagnostics: Counter[str] = Counter() + self._queue = RequestFairQueue() + self._event_sink = event_sink + + @property + def pressure(self) -> RequestPressureSnapshotProvider: + return self + + @property + def config(self) -> RequestAdmissionConfig: + return self._config + + def register( + self, + *, + provider_name: str, + model_id: str, + alias: str, + max_parallel_requests: int, + ) -> None: + events: list[RequestAdmissionEvent] = [] + with self._lock: + key = ProviderModelKey(provider_name, model_id) + cap = self._global_caps.setdefault(key, _GlobalCapState()) + before = cap.effective_max + cap.register_alias(alias, max_parallel_requests) + self._sequence += 1 + for resource, state in self._domains.items(): + if resource.provider_model_key == key: + effective_max = self._effective_max_for_resource(resource) + state.current_limit = min(state.current_limit, effective_max) + events.append( + self._request_event_locked( + "request_resource_registered", + request_resource_key=RequestResourceKey(provider_name, model_id, RequestDomain.CHAT), + diagnostics={"alias": alias, "provider_model": key, "max_parallel_requests": max_parallel_requests}, + ) + ) + if before != cap.effective_max: + events.append( + self._request_event_locked( + "request_effective_cap_changed", + request_resource_key=RequestResourceKey(provider_name, model_id, RequestDomain.CHAT), + diagnostics={"provider_model": key, "previous": before, "current": cap.effective_max}, + ) + ) + self._admit_waiters_locked(events) + self._condition.notify_all() + self._emit_events(events) + + def try_acquire(self, item: RequestAdmissionItem) -> RequestAdmissionDecision: + now = time.monotonic() + events: list[RequestAdmissionEvent] = [] + decision: RequestAdmissionDecision + with self._lock: + events.append(self._request_event_locked("request_wait_started", item=item)) + if self._queued_waiter_ahead_locked(item, now): + decision = RequestAdmissionDenied( + item=item, + reason="queued_waiters_ahead", + snapshot=self._snapshot_locked(item.resource, now), + ) + events.append(self._request_event_locked("request_acquire_denied", item=item, decision=decision)) + else: + denied = self._denial_for(item, now) + if denied is not None: + decision = denied + events.append(self._request_event_locked("request_acquire_denied", item=item, decision=decision)) + else: + decision = self._acquire_locked(item, now) + events.append(self._request_event_locked("request_wait_completed", item=item, lease=decision)) + events.append(self._request_event_locked("request_lease_acquired", item=item, lease=decision)) + self._emit_events(events) + return decision + + def acquire_sync(self, item: RequestAdmissionItem) -> RequestAdmissionLease: + try: + asyncio.get_running_loop() + except RuntimeError: + pass + else: + raise RuntimeError("acquire_sync would block the running event loop; use acquire_async instead.") + + timeout = ( + item.queue_wait_timeout_seconds + if item.queue_wait_timeout_seconds is not None + else self._config.default_queue_wait_timeout_seconds + ) + now = time.monotonic() + deadline = now + timeout if timeout is not None else None + waiter = RequestWaiter(waiter_id=uuid.uuid4().hex, item=item, enqueued_at=now, deadline_monotonic=deadline) + events: list[RequestAdmissionEvent] = [] + try: + while True: + with self._lock: + if waiter.assigned_lease is not None: + return waiter.assigned_lease + now = time.monotonic() + if deadline is not None and now >= deadline: + self._remove_waiter_locked(waiter) + denied = RequestAdmissionDenied( + item=item, + reason="queue_timeout", + snapshot=self._snapshot_locked(item.resource, now), + ) + events.append(self._request_event_locked("request_wait_timeout", item=item, decision=denied)) + raise RequestAdmissionError(denied) + if not self._queue.contains(waiter.waiter_id) and waiter.assigned_lease is None: + self._enqueue_waiter_locked(waiter, events) + self._admit_waiters_locked(events) + if waiter.assigned_lease is not None: + return waiter.assigned_lease + now = time.monotonic() + if (denied := self._terminal_denial_for(item, now)) is not None: + self._remove_waiter_locked(waiter) + events.append(self._request_event_locked("request_acquire_denied", item=item, decision=denied)) + self._condition.notify_all() + raise RequestAdmissionError(denied) + if deadline is not None and now >= deadline: + self._remove_waiter_locked(waiter) + denied = RequestAdmissionDenied( + item=item, + reason="queue_timeout", + snapshot=self._snapshot_locked(item.resource, now), + ) + events.append(self._request_event_locked("request_wait_timeout", item=item, decision=denied)) + raise RequestAdmissionError(denied) + wait = self._wait_seconds_locked(item, now, deadline) + self._condition.wait(timeout=wait) + finally: + self._emit_events(events) + + async def acquire_async(self, item: RequestAdmissionItem) -> RequestAdmissionLease: + loop = asyncio.get_running_loop() + wakeup = asyncio.Event() + timeout = ( + item.queue_wait_timeout_seconds + if item.queue_wait_timeout_seconds is not None + else self._config.default_queue_wait_timeout_seconds + ) + now = time.monotonic() + deadline = now + timeout if timeout is not None else None + waiter = RequestWaiter( + waiter_id=uuid.uuid4().hex, + item=item, + enqueued_at=now, + deadline_monotonic=deadline, + wakeup=lambda: loop.call_soon_threadsafe(wakeup.set), + ) + events: list[RequestAdmissionEvent] = [] + try: + while True: + with self._lock: + if waiter.assigned_lease is not None: + return waiter.assigned_lease + now = time.monotonic() + if deadline is not None and now >= deadline: + self._remove_waiter_locked(waiter) + denied = RequestAdmissionDenied( + item=item, + reason="queue_timeout", + snapshot=self._snapshot_locked(item.resource, now), + ) + events.append(self._request_event_locked("request_wait_timeout", item=item, decision=denied)) + raise RequestAdmissionError(denied) + if not self._queue.contains(waiter.waiter_id) and waiter.assigned_lease is None: + self._enqueue_waiter_locked(waiter, events) + self._admit_waiters_locked(events) + if waiter.assigned_lease is not None: + return waiter.assigned_lease + now = time.monotonic() + if (denied := self._terminal_denial_for(item, now)) is not None: + self._remove_waiter_locked(waiter) + events.append(self._request_event_locked("request_acquire_denied", item=item, decision=denied)) + self._condition.notify_all() + raise RequestAdmissionError(denied) + if deadline is not None and now >= deadline: + self._remove_waiter_locked(waiter) + denied = RequestAdmissionDenied( + item=item, + reason="queue_timeout", + snapshot=self._snapshot_locked(item.resource, now), + ) + events.append(self._request_event_locked("request_wait_timeout", item=item, decision=denied)) + raise RequestAdmissionError(denied) + wait = self._wait_seconds_locked(item, now, deadline) + try: + await asyncio.wait_for(wakeup.wait(), timeout=wait) + except asyncio.TimeoutError: + pass + wakeup.clear() + except asyncio.CancelledError: + lease_to_release: RequestAdmissionLease | None = None + with self._lock: + lease_to_release = waiter.assigned_lease + if lease_to_release is None: + self._remove_waiter_locked(waiter) + denied = RequestAdmissionDenied(item=item, reason="cancellation") + events.append( + self._request_event_locked( + "request_wait_cancelled", + item=item, + lease=lease_to_release, + decision=denied, + ) + ) + self._condition.notify_all() + if lease_to_release is not None: + self._emit_events(events) + events.clear() + self.release(lease_to_release, RequestReleaseOutcome(kind="local_cancelled")) + raise + finally: + self._emit_events(events) + + def release(self, lease: RequestAdmissionLease, outcome: RequestReleaseOutcome) -> ReleaseResult: + now = time.monotonic() + events: list[RequestAdmissionEvent] = [] + result: ReleaseResult + with self._lock: + if lease.controller_generation != self._generation: + self._release_diagnostics["wrong_controller_generation"] += 1 + result = ReleaseResult(released=False, reason="wrong_controller_generation") + events.append( + self._request_event_locked( + "request_release_diagnostic", item=lease.item, lease=lease, result=result + ) + ) + elif (active := self._active_leases.pop(lease.lease_id, None)) is None: + reason = "duplicate" if lease.lease_id in self._released else "unknown_lease" + self._release_diagnostics[reason] += 1 + result = ReleaseResult(released=False, reason=reason) + events.append( + self._request_event_locked( + "request_release_diagnostic", item=lease.item, lease=lease, result=result + ) + ) + elif active != lease: + self._active_leases[lease.lease_id] = active + self._release_diagnostics["stale_lease"] += 1 + result = ReleaseResult(released=False, reason="stale_lease") + events.append( + self._request_event_locked( + "request_release_diagnostic", item=lease.item, lease=lease, result=result + ) + ) + else: + self._remember_released_locked(lease.lease_id) + resource = active.item.resource + provider_model = resource.provider_model_key + state = self._get_or_create_state(resource) + state.in_flight = max(0, state.in_flight - 1) + state.active_lease_count = max(0, state.active_lease_count - 1) + state.last_outcome = outcome.kind + self._aggregate_in_flight[provider_model] = max(0, self._aggregate_in_flight[provider_model] - 1) + self._aggregate_active_leases[provider_model] = max( + 0, + self._aggregate_active_leases[provider_model] - 1, + ) + self._apply_outcome(state, resource, active.current_adaptive_limit, outcome, now, events) + self._sequence += 1 + result = ReleaseResult(released=True, reason="released") + if outcome.kind == "rate_limited": + events.append(self._request_event_locked("request_rate_limited", item=active.item, lease=active)) + events.append( + self._request_event_locked( + "request_lease_released", + item=active.item, + lease=active, + result=result, + outcome=outcome, + ) + ) + self._admit_waiters_locked(events) + self._condition.notify_all() + self._emit_events(events) + return result + + def snapshot(self, resource: RequestResourceKey) -> RequestPressureSnapshot | None: + with self._lock: + if resource not in self._domains: + return None + return self._snapshot_locked(resource, time.monotonic()) + + def snapshots(self) -> Mapping[RequestResourceKey, RequestPressureSnapshot]: + with self._lock: + now = time.monotonic() + return {resource: self._snapshot_locked(resource, now) for resource in self._domains} + + def global_snapshot(self, provider: str, model: str) -> ProviderModelPressureSnapshot | None: + with self._lock: + key = ProviderModelKey(provider, model) + if key not in self._global_caps: + return None + return self._global_snapshot_locked(key, time.monotonic()) + + def global_snapshots(self) -> Mapping[ProviderModelKey, ProviderModelPressureSnapshot]: + with self._lock: + now = time.monotonic() + return {key: self._global_snapshot_locked(key, now) for key in self._global_caps} + + def _queued_waiter_ahead_locked(self, item: RequestAdmissionItem, now: float) -> bool: + if not self._queue.has_waiters: + return False + self._expire_waiters_locked(now) + selection = self._queue.select_next(lambda waiter, _view: self._denial_for(waiter.item, now) is None) + if selection is None: + return False + selected_key = selection.item.resource.provider_model_key + return selected_key == item.resource.provider_model_key or selection.item.resource == item.resource + + def _enqueue_waiter_locked(self, waiter: RequestWaiter, events: list[RequestAdmissionEvent]) -> None: + if self._queue.enqueue(waiter): + self._get_or_create_state(waiter.item.resource).waiters += 1 + self._sequence += 1 + if self._queue.view().queued_total == 1: + events.append(self._request_event_locked("request_queue_formed", item=waiter.item)) + events.append(self._request_event_locked("request_wait_started", item=waiter.item)) + + def _remove_waiter_locked(self, waiter: RequestWaiter) -> None: + removed = self._queue.remove(waiter.waiter_id) + if removed is None: + return + state = self._get_or_create_state(waiter.item.resource) + state.waiters = max(0, state.waiters - 1) + self._sequence += 1 + + def _expire_waiters_locked(self, now: float) -> None: + for waiter in self._queue.waiters(): + if waiter.deadline_monotonic is not None and now >= waiter.deadline_monotonic: + self._remove_waiter_locked(waiter) + self._wake_waiter_locked(waiter) + + def _admit_waiters_locked(self, events: list[RequestAdmissionEvent]) -> None: + while self._queue.has_waiters: + now = time.monotonic() + self._expire_waiters_locked(now) + if not self._queue.has_waiters: + return + selection = self._queue.select_next(lambda waiter, _view: self._denial_for(waiter.item, now) is None) + if selection is None: + return + waiter = self._queue.commit(selection) + if waiter is None: + return + state = self._get_or_create_state(waiter.item.resource) + state.waiters = max(0, state.waiters - 1) + lease = self._acquire_locked(waiter.item, now) + waiter.assigned_lease = lease + self._wake_waiter_locked(waiter) + events.append(self._request_event_locked("request_wait_completed", item=waiter.item, lease=lease)) + events.append(self._request_event_locked("request_lease_acquired", item=waiter.item, lease=lease)) + if not self._queue.has_waiters: + events.append(self._request_event_locked("request_queue_drained", item=waiter.item)) + + def _wake_waiter_locked(self, waiter: RequestWaiter) -> None: + if waiter.wakeup is None: + return + waiter.wakeup() + + def _wait_seconds_locked( + self, + item: RequestAdmissionItem, + now: float, + deadline: float | None, + ) -> float: + candidates = [0.05] + if deadline is not None: + candidates.append(max(0.0, deadline - now)) + state = self._domains.get(item.resource) + if state is not None and state.blocked_until > now: + candidates.append(max(0.0, state.blocked_until - now)) + return max(0.0, min(candidates)) + + def _denial_for(self, item: RequestAdmissionItem, now: float) -> RequestAdmissionDenied | None: + resource = item.resource + provider_model = resource.provider_model_key + if provider_model not in self._global_caps: + return RequestAdmissionDenied(item=item, reason="hard_policy_denial", diagnostics={"unregistered": True}) + state = self._get_or_create_state(resource) + if now < state.blocked_until: + return RequestAdmissionDenied( + item=item, + reason="cooldown", + retry_after_seconds=state.blocked_until - now, + available_after_monotonic=state.blocked_until, + snapshot=self._snapshot_locked(resource, now), + ) + effective_max = self._effective_max_for_resource(resource) + aggregate_cap = self._global_caps[provider_model].effective_max + if state.in_flight >= min(state.current_limit, effective_max): + return RequestAdmissionDenied( + item=item, reason="no_capacity", snapshot=self._snapshot_locked(resource, now) + ) + if self._aggregate_in_flight[provider_model] >= aggregate_cap: + return RequestAdmissionDenied( + item=item, reason="no_capacity", snapshot=self._snapshot_locked(resource, now) + ) + return None + + def _terminal_denial_for(self, item: RequestAdmissionItem, now: float) -> RequestAdmissionDenied | None: + denied = self._denial_for(item, now) + if denied is None or denied.reason not in _TERMINAL_DENIAL_REASONS: + return None + return denied + + def _remember_released_locked(self, lease_id: str) -> None: + if lease_id in self._released: + return + maxlen = self._released_order.maxlen + if maxlen is not None and len(self._released_order) >= maxlen: + self._released.discard(self._released_order[0]) + self._released.add(lease_id) + self._released_order.append(lease_id) + + def _acquire_locked(self, item: RequestAdmissionItem, now: float) -> RequestAdmissionLease: + resource = item.resource + provider_model = resource.provider_model_key + state = self._get_or_create_state(resource) + state.in_flight += 1 + state.active_lease_count += 1 + self._aggregate_in_flight[provider_model] += 1 + self._aggregate_active_leases[provider_model] += 1 + lease = RequestAdmissionLease( + lease_id=uuid.uuid4().hex, + item=item, + acquired_at=now, + current_adaptive_limit=state.current_limit, + effective_max=self._effective_max_for_resource(resource), + controller_generation=self._generation, + ) + self._active_leases[lease.lease_id] = lease + self._sequence += 1 + return lease + + def _apply_outcome( + self, + state: AdaptiveRequestLimitState, + resource: RequestResourceKey, + admitted_adaptive_limit: int, + outcome: RequestReleaseOutcome, + now: float, + events: list[RequestAdmissionEvent], + ) -> None: + effective_max = self._effective_max_for_resource(resource) + if outcome.kind == "rate_limited": + prev_limit = state.current_limit + should_decrease = admitted_adaptive_limit <= prev_limit + state.consecutive_rate_limits += 1 + cooldown = ( + outcome.retry_after_seconds + if outcome.retry_after_seconds is not None and outcome.retry_after_seconds > 0 + else self._config.cooldown_seconds + ) + state.blocked_until = now + cooldown + state.success_streak = 0 + if should_decrease: + state.current_limit = max( + 1, math.floor(state.current_limit * self._config.multiplicative_decrease_factor) + ) + if state.rate_limit_ceiling == 0: + state.rate_limit_ceiling = max(1, admitted_adaptive_limit) + if state.current_limit != prev_limit: + events.append( + self._request_event_locked( + "request_limit_decreased", + request_resource_key=resource, + diagnostics={"previous": prev_limit, "current": state.current_limit}, + ) + ) + return + if outcome.kind == "success" and now >= state.blocked_until: + prev_limit = state.current_limit + state.consecutive_rate_limits = 0 + state.success_streak += 1 + if state.success_streak >= self._config.increase_after_successes: + state.current_limit = min(effective_max, state.current_limit + self._config.additive_increase_step) + state.success_streak = 0 + if state.current_limit != prev_limit: + events.append( + self._request_event_locked( + "request_limit_increased", + request_resource_key=resource, + diagnostics={"previous": prev_limit, "current": state.current_limit}, + ) + ) + if state.rate_limit_ceiling and state.current_limit > state.rate_limit_ceiling: + events.append( + self._request_event_locked( + "request_soft_ceiling_recovered", + request_resource_key=resource, + diagnostics={"rate_limit_ceiling": state.rate_limit_ceiling}, + ) + ) + if state.current_limit == effective_max and state.blocked_until <= now: + events.append( + self._request_event_locked("request_fully_recovered", request_resource_key=resource) + ) + return + if state.in_flight == 0 and outcome.kind not in {"local_cancelled", "local_timeout"}: + state.consecutive_rate_limits = 0 + + def _increment_waiter(self, item: RequestAdmissionItem) -> None: + with self._lock: + self._get_or_create_state(item.resource).waiters += 1 + self._sequence += 1 + + def _decrement_waiter(self, item: RequestAdmissionItem) -> None: + with self._lock: + state = self._get_or_create_state(item.resource) + state.waiters = max(0, state.waiters - 1) + self._sequence += 1 + + def _get_or_create_state(self, resource: RequestResourceKey) -> AdaptiveRequestLimitState: + state = self._domains.get(resource) + if state is None: + effective_max = self._effective_max_for_resource(resource) + initial = self._config.initial_limits.get(resource, effective_max) + state = AdaptiveRequestLimitState(current_limit=max(1, min(initial, effective_max))) + self._domains[resource] = state + return state + + def _effective_max_for_resource(self, resource: RequestResourceKey) -> int: + provider_model_cap = self._global_caps.get(resource.provider_model_key) + static_cap = provider_model_cap.effective_max if provider_model_cap is not None else 1 + clamp = self._config.max_limit_clamps.get(resource) + return max(1, min(static_cap, clamp if clamp is not None else static_cap)) + + def _snapshot_locked(self, resource: RequestResourceKey, now: float) -> RequestPressureSnapshot: + state = self._get_or_create_state(resource) + blocked_until = state.blocked_until if state.blocked_until > now else None + return RequestPressureSnapshot( + captured_at=now, + sequence=self._sequence, + resource=resource, + effective_max=self._effective_max_for_resource(resource), + current_limit=state.current_limit, + in_flight_count=state.in_flight, + active_lease_count=state.active_lease_count, + waiters=state.waiters, + blocked_until_monotonic=blocked_until, + cooldown_remaining_seconds=max(0.0, state.blocked_until - now), + rate_limit_ceiling=state.rate_limit_ceiling, + consecutive_rate_limits=state.consecutive_rate_limits, + last_outcome=state.last_outcome, + leak_diagnostics=dict(self._release_diagnostics), + ) + + def _global_snapshot_locked(self, key: ProviderModelKey, now: float) -> ProviderModelPressureSnapshot: + cap = self._global_caps[key] + domains = { + resource.domain: state.current_limit + for resource, state in self._domains.items() + if resource.provider_model_key == key + } + return ProviderModelPressureSnapshot( + captured_at=now, + sequence=self._sequence, + provider_model=key, + static_cap=cap.effective_max, + aggregate_in_flight=self._aggregate_in_flight[key], + aggregate_active_lease_count=self._aggregate_active_leases[key], + aliases=tuple(sorted(cap.limits_by_alias)), + raw_caps=dict(cap.limits_by_alias), + domains=domains, + ) + + def _request_event_locked( + self, + event_kind: str, + *, + item: RequestAdmissionItem | None = None, + lease: RequestAdmissionLease | None = None, + decision: RequestAdmissionDenied | None = None, + result: ReleaseResult | None = None, + outcome: RequestReleaseOutcome | None = None, + request_resource_key: RequestResourceKey | None = None, + diagnostics: Mapping[str, object] | None = None, + ) -> RequestAdmissionEvent: + self._sequence += 1 + event_context = item.event_context if item is not None else None + resource = request_resource_key or (item.resource if item is not None else None) + group_key = item.group.key if item is not None else None + reason_or_outcome = None + if decision is not None: + reason_or_outcome = decision.reason + elif outcome is not None: + reason_or_outcome = outcome.kind + elif result is not None: + reason_or_outcome = result.reason + return RequestAdmissionEvent.capture( + event_kind, # type: ignore[arg-type] + sequence=self._sequence, + correlation=event_context.captured_correlation + if event_context is not None + else runtime_correlation_provider.current(), + request_attempt_id=event_context.request_attempt_id if event_context is not None else None, + request_lease_id=lease.lease_id if lease is not None else None, + request_resource_key=resource, + request_group_key=group_key, + reason_or_outcome=reason_or_outcome, + pressure_snapshot=self._snapshot_locked(resource, time.monotonic()) if resource is not None else None, + diagnostics=dict(diagnostics or {}), + ) + + def _emit_events(self, events: list[RequestAdmissionEvent]) -> None: + if self._event_sink is None: + return + for event in events: + try: + self._event_sink.emit_request_event(event) + except Exception: + logger.warning("Request admission event sink raised; dropping event.", exc_info=True) + continue diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/limits.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/limits.py new file mode 100644 index 000000000..ad6f779d7 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/limits.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class AdaptiveRequestLimitState: + current_limit: int + in_flight: int = 0 + blocked_until: float = 0.0 + success_streak: int = 0 + waiters: int = 0 + rate_limit_ceiling: int = 0 + consecutive_rate_limits: int = 0 + active_lease_count: int = 0 + last_outcome: str | None = None diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/outcomes.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/outcomes.py new file mode 100644 index 000000000..3399b07f4 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/outcomes.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Literal + + +@dataclass(frozen=True) +class RequestReleaseOutcome: + kind: Literal[ + "success", + "rate_limited", + "provider_failure", + "provider_timeout", + "local_cancelled", + "local_timeout", + "unexpected_exception", + ] + retry_after_seconds: float | None = None + provider_status: int | None = None + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ReleaseResult: + released: bool + reason: Literal["released", "duplicate", "stale_lease", "wrong_controller_generation", "unknown_lease"] + diagnostics: Mapping[str, object] = field(default_factory=dict) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/pressure.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/pressure.py new file mode 100644 index 000000000..a268f8898 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/pressure.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Protocol + +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey + + +@dataclass(frozen=True) +class RequestPressureSnapshot: + captured_at: float + sequence: int + resource: RequestResourceKey + effective_max: int + current_limit: int + in_flight_count: int + active_lease_count: int + waiters: int + blocked_until_monotonic: float | None + cooldown_remaining_seconds: float + rate_limit_ceiling: int + consecutive_rate_limits: int + last_outcome: str | None + leak_diagnostics: Mapping[str, int] + + +@dataclass(frozen=True) +class ProviderModelPressureSnapshot: + captured_at: float + sequence: int + provider_model: ProviderModelKey + static_cap: int + aggregate_in_flight: int + aggregate_active_lease_count: int + aliases: tuple[str, ...] + raw_caps: Mapping[str, int | None] + domains: Mapping[RequestDomain, int] + + +class RequestPressureSnapshotProvider(Protocol): + @property + def config(self) -> RequestAdmissionConfig | None: ... + + def snapshot(self, resource: RequestResourceKey) -> RequestPressureSnapshot | None: ... + + def snapshots(self) -> Mapping[RequestResourceKey, RequestPressureSnapshot]: ... + + def global_snapshot(self, provider: str, model: str) -> ProviderModelPressureSnapshot | None: ... + + def global_snapshots(self) -> Mapping[ProviderModelKey, ProviderModelPressureSnapshot]: ... diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/queue.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/queue.py new file mode 100644 index 000000000..cdca7027b --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/queue.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import heapq +from collections import Counter, deque +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from data_designer.engine.models.request_admission.resources import RequestAdmissionItem, RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey + +if TYPE_CHECKING: + from data_designer.engine.models.request_admission.controller import RequestAdmissionLease + + +@dataclass +class RequestWaiter: + waiter_id: str + item: RequestAdmissionItem + enqueued_at: float + deadline_monotonic: float | None = None + assigned_lease: RequestAdmissionLease | None = None + wakeup: Callable[[], None] | None = None + + +@dataclass(frozen=True) +class RequestQueueView: + queued_total: int + queued_by_group: Mapping[RequestResourceKey, int] + queued_demand_by_resource: Mapping[RequestResourceKey, int] + aggregate_provider_model_waiters: Mapping[ProviderModelKey, int] + + +@dataclass(frozen=True) +class RequestQueueSelection: + waiter: RequestWaiter + item: RequestAdmissionItem + waiter_id: str + queue_view: RequestQueueView + sequence_version: int + + +class RequestFairQueue: + """Weighted fair waiter queue used by request admission.""" + + def __init__(self) -> None: + self._queues: dict[RequestResourceKey, deque[RequestWaiter]] = {} + self._queued: dict[str, RequestWaiter] = {} + self._waiter_groups: dict[str, RequestResourceKey] = {} + self._group_finish: dict[RequestResourceKey, float] = {} + self._heap: list[tuple[float, int, RequestResourceKey]] = [] + self._active_heap_entries: dict[RequestResourceKey, tuple[float, int]] = {} + self._sequence = 0 + self._sequence_version = 0 + self._virtual_time = 0.0 + + @property + def has_waiters(self) -> bool: + return bool(self._queued) + + def contains(self, waiter_id: str) -> bool: + return waiter_id in self._queued + + def waiters(self) -> tuple[RequestWaiter, ...]: + return tuple(self._queued.values()) + + def enqueue(self, waiter: RequestWaiter) -> bool: + if waiter.waiter_id in self._queued: + return False + key = waiter.item.group.key + queue = self._queues.setdefault(key, deque()) + queue.append(waiter) + self._queued[waiter.waiter_id] = waiter + self._waiter_groups[waiter.waiter_id] = key + self._activate_group(key) + self._sequence_version += 1 + return True + + def remove(self, waiter_id: str) -> RequestWaiter | None: + waiter = self._queued.pop(waiter_id, None) + if waiter is None: + return None + self._waiter_groups.pop(waiter_id, None) + self._sequence_version += 1 + return waiter + + def select_next( + self, is_eligible: Callable[[RequestWaiter, RequestQueueView], bool] + ) -> RequestQueueSelection | None: + view = self.view() + heap_copy = list(self._heap) + heapq.heapify(heap_copy) + active_seen: set[RequestResourceKey] = set() + while heap_copy: + finish, sequence, key = heapq.heappop(heap_copy) + if key in active_seen: + continue + if self._active_heap_entries.get(key) != (finish, sequence): + continue + active_seen.add(key) + waiter = self._first_valid_waiter(key) + if waiter is None: + continue + if not is_eligible(waiter, view): + continue + return RequestQueueSelection( + waiter=waiter, + item=waiter.item, + waiter_id=waiter.waiter_id, + queue_view=view, + sequence_version=self._sequence_version, + ) + return None + + def commit(self, selection: RequestQueueSelection) -> RequestWaiter | None: + if selection.sequence_version != self._sequence_version: + return None + key = self._waiter_groups.get(selection.waiter_id) + if key is None or key != selection.item.group.key: + return None + queue = self._queues.get(key) + if queue is None: + return None + self._purge_queue_head(key) + if not queue or queue[0].waiter_id != selection.waiter_id: + return None + + waiter = queue.popleft() + self._queued.pop(waiter.waiter_id, None) + self._waiter_groups.pop(waiter.waiter_id, None) + self._active_heap_entries.pop(key, None) + weight = max(selection.item.group.weight, 1.0) + finish = self._group_finish.get(key, self._virtual_time) + self._virtual_time = max(self._virtual_time, finish) + self._group_finish[key] = self._virtual_time + (1.0 / weight) + self._sequence_version += 1 + self._purge_queue_head(key) + if queue: + self._activate_group(key) + return waiter + + def view(self) -> RequestQueueView: + queued_by_group: Counter[RequestResourceKey] = Counter() + demand_by_resource: Counter[RequestResourceKey] = Counter() + aggregate_waiters: Counter[ProviderModelKey] = Counter() + for waiter in self._queued.values(): + resource = waiter.item.resource + queued_by_group[waiter.item.group.key] += 1 + demand_by_resource[resource] += 1 + aggregate_waiters[resource.provider_model_key] += 1 + return RequestQueueView( + queued_total=len(self._queued), + queued_by_group=dict(queued_by_group), + queued_demand_by_resource=dict(demand_by_resource), + aggregate_provider_model_waiters=dict(aggregate_waiters), + ) + + def _activate_group(self, key: RequestResourceKey) -> None: + self._purge_queue_head(key) + queue = self._queues.get(key) + if not queue or key in self._active_heap_entries: + return + self._sequence += 1 + finish = self._group_finish.get(key, self._virtual_time) + heapq.heappush(self._heap, (finish, self._sequence, key)) + self._active_heap_entries[key] = (finish, self._sequence) + + def _first_valid_waiter(self, key: RequestResourceKey) -> RequestWaiter | None: + queue = self._queues.get(key) + if queue is None: + return None + for waiter in queue: + if waiter.waiter_id in self._queued and self._waiter_groups.get(waiter.waiter_id) == key: + return waiter + return None + + def _purge_queue_head(self, key: RequestResourceKey) -> None: + queue = self._queues.get(key) + if queue is None: + return + while queue: + waiter = queue[0] + if waiter.waiter_id in self._queued and self._waiter_groups.get(waiter.waiter_id) == key: + break + queue.popleft() diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resolver.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resolver.py new file mode 100644 index 000000000..462e77427 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resolver.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey + + +@dataclass(frozen=True) +class ResolvedRequestResource: + provider_model: ProviderModelKey + resource: RequestResourceKey + aliases: tuple[str, ...] = () + generation_kind: str | None = None + + +class RequestResourceResolver: + """Canonical provider/model/domain request-resource identity factory.""" + + def resolve( + self, + *, + provider_name: str, + model_id: str, + domain: RequestDomain, + model_alias: str | None = None, + provider_alias: str | None = None, + generation_kind: str | None = None, + ) -> ResolvedRequestResource: + resource = RequestResourceKey(provider_name=provider_name, model_id=model_id, domain=domain) + aliases = tuple(alias for alias in (provider_alias, model_alias) if alias) + return ResolvedRequestResource( + provider_model=resource.provider_model_key, + resource=resource, + aliases=aliases, + generation_kind=generation_kind, + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resources.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resources.py new file mode 100644 index 000000000..b7b4bd2cd --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resources.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + +from data_designer.engine.models.resources import ProviderModelKey + + +class RequestDomain(str, Enum): + CHAT = "chat" + EMBEDDING = "embedding" + IMAGE = "image" + HEALTHCHECK = "healthcheck" + + +@dataclass(frozen=True, order=True) +class RequestResourceKey: + provider_name: str + model_id: str + domain: RequestDomain + + @property + def provider_model_key(self) -> ProviderModelKey: + return ProviderModelKey(self.provider_name, self.model_id) + + +@dataclass(frozen=True) +class RequestGroupSpec: + key: RequestResourceKey + weight: float = 1.0 + + +@dataclass(frozen=True) +class RequestEventContext: + captured_correlation: object | None = None + task_execution_id: str | None = None + request_attempt_id: str | None = None + + +@dataclass(frozen=True) +class RequestAdmissionItem: + resource: RequestResourceKey + group: RequestGroupSpec + queue_wait_timeout_seconds: float | None = None + event_context: RequestEventContext | None = None diff --git a/packages/data-designer-engine/src/data_designer/engine/models/resources.py b/packages/data-designer-engine/src/data_designer/engine/models/resources.py new file mode 100644 index 000000000..091e2936b --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/resources.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass + + +@dataclass(frozen=True, order=True) +class ProviderModelKey: + provider_name: str + model_id: str + + +@dataclass +class ProviderModelStaticCap: + cap: int + aliases: tuple[str, ...] + raw_caps: Mapping[str, int | None] + merge_rule: str = "min_same_endpoint" diff --git a/packages/data-designer-engine/src/data_designer/engine/observability.py b/packages/data-designer-engine/src/data_designer/engine/observability.py new file mode 100644 index 000000000..aa04e4a4e --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/observability.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import contextvars +import time +from dataclasses import dataclass, field +from typing import Literal, Protocol + + +@dataclass(frozen=True) +class RuntimeCorrelation: + run_id: str + row_group: int | None + task_column: str | None + task_type: str | None + scheduling_group_kind: str | None + scheduling_group_identity_hash: str | None + task_execution_id: str | None + + +class RuntimeCorrelationProvider: + """Context-variable backed runtime correlation provider.""" + + def __init__(self) -> None: + self._current: contextvars.ContextVar[RuntimeCorrelation | None] = contextvars.ContextVar( + "data_designer_runtime_correlation", + default=None, + ) + + def current(self) -> RuntimeCorrelation | None: + return self._current.get() + + def set(self, correlation: RuntimeCorrelation | None) -> contextvars.Token: + return self._current.set(correlation) + + def reset(self, token: contextvars.Token) -> None: + self._current.reset(token) + + +runtime_correlation_provider = RuntimeCorrelationProvider() + +SchedulerAdmissionEventKind = Literal[ + "scheduler_job_started", + "scheduler_job_completed", + "scheduler_health_snapshot", + "dependency_ready", + "ready_enqueued", + "row_group_admitted", + "row_group_admission_blocked", + "row_group_admission_target_changed", + "row_group_checkpointed", + "selected", + "queue_empty", + "admission_blocked", + "group_capped", + "request_pressure_advisory_skipped", + "task_lease_acquired", + "admission_denied", + "worker_spawned", + "worker_spawn_failed", + "stale_selection", + "retry_deferred", + "non_retryable_dropped", + "cancelled", + "salvage_redispatched", + "queue_drained", + "task_completed", + "task_lease_released", + "release_diagnostic", +] + +RequestAdmissionEventKind = Literal[ + "request_resource_registered", + "request_effective_cap_changed", + "request_queue_formed", + "request_wait_started", + "request_wait_completed", + "request_wait_timeout", + "request_wait_cancelled", + "request_acquire_denied", + "request_lease_acquired", + "model_request_started", + "model_request_completed", + "request_queue_drained", + "request_rate_limited", + "request_limit_decreased", + "request_limit_increased", + "request_soft_ceiling_recovered", + "request_fully_recovered", + "request_lease_released", + "request_release_diagnostic", +] + + +@dataclass(frozen=True) +class SchedulerAdmissionEvent: + event_kind: SchedulerAdmissionEventKind + captured_at_monotonic: float + sequence: int + captured_correlation: RuntimeCorrelation | None = None + task_id: str | None = None + task_execution_id: str | None = None + task_lease_id: str | None = None + scheduler_resource_key: str | None = None + reason_or_result: str | None = None + snapshot: object | None = None + diagnostics: dict[str, object] = field(default_factory=dict) + + @classmethod + def capture( + cls, + event_kind: SchedulerAdmissionEventKind, + *, + sequence: int, + correlation: RuntimeCorrelation | None = None, + **kwargs: object, + ) -> SchedulerAdmissionEvent: + return cls( + event_kind=event_kind, + captured_at_monotonic=time.monotonic(), + sequence=sequence, + captured_correlation=correlation, + **kwargs, + ) + + +@dataclass(frozen=True) +class RequestAdmissionEvent: + event_kind: RequestAdmissionEventKind + captured_at_monotonic: float + sequence: int + captured_correlation: RuntimeCorrelation | None = None + request_attempt_id: str | None = None + request_lease_id: str | None = None + request_resource_key: object | None = None + request_group_key: object | None = None + reason_or_outcome: str | None = None + pressure_snapshot: object | None = None + diagnostics: dict[str, object] = field(default_factory=dict) + + @classmethod + def capture( + cls, + event_kind: RequestAdmissionEventKind, + *, + sequence: int, + correlation: RuntimeCorrelation | None = None, + **kwargs: object, + ) -> RequestAdmissionEvent: + return cls( + event_kind=event_kind, + captured_at_monotonic=time.monotonic(), + sequence=sequence, + captured_correlation=correlation, + **kwargs, + ) + + +class SchedulerAdmissionEventSink(Protocol): + def emit_scheduler_event(self, event: SchedulerAdmissionEvent) -> None: ... + + +class RequestAdmissionEventSink(Protocol): + def emit_request_event(self, event: RequestAdmissionEvent) -> None: ... + + +class InMemoryAdmissionEventSink: + """Small sink used by tests, diagnostics, and benchmark smoke runs.""" + + def __init__(self) -> None: + self.scheduler_events: list[SchedulerAdmissionEvent] = [] + self.request_events: list[RequestAdmissionEvent] = [] + + def emit_scheduler_event(self, event: SchedulerAdmissionEvent) -> None: + self.scheduler_events.append(event) + + def emit_request_event(self, event: RequestAdmissionEvent) -> None: + self.request_events.append(event) + + +@dataclass(frozen=True) +class CorrelatedRuntimeView: + scheduler_events: tuple[SchedulerAdmissionEvent, ...] + request_events: tuple[RequestAdmissionEvent, ...] + + @property + def timeline(self) -> tuple[SchedulerAdmissionEvent | RequestAdmissionEvent, ...]: + return tuple( + sorted( + (*self.scheduler_events, *self.request_events), + key=lambda event: (event.captured_at_monotonic, event.sequence), + ) + ) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py similarity index 80% rename from packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py rename to packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py index 2ec7b4cd3..e647d4ac6 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py @@ -14,9 +14,10 @@ SamplerColumnConfig, ) from data_designer.config.sampler_params import SamplerType -from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker +from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker +from data_designer.engine.dataset_builders.scheduling.resources import stable_task_id +from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef, Task from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph -from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task MODEL_ALIAS = "stub" @@ -189,7 +190,16 @@ def test_get_ready_tasks_seed_frontier(ready_ctx: ReadyTasksFixture) -> None: assert len(ready) == 1 assert ready[0].column == "topic" - assert ready[0].task_type == "batch" + assert ready[0].task_type == "from_scratch" + + +def test_mark_enqueued_uses_scheduler_stable_task_id(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.seed_frontier() + task = ready_ctx.tracker.ready_frontier()[0] + + ready_ctx.tracker.mark_enqueued({stable_task_id(task)}) + + assert ready_ctx.tracker.ready_frontier() == () def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> None: @@ -205,6 +215,53 @@ def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> No assert delta.removed == () +def test_fan_out_cell_completion_readies_all_children_for_same_row() -> None: + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="heavy", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="child_a", prompt="{{ heavy }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="child_b", prompt="{{ heavy }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="child_c", prompt="{{ heavy }}", model_alias=MODEL_ALIAS), + ] + strategies = {config.name: GenerationStrategy.CELL_BY_CELL for config in configs[1:]} + strategies["topic"] = GenerationStrategy.FULL_COLUMN + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 2)]) + tracker.mark_row_range_complete("topic", 0, 2) + + delta = tracker.mark_cell_complete("heavy", 0, 0) + + assert {task.column for task in delta.added} == {"child_a", "child_b", "child_c"} + assert {task.row_index for task in delta.added} == {0} + ready = tracker.get_ready_tasks(set()) + assert not any(task.column.startswith("child_") and task.row_index == 1 for task in ready) + + +def test_fan_in_cell_downstream_waits_for_all_same_row_upstreams() -> None: + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="up_a", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="up_b", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="up_c", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="judge", prompt="{{ up_a }} {{ up_b }} {{ up_c }}", model_alias=MODEL_ALIAS), + ] + strategies = {config.name: GenerationStrategy.CELL_BY_CELL for config in configs[1:]} + strategies["topic"] = GenerationStrategy.FULL_COLUMN + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 2)]) + tracker.mark_row_range_complete("topic", 0, 2) + + first_delta = tracker.mark_cell_complete("up_a", 0, 0) + second_delta = tracker.mark_cell_complete("up_b", 0, 0) + final_delta = tracker.mark_cell_complete("up_c", 0, 0) + + assert not any(task.column == "judge" for task in first_delta.added) + assert not any(task.column == "judge" for task in second_delta.added) + assert final_delta.added == (Task(column="judge", row_group=0, row_index=0, task_type="cell"),) + ready = tracker.get_ready_tasks(set()) + assert not any(task.column == "judge" and task.row_index == 1 for task in ready) + + def test_get_ready_tasks_skips_dispatched(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py new file mode 100644 index 000000000..e2a9179f0 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import Counter + +from data_designer.engine.dataset_builders.scheduling.queue import FairTaskQueue, QueueView +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceRequest, + TaskGroupKey, + TaskGroupSpec, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_model import Task + + +def _task(column: str, row_index: int) -> Task: + return Task(column=column, row_group=0, row_index=row_index, task_type="cell") + + +def _group(name: str, *, weight: float = 1.0, admitted_limit: int | None = None) -> TaskGroupSpec: + return TaskGroupSpec( + key=TaskGroupKey(kind="local", identity=(name,)), + weight=weight, + admitted_limit=admitted_limit, + ) + + +def _item(column: str, row_index: int, group: TaskGroupSpec | None = None) -> SchedulableTask: + task = _task(column, row_index) + group = group or _group(column) + return SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=SchedulerResourceRequest({"submission": 1}), + ) + + +def _select_and_commit(queue: FairTaskQueue) -> SchedulableTask | None: + selection = queue.select_next(lambda _item, _view: True) + if selection is None: + return None + return queue.commit(selection) + + +def test_fair_task_queue_equal_groups_round_robins() -> None: + queue = FairTaskQueue() + queue.enqueue( + [ + _item("a", 0), + _item("a", 1), + _item("b", 0), + _item("b", 1), + _item("c", 0), + _item("c", 1), + ] + ) + + selected = [_select_and_commit(queue) for _ in range(6)] + + assert [item.payload.column for item in selected if item is not None] == ["a", "b", "c", "a", "b", "c"] + + +def test_fair_task_queue_weighted_groups() -> None: + queue = FairTaskQueue() + queue.enqueue( + [_item("a", i, _group("a", weight=2)) for i in range(6)] + + [_item("b", i, _group("b", weight=1)) for i in range(6)] + ) + + selected = [_select_and_commit(queue) for _ in range(6)] + counts = Counter(item.payload.column for item in selected if item is not None) + + assert counts == {"a": 4, "b": 2} + + +def test_select_next_is_non_mutating_until_commit() -> None: + queue = FairTaskQueue() + first = _item("a", 0) + second = _item("b", 0) + queue.enqueue([first, second]) + + selection = queue.select_next(lambda _item, _view: True) + + assert selection is not None + assert queue.view().queued_total == 2 + committed = queue.commit(selection) + assert committed == first + assert queue.view().queued_total == 1 + + +def test_commit_rejects_stale_selection() -> None: + queue = FairTaskQueue() + first = _item("a", 0) + queue.enqueue([first]) + + selection = queue.select_next(lambda _item, _view: True) + assert selection is not None + queue.enqueue([_item("b", 0)]) + + assert queue.commit(selection) is None + assert queue.view().queued_total == 2 + + +def test_select_next_uses_scheduler_eligibility_callback() -> None: + queue = FairTaskQueue() + queue.enqueue([_item("a", 0), _item("b", 0)]) + + selection = queue.select_next(lambda item, _view: item.payload.column == "b") + + assert selection is not None + assert selection.item.payload.column == "b" + assert queue.commit(selection) == selection.item + + +def test_enqueue_is_idempotent_by_task_id() -> None: + queue = FairTaskQueue() + item = _item("a", 0) + + first = queue.enqueue([item]) + second = queue.enqueue([item]) + + assert first == (item.task_id,) + assert second == () + assert queue.view().queued_total == 1 + + +def test_discard_where_removes_matching_tasks() -> None: + queue = FairTaskQueue() + queue.enqueue([_item(column, i) for column in ["a", "b"] for i in range(2)]) + + queue.discard_where(lambda item: item.payload.column == "a") + selected = [_select_and_commit(queue) for _ in range(2)] + + assert [item.payload.column for item in selected if item is not None] == ["b", "b"] + assert _select_and_commit(queue) is None + + +def test_queue_view_exposes_group_and_resource_demand() -> None: + queue = FairTaskQueue() + group = _group("a") + task = _task("a", 0) + item = SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 1}), + ) + + queue.enqueue([item]) + view: QueueView = queue.view() + + assert view.queued_total == 1 + assert view.queued_by_group[group.key] == 1 + assert view.queued_resource_demand_by_group[group.key]["llm_wait"] == 1 + assert view.first_candidate_resources_by_group[group.key]["submission"] == 1 diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resolver.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resolver.py new file mode 100644 index 000000000..1804dd272 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resolver.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from data_designer.config.column_configs import ExpressionColumnConfig +from data_designer.config.models import GenerationType +from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError +from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry +from data_designer.engine.dataset_builders.scheduling.resolver import TaskSchedulingResolver +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey + + +class _LocalGenerator: + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.local() + + +class _ModelGenerator: + def __init__(self, metadata: SchedulingMetadata) -> None: + self._metadata = metadata + + def get_scheduling_metadata(self) -> SchedulingMetadata: + return self._metadata + + +class _FallbackGenerator: + def get_scheduling_metadata(self) -> SchedulingMetadata: + raise SchedulingMetadataError( + code="partial", + message="using fallback", + fallback=SchedulingMetadata.local("fallback"), + diagnostics={"reason": "test"}, + ) + + +class _FatalGenerator: + def get_scheduling_metadata(self) -> SchedulingMetadata: + raise SchedulingMetadataError(code="fatal", message="fatal") + + +def _task(column: str = "answer") -> Task: + return Task(column=column, row_group=0, row_index=0, task_type="cell") + + +def test_task_scheduling_resolver_uses_local_default_metadata() -> None: + resolver = TaskSchedulingResolver({"answer": _LocalGenerator()}) # type: ignore[arg-type] + + schedulable = resolver.schedulable_task(_task(), ("answer",)) + + assert schedulable.group.key.kind == "local" + assert schedulable.resource_request.amounts == {"submission": 1} + + +def test_task_scheduling_resolver_maps_model_metadata_to_model_resource() -> None: + metadata = SchedulingMetadata.model("nvidia", "nemotron", "chat", weight=3) + resolver = TaskSchedulingResolver({"answer": _ModelGenerator(metadata)}) # type: ignore[arg-type] + + schedulable = resolver.schedulable_task(_task(), ("answer",)) + + assert schedulable.group.key.kind == "model" + assert schedulable.group.weight == 3.0 + assert schedulable.group.admitted_limit == 6 + assert schedulable.resource_request.amounts == {"submission": 1, "llm_wait": 1} + assert schedulable.request_resource_key == RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT) + + +def test_task_scheduling_resolver_records_safe_fallback_diagnostics() -> None: + resolver = TaskSchedulingResolver({"answer": _FallbackGenerator()}) # type: ignore[arg-type] + + schedulable = resolver.schedulable_task(_task(), ("answer",)) + + assert schedulable.group.key.identity[:2] == ("local", "fallback") + assert resolver.diagnostics[0]["code"] == "partial" + + +def test_task_scheduling_resolver_raises_fatal_metadata_error() -> None: + with pytest.raises(SchedulingMetadataError): + TaskSchedulingResolver({"answer": _FatalGenerator()}) # type: ignore[arg-type] + + +def test_model_registry_generator_metadata_deduplicates_same_endpoint_aliases() -> None: + class _RegistryGenerator(ColumnGeneratorWithModelRegistry[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> object: + return object() + + def generate(self, data: object) -> object: + return data + + config = ExpressionColumnConfig(name="answer", expr="{{ x }}", dtype="str") + generator = _RegistryGenerator(config=config, resource_provider=MagicMock()) + generator._get_scheduling_model_aliases = lambda: ["primary", "secondary"] # type: ignore[method-assign] + configs = { + "primary": SimpleNamespace( + model="endpoint", + generation_type=GenerationType.CHAT_COMPLETION, + inference_parameters=SimpleNamespace(max_parallel_requests=4), + ), + "secondary": SimpleNamespace( + model="endpoint", + generation_type=GenerationType.CHAT_COMPLETION, + inference_parameters=SimpleNamespace(max_parallel_requests=2), + ), + } + providers = { + "primary": SimpleNamespace(name="nvidia"), + "secondary": SimpleNamespace(name="nvidia"), + } + generator.get_model_config = lambda model_alias: configs[model_alias] # type: ignore[method-assign] + generator.get_model_provider_name = lambda model_alias: providers[model_alias].name # type: ignore[method-assign] + + metadata = generator.get_scheduling_metadata() + + assert metadata.identity == ("model", "nvidia", "endpoint", "chat") + assert metadata.weight == 2 + assert metadata.diagnostics["merge_rule"] == "min_same_endpoint" + + resolver = TaskSchedulingResolver({"answer": generator}) # type: ignore[arg-type] + schedulable = resolver.schedulable_task(_task(), ("answer",)) + assert schedulable.request_resource_key == RequestResourceKey("nvidia", "endpoint", RequestDomain.CHAT) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resources.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resources.py new file mode 100644 index 000000000..935f2c074 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resources.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceRequest, + TaskGroupKey, + TaskGroupSpec, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_model import Task + + +def test_scheduler_resource_request_defaults_to_submission() -> None: + request = SchedulerResourceRequest() + + assert request.amounts == {"submission": 1} + + +def test_scheduler_resource_request_rejects_unknown_resource() -> None: + with pytest.raises(ValueError, match="Unknown scheduler resource key"): + SchedulerResourceRequest({"gpu": 1}) # type: ignore[arg-type] + + +def test_scheduler_resource_request_rejects_non_positive_amounts() -> None: + with pytest.raises(ValueError, match="must be a positive integer"): + SchedulerResourceRequest({"submission": 0}) + + +def test_stable_task_id_is_stable_for_task_identity() -> None: + task = Task(column="answer", row_group=3, row_index=8, task_type="cell") + + assert stable_task_id(task) == stable_task_id(task) + assert stable_task_id(task).startswith("task-") + + +def test_stable_task_id_distinguishes_task_identity_fields() -> None: + first = Task(column="answer", row_group=3, row_index=8, task_type="cell") + second = Task(column="answer", row_group=3, row_index=9, task_type="cell") + + assert stable_task_id(first) != stable_task_id(second) + + +def test_schedulable_task_binds_payload_group_and_resource_request() -> None: + task = Task(column="answer", row_group=0, row_index=1, task_type="cell") + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("nvidia", "nemotron")), admitted_limit=2) + request = SchedulerResourceRequest({"submission": 1, "llm_wait": 1}) + + item = SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=request, + ) + + assert item.payload == task + assert item.group == group + assert item.resource_request.amounts["llm_wait"] == 1 diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_admission.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_admission.py new file mode 100644 index 000000000..fbb2fd469 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_admission.py @@ -0,0 +1,275 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from data_designer.engine.dataset_builders.scheduling.queue import FairTaskQueue +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceRequest, + TaskGroupKey, + TaskGroupSpec, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_admission import ( + RELEASED_TASK_LEASE_HISTORY_LIMIT, + TaskAdmissionConfig, + TaskAdmissionController, + TaskAdmissionDenied, + TaskAdmissionLease, +) +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.dataset_builders.scheduling.task_policies import BoundedBorrowTaskAdmissionPolicyConfig + + +def _item( + column: str, + row: int = 0, + *, + group: TaskGroupSpec | None = None, + resources: dict[str, int] | None = None, +) -> SchedulableTask: + task = Task(column=column, row_group=0, row_index=row, task_type="cell") + group = group or TaskGroupSpec(TaskGroupKey(kind="local", identity=(column,))) + return SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=SchedulerResourceRequest(resources or {"submission": 1}), + ) + + +def _queue_view(*items: SchedulableTask): + queue = FairTaskQueue() + queue.enqueue(items) + return queue.view() + + +def test_task_admission_acquires_and_releases_exact_lease() -> None: + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=1)) + item = _item("a") + + decision = controller.try_acquire(item, _queue_view(item)) + + assert isinstance(decision, TaskAdmissionLease) + assert controller.view().resources_available["submission"] == 0 + result = controller.release(decision) + assert result.released is True + assert controller.view().resources_available["submission"] == 1 + + +def test_task_admission_denies_when_resource_full() -> None: + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=1)) + first = _item("a") + second = _item("b") + lease = controller.try_acquire(first, _queue_view(first, second)) + + assert isinstance(lease, TaskAdmissionLease) + decision = controller.try_acquire(second, _queue_view(second)) + + assert isinstance(decision, TaskAdmissionDenied) + assert decision.reason == "no_capacity" + + +def test_task_admission_duplicate_release_does_not_increase_capacity() -> None: + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=1)) + item = _item("a") + lease = controller.try_acquire(item, _queue_view(item)) + assert isinstance(lease, TaskAdmissionLease) + + first = controller.release(lease) + second = controller.release(lease) + + assert first.released is True + assert second.released is False + assert second.reason == "duplicate" + assert controller.view().resources_available["submission"] == 1 + + +def test_task_admission_released_history_is_bounded() -> None: + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=1)) + first_lease: TaskAdmissionLease | None = None + for index in range(RELEASED_TASK_LEASE_HISTORY_LIMIT + 5): + item = _item(f"task-{index}") + lease = controller.try_acquire(item, _queue_view(item)) + assert isinstance(lease, TaskAdmissionLease) + first_lease = first_lease or lease + controller.release(lease) + + assert len(controller._released) == RELEASED_TASK_LEASE_HISTORY_LIMIT + assert len(controller._released_order) == RELEASED_TASK_LEASE_HISTORY_LIMIT + assert controller._released_order.maxlen == RELEASED_TASK_LEASE_HISTORY_LIMIT + assert first_lease is not None + assert controller.release(first_lease).reason == "unknown_lease" + + +def test_task_admission_group_cap_yields_to_peer_pressure() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=2)) + first = _item("a", 0, group=group) + second = _item("a", 1, group=group) + peer = _item("b") + lease = controller.try_acquire(first, _queue_view(first, second, peer)) + assert isinstance(lease, TaskAdmissionLease) + + decision = controller.try_acquire(second, _queue_view(second, peer)) + + assert isinstance(decision, TaskAdmissionDenied) + assert decision.reason == "group_cap" + + +def test_task_admission_group_cap_ignores_non_overlapping_typed_peer_resource() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController( + TaskAdmissionConfig(submission_capacity=3, resource_limits={"llm_wait": 3, "local": 3}) + ) + first = _item("a", 0, group=group, resources={"submission": 1, "llm_wait": 1}) + second = _item("a", 1, group=group, resources={"submission": 1, "llm_wait": 1}) + local_peer = _item("b", resources={"submission": 1, "local": 1}) + lease = controller.try_acquire(first, _queue_view(first, second, local_peer)) + assert isinstance(lease, TaskAdmissionLease) + + decision = controller.try_acquire(second, _queue_view(second, local_peer)) + + assert isinstance(decision, TaskAdmissionLease) + + +def test_task_admission_group_cap_applies_to_overlapping_typed_peer_resource() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + peer_group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "peer")), admitted_limit=1) + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=3, resource_limits={"llm_wait": 3})) + first = _item("a", 0, group=group, resources={"submission": 1, "llm_wait": 1}) + second = _item("a", 1, group=group, resources={"submission": 1, "llm_wait": 1}) + peer = _item("b", group=peer_group, resources={"submission": 1, "llm_wait": 1}) + lease = controller.try_acquire(first, _queue_view(first, second, peer)) + assert isinstance(lease, TaskAdmissionLease) + + decision = controller.try_acquire(second, _queue_view(second, peer)) + + assert isinstance(decision, TaskAdmissionDenied) + assert decision.reason == "group_cap" + assert decision.diagnostics["pressure_resources"] == ("llm_wait",) + + +def test_task_admission_group_cap_ignores_peer_blocked_by_hard_resource_capacity() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + peer_group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "peer")), admitted_limit=1) + controller = TaskAdmissionController( + TaskAdmissionConfig(submission_capacity=4, resource_limits={"llm_wait": 3, "local": 1}) + ) + first = _item("a", 0, group=group, resources={"submission": 1, "llm_wait": 1}) + second = _item("a", 1, group=group, resources={"submission": 1, "llm_wait": 1}) + local_holder = _item("local-holder", resources={"submission": 1, "local": 1}) + blocked_peer = _item("b", group=peer_group, resources={"submission": 1, "llm_wait": 1, "local": 1}) + first_lease = controller.try_acquire(first, _queue_view(first, second, blocked_peer)) + local_lease = controller.try_acquire(local_holder, _queue_view(local_holder, blocked_peer)) + assert isinstance(first_lease, TaskAdmissionLease) + assert isinstance(local_lease, TaskAdmissionLease) + + decision = controller.try_acquire(second, _queue_view(second, blocked_peer)) + + assert isinstance(decision, TaskAdmissionLease) + + +def test_explain_blocked_reports_group_cap_denials() -> None: + first_group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "first")), admitted_limit=1) + second_group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "second")), admitted_limit=1) + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=4)) + first_active = _item("a", 0, group=first_group) + second_active = _item("b", 0, group=second_group) + first_queued = _item("a", 1, group=first_group) + second_queued = _item("b", 1, group=second_group) + first_lease = controller.try_acquire(first_active, _queue_view(first_active, second_active)) + second_lease = controller.try_acquire(second_active, _queue_view(second_active, first_queued)) + assert isinstance(first_lease, TaskAdmissionLease) + assert isinstance(second_lease, TaskAdmissionLease) + queue = FairTaskQueue() + queue.enqueue((first_queued, second_queued)) + + assert queue.select_next(controller.is_eligible) is None + summary = controller.explain_blocked(queue.view()) + + assert summary.dominant_denial_reasons == {"group_cap": 2} + + +def test_task_admission_group_cap_does_not_block_solo_group() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=2)) + first = _item("a", 0, group=group) + second = _item("a", 1, group=group) + lease = controller.try_acquire(first, _queue_view(first, second)) + assert isinstance(lease, TaskAdmissionLease) + + decision = controller.try_acquire(second, _queue_view(second)) + + assert isinstance(decision, TaskAdmissionLease) + + +def test_bounded_borrow_limits_solo_group_borrow_debt() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController( + TaskAdmissionConfig( + submission_capacity=3, + bounded_borrow=BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1), + ) + ) + first = _item("a", 0, group=group) + second = _item("a", 1, group=group) + third = _item("a", 2, group=group) + first_lease = controller.try_acquire(first, _queue_view(first, second, third)) + assert isinstance(first_lease, TaskAdmissionLease) + borrowed = controller.try_acquire(second, _queue_view(second, third)) + assert isinstance(borrowed, TaskAdmissionLease) + + denied = controller.try_acquire(third, _queue_view(third)) + + assert isinstance(denied, TaskAdmissionDenied) + assert denied.reason == "borrow_debt" + assert controller.view().policy_debt_by_group_resource[(group.key, "submission")] == 1 + + +def test_bounded_borrow_debt_blocks_under_peer_pressure_and_releases() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController( + TaskAdmissionConfig( + submission_capacity=3, + bounded_borrow=BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1), + ) + ) + first = _item("a", 0, group=group) + borrowed_item = _item("a", 1, group=group) + blocked_item = _item("a", 2, group=group) + peer = _item("b") + first_lease = controller.try_acquire(first, _queue_view(first, borrowed_item)) + borrowed = controller.try_acquire(borrowed_item, _queue_view(borrowed_item)) + assert isinstance(first_lease, TaskAdmissionLease) + assert isinstance(borrowed, TaskAdmissionLease) + + denied = controller.try_acquire(blocked_item, _queue_view(blocked_item, peer)) + + assert isinstance(denied, TaskAdmissionDenied) + assert denied.reason == "borrow_debt" + controller.release(borrowed) + assert (group.key, "submission") not in controller.view().policy_debt_by_group_resource + + +def test_bounded_borrow_release_repayment_is_group_level() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController( + TaskAdmissionConfig( + submission_capacity=3, + bounded_borrow=BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1), + ) + ) + first = _item("a", 0, group=group) + borrowed_item = _item("a", 1, group=group) + first_lease = controller.try_acquire(first, _queue_view(first, borrowed_item)) + borrowed = controller.try_acquire(borrowed_item, _queue_view(borrowed_item)) + assert isinstance(first_lease, TaskAdmissionLease) + assert isinstance(borrowed, TaskAdmissionLease) + assert controller.view().policy_debt_by_group_resource[(group.key, "submission")] == 1 + + controller.release(first_lease) + + assert (group.key, "submission") not in controller.view().policy_debt_by_group_resource + controller.release(borrowed) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py similarity index 96% rename from packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py rename to packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py index 5d5716213..cdc5e6c6a 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py @@ -5,7 +5,7 @@ import pytest -from data_designer.engine.dataset_builders.utils.task_model import Task, TaskResult, TaskTrace +from data_designer.engine.dataset_builders.scheduling.task_model import Task, TaskResult, TaskTrace def test_task_is_frozen() -> None: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_policies.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_policies.py new file mode 100644 index 000000000..286fdee96 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_policies.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from data_designer.engine.dataset_builders.scheduling.queue import FairTaskQueue, QueueView +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceRequest, + TaskGroupKey, + TaskGroupSpec, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_admission import TaskAdmissionLease, TaskAdmissionView +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.dataset_builders.scheduling.task_policies import ( + BoundedBorrowTaskAdmissionPolicy, + BoundedBorrowTaskAdmissionPolicyConfig, + StrictFairTaskAdmissionPolicy, +) + + +def _item(column: str, group: TaskGroupSpec) -> SchedulableTask: + task = Task(column=column, row_group=0, row_index=0, task_type="cell") + return SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=SchedulerResourceRequest({"submission": 1}), + ) + + +def _queue_view(*items: SchedulableTask) -> QueueView: + queue = FairTaskQueue() + queue.enqueue(items) + return queue.view() + + +def _admission_view( + *, + running_group: TaskGroupKey, + running_count: int = 1, + debt: int = 0, +) -> TaskAdmissionView: + return TaskAdmissionView( + resource_limits={"submission": 4}, + resources_available={"submission": 3}, + leased_resources={"submission": running_count}, + leased_resources_by_group={running_group: {"submission": running_count}}, + running_counts_by_group={running_group: running_count}, + policy_debt_by_group_resource={(running_group, "submission"): debt} if debt else {}, + ) + + +def _lease(item: SchedulableTask) -> TaskAdmissionLease: + return TaskAdmissionLease( + lease_id="lease", + item=item, + resources={"submission": 1}, + acquired_at=0.0, + controller_generation="generation", + ) + + +def test_strict_fair_policy_allows_group_without_peer_pressure() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + item = _item("a", group) + policy = StrictFairTaskAdmissionPolicy() + + decision = policy.evaluate(item, _queue_view(item), _admission_view(running_group=group.key)) + + assert decision.allowed is True + + +def test_strict_fair_policy_denies_capped_group_with_peer_pressure() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + peer_group = TaskGroupSpec(TaskGroupKey(kind="local", identity=("peer",))) + item = _item("a", group) + peer = _item("b", peer_group) + policy = StrictFairTaskAdmissionPolicy() + + decision = policy.evaluate(item, _queue_view(item, peer), _admission_view(running_group=group.key)) + + assert decision.allowed is False + assert decision.reason == "group_cap" + + +def test_bounded_borrow_policy_records_borrow_without_peer_pressure() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + item = _item("a", group) + policy = BoundedBorrowTaskAdmissionPolicy(BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1)) + + decision = policy.evaluate(item, _queue_view(item), _admission_view(running_group=group.key)) + delta = policy.on_acquire(_lease(item), decision) + + assert decision.allowed is True + assert delta.debt_changes == {(group.key, "submission"): 1} + + +def test_bounded_borrow_policy_denies_existing_debt_under_peer_pressure() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + peer_group = TaskGroupSpec(TaskGroupKey(kind="local", identity=("peer",))) + item = _item("a", group) + peer = _item("b", peer_group) + policy = BoundedBorrowTaskAdmissionPolicy(BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1)) + + decision = policy.evaluate(item, _queue_view(item, peer), _admission_view(running_group=group.key, debt=1)) + + assert decision.allowed is False + assert decision.reason == "borrow_debt" + + +def test_bounded_borrow_policy_releases_debt() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + item = _item("a", group) + policy = BoundedBorrowTaskAdmissionPolicy(BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1)) + + delta = policy.on_release(_lease(item)) + + assert delta.debt_changes == {(group.key, "submission"): -1} diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index 684c009ba..f01dc1d91 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -5,10 +5,12 @@ import math import warnings +from types import SimpleNamespace from unittest.mock import MagicMock, Mock import pytest +import data_designer.engine.dataset_builders.dataset_builder as builder_mod import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import ( ExpressionColumnConfig, @@ -24,7 +26,7 @@ ) from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder -from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta +from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager from data_designer.engine.resources.resource_provider import ResourceProvider @@ -189,6 +191,39 @@ def finalize_row_group(rg_id: int) -> None: assert tracker.is_row_group_complete(1, 2, all_cols) +def test_prepare_async_run_enables_request_pressure_advisory(monkeypatch: pytest.MonkeyPatch) -> None: + captured_kwargs: dict[str, object] = {} + + class _SpyScheduler: + def __init__(self, **kwargs: object) -> None: + captured_kwargs.update(kwargs) + + monkeypatch.setattr(builder_mod, "AsyncTaskScheduler", _SpyScheduler) + request_admission = object() + model_registry = MagicMock() + model_registry.get_aggregate_max_parallel_requests.return_value = 2 + model_registry.request_admission = request_admission + provider = SimpleNamespace( + model_registry=model_registry, + run_config=SimpleNamespace(progress_interval=5.0, progress_bar=False), + ) + processor_runner = MagicMock() + processor_runner.has_processors_for.return_value = False + config = SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}) + builder = SimpleNamespace( + _column_configs=[config], + _processor_runner=processor_runner, + artifact_storage=MagicMock(), + _resource_provider=provider, + ) + generator = MockSeed(config=_expr_config("seed"), resource_provider=provider) + + DatasetBuilder._prepare_async_run(builder, [generator], num_records=1, buffer_size=1) + + assert captured_kwargs["request_pressure_provider"] is request_admission + assert captured_kwargs["request_pressure_advisory"] is True + + # -- Test that existing sync path is unaffected -------------------------------- diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index 6097232ef..ce78e3141 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -4,6 +4,8 @@ from __future__ import annotations import asyncio +import logging +import time from collections.abc import Callable from types import SimpleNamespace from typing import Any @@ -23,24 +25,42 @@ from data_designer.config.custom_column import custom_column_generator from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig from data_designer.config.sampler_params import SamplerType +from data_designer.config.scheduling import SchedulingMetadata from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, ColumnGeneratorFullColumn, + ColumnGeneratorWithModelRegistry, FromScratchColumnGenerator, ) from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator -from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler, build_llm_bound_lookup +from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler from data_designer.engine.dataset_builders.errors import DatasetGenerationError -from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta +from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta +from data_designer.engine.dataset_builders.scheduling.task_admission import TaskAdmissionConfig, TaskAdmissionLease +from data_designer.engine.dataset_builders.scheduling.task_model import Task from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager -from data_designer.engine.dataset_builders.utils.task_model import Task from data_designer.engine.models.errors import ( RETRYABLE_MODEL_ERRORS, ModelInternalServerError, ModelRateLimitError, ModelTimeoutError, ) +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.controller import ( + AdaptiveRequestAdmissionController, + RequestAdmissionLease, +) +from data_designer.engine.models.request_admission.outcomes import RequestReleaseOutcome +from data_designer.engine.models.request_admission.pressure import RequestPressureSnapshot +from data_designer.engine.models.request_admission.resources import ( + RequestAdmissionItem, + RequestDomain, + RequestGroupSpec, + RequestResourceKey, +) +from data_designer.engine.models.resources import ProviderModelKey +from data_designer.engine.observability import InMemoryAdmissionEventSink from data_designer.engine.resources.resource_provider import ResourceProvider MODEL_ALIAS = "stub" @@ -83,6 +103,25 @@ def generate(self, data: dict) -> dict: return data +class MockRootCellGenerator(ColumnGenerator[ExpressionColumnConfig]): + """Root cell generator that records the shape it receives.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.call_types: list[str] = [] + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + self.call_types.append(type(data).__name__) + if not isinstance(data, dict): + raise TypeError(f"expected dict, got {type(data).__name__}") + data[self.config.name] = f"root_{len(self.call_types)}" + return data + + class MockFullColumnGenerator(ColumnGeneratorFullColumn[ExpressionColumnConfig]): """Mock full-column generator.""" @@ -152,6 +191,17 @@ def generate(self, data: dict) -> dict: return data +class MockBuggyGenerator(ColumnGenerator[ExpressionColumnConfig]): + """Generator that simulates an internal scheduler/generator bug.""" + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, _data: dict) -> dict: + raise KeyError("missing internal key") + + class MockRateLimitGenerator(ColumnGenerator[ExpressionColumnConfig]): """Generator that fails with rate-limit errors before succeeding. @@ -228,8 +278,8 @@ def generate(self, data: dict) -> dict: class MockRetryableErrorGenerator(ColumnGenerator[ExpressionColumnConfig]): """Generator that raises a parametrizable retryable error then succeeds. - Declares ``is_llm_bound=True`` because it mimics model-call behavior; - the scheduler's degraded-provider WARN window only counts LLM-bound tasks. + Declares model scheduling metadata because it mimics model-call behavior; + the scheduler's degraded-provider WARN window counts model-stage tasks. """ def __init__( @@ -248,9 +298,8 @@ def __init__( def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.CELL_BY_CELL - @property - def is_llm_bound(self) -> bool: - return True + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.custom_model("test", self.config.name, "v1") def generate(self, data: dict) -> dict: self._calls += 1 @@ -260,6 +309,11 @@ def generate(self, data: dict) -> dict: return data +class _BrokenSchedulerSink: + def emit_scheduler_event(self, _event: object) -> None: + raise RuntimeError("sink boom") + + # -- Helper to build graph + scheduler ---------------------------------------- @@ -270,6 +324,7 @@ def _build_simple_pipeline( generators: dict[str, ColumnGenerator] | None = None, configs: list[SamplerColumnConfig | LLMTextColumnConfig | ExpressionColumnConfig] | None = None, strategies: dict[str, GenerationStrategy] | None = None, + scheduler_event_sink: Any | None = None, ) -> tuple[AsyncTaskScheduler, CompletionTracker]: """Build a simple seed → cell pipeline for testing.""" if configs is None: @@ -308,6 +363,7 @@ def _build_simple_pipeline( tracker=tracker, row_groups=row_groups, trace=trace, + scheduler_event_sink=scheduler_event_sink, ) return scheduler, tracker @@ -377,6 +433,31 @@ async def test_scheduler_dispatches_seeds_first() -> None: assert seed_traces[0].dispatched_at < cell_traces[0].dispatched_at +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_dispatches_root_cell_by_cell_columns_per_row() -> None: + provider = _mock_provider() + generator = MockRootCellGenerator(config=_expr_config("root_cell"), resource_provider=provider) + configs = [SamplerColumnConfig(name="root_cell", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]})] + strategies = {"root_cell": GenerationStrategy.CELL_BY_CELL} + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators={"root_cell": generator}, + graph=graph, + tracker=tracker, + row_groups=row_groups, + trace=True, + ) + + await scheduler.run() + + assert generator.call_types == ["dict", "dict", "dict"] + assert [trace.task_type for trace in scheduler.traces] == ["cell", "cell", "cell"] + assert not any(tracker.is_dropped(0, row_index) for row_index in range(3)) + assert tracker.is_row_group_complete(0, 3, ["root_cell"]) + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_with_buffer_manager() -> None: """Scheduler writes results to buffer manager and checkpoints.""" @@ -475,6 +556,90 @@ async def test_scheduler_non_retryable_failure_drops_row() -> None: assert tracker.is_row_group_complete(0, 2, ["seed", "fail_col"]) +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_internal_bug_failure_aborts_instead_of_dropping_row( + caplog: pytest.LogCaptureFixture, +) -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="buggy_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "buggy_col": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "buggy_col": MockBuggyGenerator(config=_expr_config("buggy_col"), resource_provider=provider), + } + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 1)]) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=[(0, 1)], + ) + + with caplog.at_level(logging.ERROR, logger="data_designer.engine.dataset_builders.async_scheduler"): + with pytest.raises(DatasetGenerationError, match="Unexpected internal task failure") as exc_info: + await scheduler.run() + + assert isinstance(exc_info.value.__cause__, KeyError) + assert not tracker.is_dropped(0, 0) + error_records = [ + record for record in caplog.records if "Unexpected fatal Non-retryable failure" in record.getMessage() + ] + assert error_records + assert error_records[0].exc_info is not None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_custom_generator_key_error_drops_row_without_fatal_abort( + caplog: pytest.LogCaptureFixture, +) -> None: + @custom_column_generator() + def failing_custom(row: dict) -> dict: + raise KeyError("missing user field") + + provider = _mock_provider() + custom_config = CustomColumnConfig(name="custom_col", generator_function=failing_custom) + scheduler, tracker = _build_simple_pipeline( + num_records=1, + configs=[ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + custom_config, + ], + strategies={ + "seed": GenerationStrategy.FULL_COLUMN, + "custom_col": GenerationStrategy.CELL_BY_CELL, + }, + generators={ + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "custom_col": CustomColumnGenerator(config=custom_config, resource_provider=provider), + }, + ) + + with caplog.at_level(logging.WARNING): + await scheduler.run() + + assert tracker.is_dropped(0, 0) + assert "This record will be skipped" in caplog.text + assert "Unexpected fatal Non-retryable failure" not in caplog.text + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_logs_sink_failures(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level(logging.WARNING, logger="data_designer.engine.dataset_builders.async_scheduler") + scheduler, tracker = _build_simple_pipeline(num_records=1, scheduler_event_sink=_BrokenSchedulerSink()) + + await scheduler.run() + + assert tracker.is_row_group_complete(0, 1, ["seed", "cell_out"]) + assert "Scheduler admission event sink raised; dropping event." in caplog.text + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_stateful_generator_serializes() -> None: """Stateful generators serialize across row groups.""" @@ -1084,10 +1249,10 @@ def _count_degraded_msgs(caplog: pytest.LogCaptureFixture) -> int: @pytest.mark.parametrize( "retryable_failures,num_records,window,interval_s,expected_count", [ - # Above-threshold + zero throttle: at least one WARN should fire. + # Above-threshold + no log interval: at least one WARN should fire. pytest.param(6, 10, 8, 0.0, "at_least_one", id="fires_above_threshold"), - # Above-threshold + 1h throttle: only one WARN despite sustained degradation. - pytest.param(8, 12, 4, 3600.0, 1, id="throttled_to_one"), + # Above-threshold + 1h log interval: only one WARN despite sustained degradation. + pytest.param(8, 12, 4, 3600.0, 1, id="rate_limited_to_one"), ], ) @pytest.mark.asyncio(loop_scope="session") @@ -1397,15 +1562,14 @@ async def test_scheduler_out_of_order_row_group_completion() -> None: assert checkpoint_order.index(1) < checkpoint_order.index(0) -# -- Dual-semaphore / LLM-bound tests ----------------------------------------- +# -- Task-admission / model-stage tests --------------------------------------- class MockLLMBoundCellGenerator(ColumnGenerator[ExpressionColumnConfig]): - """Mock cell-by-cell generator that reports is_llm_bound=True.""" + """Mock cell-by-cell generator that reports model-stage scheduling metadata.""" - @property - def is_llm_bound(self) -> bool: - return True + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.custom_model("test", self.config.name, "v1") @staticmethod def get_generation_strategy() -> GenerationStrategy: @@ -1416,13 +1580,9 @@ def generate(self, data: dict) -> dict: return data -class MockConfiguredModelCellGenerator(ColumnGenerator[LLMTextColumnConfig]): +class MockConfiguredModelCellGenerator(ColumnGeneratorWithModelRegistry[LLMTextColumnConfig]): """Mock cell generator with model-registry helpers.""" - @property - def is_llm_bound(self) -> bool: - return True - @staticmethod def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.CELL_BY_CELL @@ -1447,9 +1607,8 @@ def __init__(self, *args: Any, rate_limit_failures: int = 0, **kwargs: Any) -> N self._rate_limit_failures = rate_limit_failures self._calls = 0 - @property - def is_llm_bound(self) -> bool: - return True + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.custom_model("test", self.config.name, "v1") @staticmethod def get_generation_strategy() -> GenerationStrategy: @@ -1492,19 +1651,15 @@ async def test_scheduler_llm_bound_one_way_handoff() -> None: tracker=tracker, row_groups=row_groups, max_submitted_tasks=max_submitted, - max_llm_wait_tasks=max_llm_wait, + max_model_task_admission=max_llm_wait, ) await scheduler.run() assert tracker.is_row_group_complete(0, 3, ["seed", "llm_col"]) - sub_available, llm_available = scheduler.get_semaphore_permits() - assert sub_available == max_submitted, ( - f"Submission semaphore leaked after LLM handoff: available={sub_available}, expected={max_submitted}" - ) - assert llm_available == max_llm_wait, ( - f"LLM-wait semaphore leaked after LLM handoff: available={llm_available}, expected={max_llm_wait}" - ) + snapshot = scheduler.task_admission_snapshot() + assert snapshot.resources_available["submission"] == max_submitted + assert snapshot.resources_available["llm_wait"] == max_llm_wait @pytest.mark.asyncio(loop_scope="session") @@ -1535,21 +1690,19 @@ async def test_scheduler_non_llm_holds_submission_slot() -> None: tracker=tracker, row_groups=row_groups, max_submitted_tasks=2, - max_llm_wait_tasks=max_llm_wait, + max_model_task_admission=max_llm_wait, ) await scheduler.run() assert tracker.is_row_group_complete(0, 3, ["seed", "cell_out"]) - _, llm_available = scheduler.get_semaphore_permits() - assert llm_available == max_llm_wait, ( - f"LLM-wait semaphore was consumed by non-LLM task: available={llm_available}, expected={max_llm_wait}" - ) + snapshot = scheduler.task_admission_snapshot() + assert snapshot.resources_available["llm_wait"] == max_llm_wait @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_deadlock_regression() -> None: - """max_submitted_tasks=1, max_llm_wait_tasks=1, two ready LLM tasks completes without deadlock.""" + """max_submitted_tasks=1, max_model_task_admission=1, two ready LLM tasks completes without deadlock.""" provider = _mock_provider() configs = [ SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), @@ -1574,7 +1727,7 @@ async def test_scheduler_deadlock_regression() -> None: tracker=tracker, row_groups=row_groups, max_submitted_tasks=1, - max_llm_wait_tasks=1, + max_model_task_admission=1, ) await asyncio.wait_for(scheduler.run(), timeout=10.0) @@ -1613,9 +1766,12 @@ async def test_drain_frontier_raises_when_ready_but_no_capacity_or_inflight() -> graph=graph, tracker=tracker, row_groups=row_groups, - max_submitted_tasks=0, + task_admission_config=TaskAdmissionConfig(submission_capacity=1), ) scheduler._rg_states[0] = MagicMock(size=1) + blocker = scheduler._schedulable_task(Task(column="cell_out", row_group=0, row_index=99, task_type="cell")) + lease = scheduler._task_admission.try_acquire(blocker, scheduler._fair_queue.view()) + assert isinstance(lease, TaskAdmissionLease) scheduler._apply_frontier_delta(seed_delta) with pytest.raises(RuntimeError, match="Ready frontier is admission-blocked"): @@ -1695,22 +1851,99 @@ def drop_middle_row(row_group: int, row_group_size: int) -> FrontierDelta: assert tracker.is_row_group_complete(0, 3, ["seed", "cell_out"]) -@pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_is_llm_bound_property_drives_lookup() -> None: - """is_llm_bound property on generators drives the lookup, not isinstance.""" +def test_apply_frontier_delta_enqueues_ready_tasks_in_one_queue_operation(monkeypatch: pytest.MonkeyPatch) -> None: provider = _mock_provider() - llm_gen = MockLLMBoundCellGenerator(config=_expr_config("llm_col"), resource_provider=provider) - non_llm_gen = MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider) + configs = [ + LLMTextColumnConfig(name="root", prompt="root", model_alias=MODEL_ALIAS), + ] + strategies = { + "root": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "root": MockCellGenerator(config=_expr_config("root"), resource_provider=provider), + } + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 5)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + ) + scheduler._rg_states[0] = SimpleNamespace(size=5, pre_batch_done=True) - assert llm_gen.is_llm_bound is True - assert non_llm_gen.is_llm_bound is False + enqueue_sizes: list[int] = [] + original_enqueue = scheduler._fair_queue.enqueue - lookup = build_llm_bound_lookup({"llm_col": llm_gen, "cell_out": non_llm_gen}) - assert lookup == {"llm_col": True, "cell_out": False} + def spy_enqueue(items: Any) -> tuple[str, ...]: + materialized = tuple(items) + enqueue_sizes.append(len(materialized)) + return original_enqueue(materialized) + monkeypatch.setattr(scheduler._fair_queue, "enqueue", spy_enqueue) -def test_custom_generator_with_model_aliases_is_llm_bound() -> None: - """CustomColumnGenerator with model_aliases reports is_llm_bound=True.""" + scheduler._apply_frontier_delta(tracker.add_root_tasks(0, 5)) + + assert enqueue_sizes == [5] + assert tracker.ready_frontier() == () + assert scheduler._fair_queue.view().queued_total == 5 + + +def test_pre_batch_flush_batches_pending_ready_and_skips_dropped_rows(monkeypatch: pytest.MonkeyPatch) -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cell_out": MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider), + } + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + sink = InMemoryAdmissionEventSink() + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + on_seeds_complete=lambda row_group, row_group_size: None, + scheduler_event_sink=sink, + ) + state = SimpleNamespace(size=3, pre_batch_done=False) + scheduler._rg_states[0] = state + + enqueue_sizes: list[int] = [] + original_enqueue = scheduler._fair_queue.enqueue + + def spy_enqueue(items: Any) -> tuple[str, ...]: + materialized = tuple(items) + enqueue_sizes.append(len(materialized)) + return original_enqueue(materialized) + + monkeypatch.setattr(scheduler._fair_queue, "enqueue", spy_enqueue) + + scheduler._apply_frontier_delta(tracker.mark_row_range_complete("seed", 0, 3)) + scheduler._apply_frontier_delta(tracker.drop_row(0, 1)) + state.pre_batch_done = True + scheduler._flush_pre_batch_ready(0) + + assert enqueue_sizes == [2] + assert scheduler._fair_queue.view().queued_total == 2 + assert {item.payload.row_index for item in scheduler._fair_queue._queued.values()} == {0, 2} + assert tracker.is_dropped(0, 1) + assert sum(event.event_kind == "ready_enqueued" for event in sink.scheduler_events) == 2 + assert sum(event.event_kind == "dependency_ready" for event in sink.scheduler_events) == 5 + + +def test_custom_generator_with_model_aliases_reports_custom_model_metadata() -> None: + """CustomColumnGenerator with model_aliases reports custom-model metadata.""" @custom_column_generator(model_aliases=["my_model"]) def gen_with_models(row: dict, generator_params: None, models: dict) -> dict: @@ -1729,11 +1962,8 @@ def gen_no_models(row: dict) -> dict: llm_gen = CustomColumnGenerator(config=llm_config, resource_provider=provider) plain_gen = CustomColumnGenerator(config=plain_config, resource_provider=provider) - assert llm_gen.is_llm_bound is True - assert plain_gen.is_llm_bound is False - - lookup = build_llm_bound_lookup({"custom_llm": llm_gen, "custom_plain": plain_gen}) - assert lookup == {"custom_llm": True, "custom_plain": False} + assert llm_gen.get_scheduling_metadata().kind == "custom_model" + assert plain_gen.get_scheduling_metadata().kind == "local" def _provider_with_model_configs(configs: dict[str, ModelConfig]) -> MagicMock: @@ -1762,13 +1992,13 @@ def test_scheduler_model_task_group_spec_uses_model_resource_and_flow() -> None: graph=graph, tracker=tracker, row_groups=[(0, 1)], - max_llm_wait_tasks=5, + max_model_task_admission=5, ) - spec = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=0, task_type="cell")) + spec = scheduler._schedulable_task(Task(column="answer", row_group=0, row_index=0, task_type="cell")).group assert spec.key.kind == "model" - assert spec.key.identity[:2] == ("mock-provider", "model-text") + assert spec.key.identity[:3] == ("model", "mock-provider", "model-text") assert spec.key.identity[-1] == "answer" assert spec.weight == 3.0 assert spec.admitted_limit == 5 @@ -1792,21 +2022,19 @@ def test_scheduler_task_group_spec_is_cached_per_generator() -> None: graph=graph, tracker=tracker, row_groups=[(0, 2)], - max_llm_wait_tasks=5, + max_model_task_admission=5, ) - spec_a = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=0, task_type="cell")) - spec_b = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=1, task_type="cell")) + spec_a = scheduler._schedulable_task(Task(column="answer", row_group=0, row_index=0, task_type="cell")).group + spec_b = scheduler._schedulable_task(Task(column="answer", row_group=0, row_index=1, task_type="cell")).group - assert spec_a is spec_b + assert spec_a == spec_b assert provider.model_registry.get_model_config.call_count == 1 assert provider.model_registry.get_model_provider.call_count == 1 -def test_scheduler_task_group_spec_logs_debug_on_model_resolution_fallback( - caplog: pytest.LogCaptureFixture, -) -> None: - """Direct spec resolution isolates fallback logging without timing-based scheduler traces.""" +def test_scheduler_task_group_spec_raises_on_model_resolution_failure() -> None: + """Model metadata resolution failures are fatal without an explicit fallback.""" provider = MagicMock() provider.model_registry = MagicMock() provider.model_registry.get_model_config.side_effect = RuntimeError("registry unavailable") @@ -1816,29 +2044,14 @@ def test_scheduler_task_group_spec_logs_debug_on_model_resolution_fallback( graph = ExecutionGraph.create([column_config], {"answer": GenerationStrategy.CELL_BY_CELL}) tracker = CompletionTracker.with_graph(graph, [(0, 2)]) - with caplog.at_level("DEBUG", logger="data_designer.engine.dataset_builders.utils.scheduling_hints"): - scheduler = AsyncTaskScheduler( + with pytest.raises(Exception): + AsyncTaskScheduler( generators={"answer": generator}, graph=graph, tracker=tracker, row_groups=[(0, 2)], - max_llm_wait_tasks=5, + max_model_task_admission=5, ) - spec_a = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=0, task_type="cell")) - spec_b = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=1, task_type="cell")) - - assert spec_a is spec_b - assert spec_a.key.kind == "custom_model" - assert spec_a.key.identity == ("answer", MODEL_ALIAS) - assert spec_a.weight == 1.0 - assert provider.model_registry.get_model_config.call_count == 1 - fallback_records = [ - record for record in caplog.records if "Falling back to custom-model scheduling group" in record.getMessage() - ] - assert len(fallback_records) == 1 - assert "answer" in fallback_records[0].getMessage() - assert MODEL_ALIAS in fallback_records[0].getMessage() - assert fallback_records[0].exc_info is not None def test_scheduler_custom_model_task_group_spec_uses_alias_set_weight() -> None: @@ -1874,15 +2087,15 @@ def gen_with_models(row: dict, generator_params: None, models: dict) -> dict: graph=graph, tracker=tracker, row_groups=[(0, 1)], - max_llm_wait_tasks=10, + max_model_task_admission=10, ) - spec = scheduler._task_group_spec(Task(column="custom_llm", row_group=0, row_index=0, task_type="cell")) + spec = scheduler._schedulable_task(Task(column="custom_llm", row_group=0, row_index=0, task_type="cell")).group assert spec.key.kind == "custom_model" - assert spec.key.identity == ("custom_llm", "draft", "judge") - assert spec.weight == 5.0 - assert spec.admitted_limit == 10 + assert spec.key.identity[:3] == ("custom_model", "custom_column", "draft-judge") + assert spec.weight == 2.0 + assert spec.admitted_limit == 4 @pytest.mark.asyncio(loop_scope="session") @@ -1927,33 +2140,28 @@ async def test_scheduler_llm_bound_429_retried_in_salvage() -> None: row_groups=row_groups, buffer_manager=buffer_mgr, max_submitted_tasks=max_submitted, - max_llm_wait_tasks=max_llm_wait, + max_model_task_admission=max_llm_wait, ) await scheduler.run() assert tracker.is_row_group_complete(0, num_records, ["seed", "llm_col"]) - sub_available, llm_available = scheduler.get_semaphore_permits() - assert sub_available == max_submitted, ( - f"Submission semaphore leaked after salvage retry: available={sub_available}, expected={max_submitted}" - ) - assert llm_available == max_llm_wait, ( - f"LLM-wait semaphore leaked after salvage retry: available={llm_available}, expected={max_llm_wait}" - ) + snapshot = scheduler.task_admission_snapshot() + assert snapshot.resources_available["submission"] == max_submitted + assert snapshot.resources_available["llm_wait"] == max_llm_wait @pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_cancellation_releases_semaphores() -> None: - """Cancelling the scheduler while LLM-bound tasks are in-flight releases all semaphore slots.""" +async def test_scheduler_cancellation_releases_task_admission_leases() -> None: + """Cancelling the scheduler while model-stage tasks are in-flight releases task leases.""" provider = _mock_provider() blocked = asyncio.Event() proceed = asyncio.Event() class BlockingLLMGenerator(ColumnGenerator[ExpressionColumnConfig]): - @property - def is_llm_bound(self) -> bool: - return True + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.custom_model("test", self.config.name, "v1") @staticmethod def get_generation_strategy() -> GenerationStrategy: @@ -1987,13 +2195,15 @@ async def agenerate(self, data: dict) -> dict: max_submitted = 4 max_llm_wait = 2 + sink = InMemoryAdmissionEventSink() scheduler = AsyncTaskScheduler( generators=generators, graph=graph, tracker=tracker, row_groups=row_groups, max_submitted_tasks=max_submitted, - max_llm_wait_tasks=max_llm_wait, + max_model_task_admission=max_llm_wait, + scheduler_event_sink=sink, ) run_task = asyncio.create_task(scheduler.run()) @@ -2003,13 +2213,14 @@ async def agenerate(self, data: dict) -> dict: with pytest.raises(asyncio.CancelledError): await run_task - sub_available, llm_available = scheduler.get_semaphore_permits() - assert sub_available == max_submitted, ( - f"Submission semaphore leaked: available={sub_available}, expected={max_submitted}" - ) - assert llm_available == max_llm_wait, ( - f"LLM-wait semaphore leaked: available={llm_available}, expected={max_llm_wait}" - ) + snapshot = scheduler.task_admission_snapshot() + assert snapshot.resources_available["submission"] == max_submitted + assert snapshot.resources_available["llm_wait"] == max_llm_wait + assert "cancelled" in [event.event_kind for event in sink.scheduler_events] + assert all(event.snapshot is not None for event in sink.scheduler_events) + task_events = [event for event in sink.scheduler_events if event.task_id is not None] + assert all("resource_request" in event.diagnostics for event in task_events) + assert any("llm_wait" in event.diagnostics["resource_request"] for event in task_events) @pytest.mark.asyncio(loop_scope="session") @@ -2017,7 +2228,7 @@ async def test_scheduler_rg_semaphore_deadlock_with_transient_failures() -> None """Row groups stalled by transient failures don't block admission of new row groups. Regression test: with max_concurrent_row_groups=1 and 2 row groups, if all - tasks in RG0 fail transiently, the semaphore must still be released so RG1 + tasks in RG0 fail transiently, row-group capacity must still be released so RG1 can be admitted. The scheduler salvages RG0 inline and continues. """ provider = _mock_provider() @@ -2098,31 +2309,6 @@ def test_side_effect_columns_separated_from_completion_tracking() -> None: assert "side_b" in write_cols -# -- TrackingSemaphore tests --------------------------------------------------- - - -def test_tracking_semaphore_try_acquire() -> None: - """try_acquire returns True when permits are available, False when exhausted.""" - from data_designer.engine.dataset_builders.async_scheduler import TrackingSemaphore - - sem = TrackingSemaphore(2) - assert sem.available_permits == 2 - - assert sem.try_acquire() is True - assert sem.available_permits == 1 - - assert sem.try_acquire() is True - assert sem.available_permits == 0 - - assert sem.try_acquire() is False - assert sem.available_permits == 0 - - sem.release() - assert sem.available_permits == 1 - assert sem.try_acquire() is True - assert sem.available_permits == 0 - - # -- Pipeline parallelism (stale dispatch fix, issue #504) --------------------- @@ -2147,11 +2333,80 @@ async def agenerate(self, data: dict) -> dict: class SlowLLMBoundCellGenerator(SlowCellGenerator): - """Slow cell generator that participates in LLM-wait scheduling.""" + """Slow cell generator that participates in model-stage scheduling.""" + + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.custom_model("test", self.config.name, "v1") + + +class SlowModelBoundCellGenerator(SlowCellGenerator): + """Slow cell generator with concrete request-pressure identity.""" + + def __init__( + self, + *args: Any, + provider_name: str = "provider", + model_id: str = "model", + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._provider_name = provider_name + self._model_id = model_id + + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.model( + self._provider_name, + self._model_id, + "chat", + weight=1, + ) + + +class _StaticRequestPressureProvider: + def __init__(self, snapshots: dict[RequestResourceKey, RequestPressureSnapshot]) -> None: + self._snapshots = snapshots @property - def is_llm_bound(self) -> bool: - return True + def config(self) -> RequestAdmissionConfig | None: + return None + + def snapshot(self, resource: RequestResourceKey) -> RequestPressureSnapshot | None: + return self._snapshots.get(resource) + + def snapshots(self) -> dict[RequestResourceKey, RequestPressureSnapshot]: + return dict(self._snapshots) + + def global_snapshot(self, provider: str, model: str) -> None: + return None + + def global_snapshots(self) -> dict[ProviderModelKey, object]: + return {} + + +def _pressure_snapshot( + resource: RequestResourceKey, + *, + current_limit: int = 1, + in_flight: int = 0, + waiters: int = 0, + cooldown: float = 0.0, +) -> RequestPressureSnapshot: + return RequestPressureSnapshot( + captured_at=time.monotonic(), + sequence=1, + resource=resource, + effective_max=max(1, current_limit), + current_limit=current_limit, + in_flight_count=in_flight, + active_lease_count=in_flight, + waiters=waiters, + blocked_until_monotonic=time.monotonic() + cooldown if cooldown > 0.0 else None, + cooldown_remaining_seconds=cooldown, + rate_limit_ceiling=max(1, current_limit), + consecutive_rate_limits=0, + last_outcome=None, + leak_diagnostics={}, + ) @pytest.mark.asyncio(loop_scope="session") @@ -2304,7 +2559,7 @@ async def test_scheduler_fair_llm_group_cap_preserves_peer_admission() -> None: tracker=tracker, row_groups=row_groups, max_submitted_tasks=4, - max_llm_wait_tasks=4, + max_model_task_admission=4, trace=True, ) @@ -2318,7 +2573,9 @@ async def test_scheduler_fair_llm_group_cap_preserves_peer_admission() -> None: assert first_window.count("hot") == 2 assert first_window.count("peer") == 2 assert tracker.is_row_group_complete(0, 8, ["topic", *gen_names]) - assert scheduler.get_semaphore_permits() == (4, 4) + snapshot = scheduler.task_admission_snapshot() + assert snapshot.resources_available["submission"] == 4 + assert snapshot.resources_available["llm_wait"] == 4 @pytest.mark.asyncio(loop_scope="session") @@ -2332,9 +2589,9 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None: ├── gen_b (slow, 50ms) → judge_b (instant) └── gen_c (slow, 50ms) → judge_c (instant) - With a small semaphore (4) and 10 records, the 30 gen tasks (3 cols x 10 rows) - saturate the semaphore. The dispatch loop must re-query the frontier when the - semaphore is full so that judge tasks from completed gen rows are picked up + With small task admission capacity (4) and 10 records, the 30 gen tasks + saturate admission. The dispatch loop must re-query the frontier when capacity + is full so that judge tasks from completed gen rows are picked up before all gen tasks finish. """ provider = _mock_provider() @@ -2394,6 +2651,387 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None: ) +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_capacity_plan_observes_buffer_backpressure() -> None: + provider = _mock_provider() + gen_names = ["gen_a", "gen_b", "gen_c"] + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + *[LLMTextColumnConfig(name=g, prompt="{{ topic }}", model_alias=MODEL_ALIAS) for g in gen_names], + ] + strategies: dict[str, GenerationStrategy] = {"topic": GenerationStrategy.FULL_COLUMN} + strategies.update({column: GenerationStrategy.CELL_BY_CELL for column in gen_names}) + generators: dict[str, ColumnGenerator] = { + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + **{ + name: SlowCellGenerator(config=_expr_config(name), resource_provider=provider, delay=0.02) + for name in gen_names + }, + } + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3), (1, 3), (2, 3), (3, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=2, + max_submitted_tasks=2, + trace=True, + num_records=12, + buffer_size=3, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + plan = scheduler.capacity_plan() + for row_group_index, row_count in row_groups: + assert tracker.is_row_group_complete(row_group_index, row_count, ["topic", *gen_names]) + assert plan.configured.row_group_admission.observed_in_flight == 0 + assert plan.observed_maxima.row_groups_in_flight == 2 + assert plan.observed_maxima.queued_tasks_by_group + assert max(plan.observed_maxima.task_leases_by_resource.values()) <= 2 + + +def test_scheduler_capacity_plan_reports_request_admission_state() -> None: + resource = RequestResourceKey("provider", "model", RequestDomain.CHAT) + request_admission = AdaptiveRequestAdmissionController( + RequestAdmissionConfig(initial_limits={resource: 2}, max_limit_clamps={resource: 3}) + ) + request_admission.register( + provider_name="provider", + model_id="model", + alias="primary", + max_parallel_requests=4, + ) + lease = request_admission.try_acquire(RequestAdmissionItem(resource, RequestGroupSpec(resource))) + assert isinstance(lease, RequestAdmissionLease) + + scheduler, _tracker = _build_simple_pipeline() + scheduler._request_pressure_provider = request_admission + scheduler._record_observed_task_state() + plan = scheduler.capacity_plan() + + assert plan.configured.request_resources.value == (resource,) + assert plan.configured.request_domain_initial_limits.value[resource] == 2 + assert plan.configured.request_admission_config.value is not None + assert plan.configured.provider_model_static_caps.value[ProviderModelKey("provider", "model")].cap == 4 + assert plan.runtime_snapshot.request_domain_current_limits[resource] == 2 + assert plan.runtime_snapshot.request_domain_effective_max[resource] == 3 + assert plan.runtime_snapshot.provider_model_aggregate_in_flight[ProviderModelKey("provider", "model")] == 1 + assert plan.observed_maxima.request_in_flight_by_resource[resource] == 1 + assert plan.observed_maxima.provider_model_aggregate_in_flight[ProviderModelKey("provider", "model")] == 1 + request_admission.release(lease, RequestReleaseOutcome(kind="success")) + + +def test_scheduler_capacity_plan_reports_default_request_initial_limit_after_aimd_drop() -> None: + resource = RequestResourceKey("provider", "model", RequestDomain.CHAT) + request_admission = AdaptiveRequestAdmissionController() + request_admission.register( + provider_name="provider", + model_id="model", + alias="primary", + max_parallel_requests=4, + ) + lease = request_admission.try_acquire(RequestAdmissionItem(resource, RequestGroupSpec(resource))) + assert isinstance(lease, RequestAdmissionLease) + request_admission.release(lease, RequestReleaseOutcome(kind="rate_limited")) + + scheduler, _tracker = _build_simple_pipeline() + scheduler._request_pressure_provider = request_admission + plan = scheduler.capacity_plan() + + assert plan.configured.request_domain_initial_limits.value[resource] == 4 + assert plan.runtime_snapshot.request_domain_effective_max[resource] == 4 + assert plan.runtime_snapshot.request_domain_current_limits[resource] == 3 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_emits_job_health_and_row_group_telemetry() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 2)] + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, row_groups) + sink = InMemoryAdmissionEventSink() + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": SlowLLMBoundCellGenerator( + config=_expr_config("model_col"), + resource_provider=provider, + delay=0.0, + ), + }, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=1, + max_submitted_tasks=2, + max_model_task_admission=1, + scheduler_event_sink=sink, + num_records=2, + buffer_size=2, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + kinds = [event.event_kind for event in sink.scheduler_events] + assert "scheduler_job_started" in kinds + assert "scheduler_health_snapshot" in kinds + assert "row_group_checkpointed" in kinds + assert "scheduler_job_completed" in kinds + + started = next(event for event in sink.scheduler_events if event.event_kind == "scheduler_job_started") + assert started.diagnostics["num_records"] == 2 + assert started.diagnostics["buffer_size"] == 2 + assert started.diagnostics["row_group_count"] == 1 + assert started.diagnostics["graph_depth"] == 2 + column_scheduling = started.diagnostics["column_scheduling"] + assert isinstance(column_scheduling, tuple) + model_column = next(item for item in column_scheduling if item["column"] == "model_col") + assert model_column["group_kind"] == "custom_model" + assert model_column["resource_request"] == {"submission": 1, "llm_wait": 1} + + health = next(event for event in sink.scheduler_events if event.event_kind == "scheduler_health_snapshot") + assert "queued_total" in health.diagnostics + assert "leased_resources" in health.diagnostics + assert "request_pressure" in health.diagnostics + + checkpointed = next(event for event in sink.scheduler_events if event.event_kind == "row_group_checkpointed") + assert checkpointed.diagnostics["row_group"] == 0 + assert checkpointed.diagnostics["row_group_size"] == 2 + assert checkpointed.diagnostics["surviving_rows"] == 2 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_adaptive_row_group_admission_expands_target_for_horizon_idle() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1), (1, 1), (2, 1), (3, 1)] + generators: dict[str, ColumnGenerator] = { + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": SlowLLMBoundCellGenerator( + config=_expr_config("model_col"), + resource_provider=provider, + delay=0.04, + ), + } + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, row_groups) + sink = InMemoryAdmissionEventSink() + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=4, + max_submitted_tasks=4, + max_model_task_admission=4, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=1, + scheduler_event_sink=sink, + trace=True, + num_records=4, + buffer_size=1, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + plan = scheduler.capacity_plan() + assert tracker.is_row_group_complete(0, 1, ["topic", "model_col"]) + assert plan.configured.row_group_admission.mode == "adaptive" + assert plan.configured.row_group_admission.observed_max_target is not None + assert plan.configured.row_group_admission.observed_max_target > 1 + assert plan.observed_maxima.row_groups_in_flight > 1 + assert any(event.event_kind == "row_group_admission_target_changed" for event in sink.scheduler_events) + + +def test_scheduler_adaptive_row_group_row_guard_blocks_extra_large_groups() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 5_000), (1, 5_000)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": SlowLLMBoundCellGenerator( + config=_expr_config("model_col"), + resource_provider=provider, + delay=0.0, + ), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + max_concurrent_row_groups=4, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=4, + num_records=10_000, + buffer_size=1, + ) + + scheduler._rg_states[0] = SimpleNamespace(size=5_000) + + assert scheduler._adaptive_max_admitted_rows == 8_192 + assert not scheduler._row_group_row_guard_allows(5_000) + assert scheduler._row_group_row_guard_allows(1_000) + scheduler._rg_states.clear() + assert scheduler._row_group_row_guard_allows(9_000) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_raises_when_ready_frontier_blocked_without_in_flight() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": SlowLLMBoundCellGenerator( + config=_expr_config("model_col"), + resource_provider=provider, + delay=0.0, + ), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + task_admission_config=TaskAdmissionConfig( + submission_capacity=1, + resource_limits={"submission": 1, "local": 1}, + ), + ) + + with pytest.raises(RuntimeError, match="Ready frontier is admission-blocked"): + await asyncio.wait_for(scheduler.run(), timeout=2.0) + + +def test_scheduler_request_pressure_advisory_prefers_pressure_open_peer() -> None: + provider = _mock_provider() + configs = [ + LLMTextColumnConfig(name="pressured", prompt="A", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="open", prompt="B", model_alias=MODEL_ALIAS), + ] + strategies = { + "pressured": GenerationStrategy.CELL_BY_CELL, + "open": GenerationStrategy.CELL_BY_CELL, + } + generators: dict[str, ColumnGenerator] = { + "pressured": SlowModelBoundCellGenerator( + config=_expr_config("pressured"), + resource_provider=provider, + provider_name="provider-a", + model_id="model-a", + ), + "open": SlowModelBoundCellGenerator( + config=_expr_config("open"), + resource_provider=provider, + provider_name="provider-b", + model_id="model-b", + ), + } + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 1)]) + pressured_key = RequestResourceKey("provider-a", "model-a", RequestDomain.CHAT) + open_key = RequestResourceKey("provider-b", "model-b", RequestDomain.CHAT) + pressure = _StaticRequestPressureProvider( + { + pressured_key: _pressure_snapshot(pressured_key, current_limit=1, in_flight=1, waiters=1), + open_key: _pressure_snapshot(open_key, current_limit=1, in_flight=0, waiters=0), + } + ) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=[(0, 1)], + request_pressure_provider=pressure, + request_pressure_advisory=True, + scheduler_event_sink=(sink := InMemoryAdmissionEventSink()), + ) + scheduler._rg_states[0] = SimpleNamespace(size=1, pre_batch_done=True) + pressured = scheduler._schedulable_task(Task(column="pressured", row_group=0, row_index=0, task_type="cell")) + open_task = scheduler._schedulable_task(Task(column="open", row_group=0, row_index=0, task_type="cell")) + scheduler._fair_queue.enqueue((pressured, open_task)) + + selection = scheduler._fair_queue.select_next(scheduler._is_dispatch_eligible) + + assert selection is not None + assert selection.item.payload.column == "open" + skip = next(event for event in sink.scheduler_events if event.event_kind == "request_pressure_advisory_skipped") + assert skip.diagnostics["request_resource"] == "provider-a/model-a/chat" + assert skip.diagnostics["pressure_reason"] == "waiters" + assert skip.diagnostics["open_peer_column"] == "open" + assert skip.diagnostics["open_peer_request_resource"] == "provider-b/model-b/chat" + + +def test_scheduler_request_pressure_advisory_preserves_liveness_when_all_candidates_pressured() -> None: + provider = _mock_provider() + configs = [LLMTextColumnConfig(name="pressured", prompt="A", model_alias=MODEL_ALIAS)] + strategies = {"pressured": GenerationStrategy.CELL_BY_CELL} + generators: dict[str, ColumnGenerator] = { + "pressured": SlowModelBoundCellGenerator( + config=_expr_config("pressured"), + resource_provider=provider, + provider_name="provider-a", + model_id="model-a", + ), + } + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 1)]) + pressured_key = RequestResourceKey("provider-a", "model-a", RequestDomain.CHAT) + pressure = _StaticRequestPressureProvider( + {pressured_key: _pressure_snapshot(pressured_key, current_limit=1, in_flight=1, waiters=1)} + ) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=[(0, 1)], + request_pressure_provider=pressure, + request_pressure_advisory=True, + ) + scheduler._rg_states[0] = SimpleNamespace(size=1, pre_batch_done=True) + pressured = scheduler._schedulable_task(Task(column="pressured", row_group=0, row_index=0, task_type="cell")) + scheduler._fair_queue.enqueue((pressured,)) + + selection = scheduler._fair_queue.select_next(scheduler._is_dispatch_eligible) + + assert selection is not None + assert selection.item.payload.column == "pressured" + + # -- Skip / conditional generation tests (async engine) ----------------------- diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index dfd219fd5..6a5b31a51 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -20,9 +20,9 @@ from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import CodeValidatorParams from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig +from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph -from data_designer.engine.dataset_builders.utils.task_model import SliceRef MODEL_ALIAS = "stub-model-alias" diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_fair_task_queue.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_fair_task_queue.py deleted file mode 100644 index b929bce4f..000000000 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_fair_task_queue.py +++ /dev/null @@ -1,219 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections import Counter - -from data_designer.engine.dataset_builders.utils.fair_task_queue import ( - FairTaskQueue, - TaskGroupKey, - TaskGroupSpec, -) -from data_designer.engine.dataset_builders.utils.task_model import Task - - -def _task(column: str, row_index: int) -> Task: - return Task(column=column, row_group=0, row_index=row_index, task_type="cell") - - -def _group(name: str, *, weight: float = 1.0, admitted_limit: int | None = None) -> TaskGroupSpec: - return TaskGroupSpec( - key=TaskGroupKey(kind="local", identity=(name,)), - weight=weight, - admitted_limit=admitted_limit, - ) - - -def _enqueue(queue: FairTaskQueue, items: list[tuple[Task, TaskGroupSpec]]) -> None: - for task, group in items: - queue.enqueue(task, group) - - -def test_fair_task_queue_equal_groups_round_robins() -> None: - queue = FairTaskQueue() - _enqueue( - queue, - [ - (task, _group(task.column)) - for task in [ - _task("a", 0), - _task("a", 1), - _task("b", 0), - _task("b", 1), - _task("c", 0), - _task("c", 1), - ] - ], - ) - - selected = [queue.admit_next() for _ in range(6)] - - assert [selection.task.column for selection in selected if selection is not None] == ["a", "b", "c", "a", "b", "c"] - - -def test_fair_task_queue_weighted_groups() -> None: - queue = FairTaskQueue() - _enqueue( - queue, - [ - (task, _group(task.column, weight=2 if task.column == "a" else 1)) - for task in [_task("a", i) for i in range(6)] - ] - + [(_task("b", i), _group("b", weight=1)) for i in range(6)], - ) - - selected = [queue.admit_next() for _ in range(6)] - counts = Counter(selection.task.column for selection in selected if selection is not None) - - assert counts == {"a": 4, "b": 2} - - -def test_fair_task_queue_discards_queued_tasks() -> None: - queue = FairTaskQueue() - stale = _task("a", 0) - fresh = _task("a", 1) - - _enqueue(queue, [(stale, _group("a")), (fresh, _group("a"))]) - queue.discard(stale) - - selected = queue.admit_next() - - assert selected is not None - assert selected.task == fresh - assert queue.admit_next() is None - - -def test_fair_task_queue_admitted_cap_skips_saturated_group_with_waiting_peer() -> None: - queue = FairTaskQueue() - capped = _group("a", admitted_limit=1, weight=1_000) - peer = _group("b") - _enqueue( - queue, - [ - (_task("a", 0), capped), - (_task("a", 1), capped), - (_task("b", 0), peer), - (_task("b", 1), peer), - ], - ) - - first = queue.admit_next() - peer_first = queue.admit_next() - selected = queue.admit_next() - - assert first is not None - assert first.task.column == "a" - assert peer_first is not None - assert peer_first.task.column == "b" - assert selected is not None - assert selected.task.column == "b" - - -def test_fair_task_queue_solo_group_can_exceed_admitted_cap() -> None: - queue = FairTaskQueue() - group = _group("a", admitted_limit=1) - first_task = _task("a", 0) - second_task = _task("a", 1) - queue.enqueue(first_task, group) - queue.enqueue(second_task, group) - - first = queue.admit_next() - - assert first is not None - assert first.task == first_task - second = queue.admit_next() - assert second is not None - assert second.task == second_task - assert queue.has_queued_tasks is False - - -def test_fair_task_queue_over_cap_group_yields_to_queued_peer() -> None: - queue = FairTaskQueue() - capped = _group("a", admitted_limit=1) - peer = _group("b") - _enqueue(queue, [(_task("a", i), capped) for i in range(5)]) - - solo_selected = [queue.admit_next() for _ in range(3)] - _enqueue(queue, [(_task("b", i), peer) for i in range(2)]) - peer_selected = [queue.admit_next() for _ in range(2)] - - assert [selection.task.column for selection in solo_selected if selection is not None] == ["a", "a", "a"] - assert [selection.task.column for selection in peer_selected if selection is not None] == ["b", "b"] - - -def test_fair_task_queue_returns_none_when_all_competing_groups_capped() -> None: - queue = FairTaskQueue() - group_a = _group("a", admitted_limit=1) - group_b = _group("b", admitted_limit=1) - _enqueue( - queue, - [ - (_task("a", 0), group_a), - (_task("a", 1), group_a), - (_task("b", 0), group_b), - (_task("b", 1), group_b), - ], - ) - - selected = [queue.admit_next() for _ in range(2)] - - assert [selection.task.column for selection in selected if selection is not None] == ["a", "b"] - assert queue.admit_next() is None - assert queue.has_queued_tasks is True - - -def test_fair_task_queue_release_reopens_saturated_group() -> None: - queue = FairTaskQueue() - group_a = _group("a", admitted_limit=1) - group_b = _group("b", admitted_limit=1) - _enqueue( - queue, - [ - (_task("a", 0), group_a), - (_task("a", 1), group_a), - (_task("b", 0), group_b), - (_task("b", 1), group_b), - ], - ) - first = queue.admit_next() - second = queue.admit_next() - - assert first is not None - assert first.task.column == "a" - assert second is not None - assert second.task.column == "b" - assert queue.admit_next() is None - - queue.release(first.task) - reopened = queue.admit_next() - - assert reopened is not None - assert reopened.task == _task("a", 1) - - -def test_fair_task_queue_no_duplicate_on_repeated_enqueue() -> None: - queue = FairTaskQueue() - task = _task("a", 0) - - queue.enqueue(task, _group("a")) - queue.enqueue(task, _group("a")) - first = queue.admit_next() - - assert first is not None - assert first.task == task - assert queue.admit_next() is None - - -def test_fair_task_queue_discard_where_removes_matching_tasks() -> None: - queue = FairTaskQueue() - _enqueue( - queue, - [(_task(column, i), _group(column)) for column in ["a", "b"] for i in range(2)], - ) - - queue.discard_where(lambda task: task.column == "a") - selected = [queue.admit_next() for _ in range(2)] - - assert [selection.task.column for selection in selected if selection is not None] == ["b", "b"] - assert queue.admit_next() is None diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_scheduling_hints.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_scheduling_hints.py deleted file mode 100644 index 4e46c07b0..000000000 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_scheduling_hints.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from data_designer.config.column_configs import ( - CustomColumnConfig, - ExpressionColumnConfig, - GenerationStrategy, - LLMTextColumnConfig, -) -from data_designer.config.custom_column import custom_column_generator -from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig -from data_designer.engine.column_generators.generators.base import ColumnGenerator -from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator -from data_designer.engine.dataset_builders.utils.scheduling_hints import SchedulingHint, SchedulingHintResolver -from data_designer.engine.resources.resource_provider import ResourceProvider - -MODEL_ALIAS = "stub" - - -def _expr_config(name: str = "test") -> ExpressionColumnConfig: - return ExpressionColumnConfig(name=name, expr="{{ x }}", dtype="str") - - -def _provider_with_model_configs(configs: dict[str, ModelConfig]) -> MagicMock: - provider = MagicMock(spec=ResourceProvider) - provider.model_registry = MagicMock() - provider.model_registry.get_model_config.side_effect = lambda model_alias: configs[model_alias] - provider.model_registry.get_model_provider.return_value = SimpleNamespace(name="mock-provider") - return provider - - -class LocalCellGenerator(ColumnGenerator[ExpressionColumnConfig]): - @staticmethod - def get_generation_strategy() -> GenerationStrategy: - return GenerationStrategy.CELL_BY_CELL - - def generate(self, data: dict) -> dict: - data[self.config.name] = "local" - return data - - -class ModelCellGenerator(ColumnGenerator[LLMTextColumnConfig]): - @property - def is_llm_bound(self) -> bool: - return True - - @staticmethod - def get_generation_strategy() -> GenerationStrategy: - return GenerationStrategy.CELL_BY_CELL - - def generate(self, data: dict) -> dict: - data[self.config.name] = "model" - return data - - def get_model_config(self, model_alias: str) -> ModelConfig: - return self.resource_provider.model_registry.get_model_config(model_alias=model_alias) - - def get_model_provider_name(self, model_alias: str) -> str: - provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) - return str(provider.name) - - -def test_scheduling_hint_resolver_local_hint_does_not_touch_model_registry() -> None: - provider = MagicMock(spec=ResourceProvider) - provider.model_registry = MagicMock() - generator = LocalCellGenerator(config=_expr_config("local_col"), resource_provider=provider) - - resolver = SchedulingHintResolver({"local_col": generator}) - - assert resolver.hint_for(generator) == SchedulingHint(group_kind="local") - provider.model_registry.get_model_config.assert_not_called() - provider.model_registry.get_model_provider.assert_not_called() - - -def test_scheduling_hint_resolver_resolves_primary_model_once_per_generator() -> None: - model_config = ModelConfig( - alias=MODEL_ALIAS, - model="model-text", - inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=3), - provider="mock-provider", - ) - provider = _provider_with_model_configs({MODEL_ALIAS: model_config}) - column_config = LLMTextColumnConfig(name="answer", prompt="hello", model_alias=MODEL_ALIAS) - generator = ModelCellGenerator(config=column_config, resource_provider=provider) - - resolver = SchedulingHintResolver({"answer": generator, "answer_again": generator}) - hint = resolver.hint_for(generator) - - assert hint.group_kind == "model" - assert hint.identity_prefix[:2] == ("mock-provider", "model-text") - assert hint.weight == 3 - assert provider.model_registry.get_model_config.call_count == 1 - assert provider.model_registry.get_model_provider.call_count == 1 - - -def test_scheduling_hint_resolver_falls_back_to_custom_model_hint_with_debug( - caplog: pytest.LogCaptureFixture, -) -> None: - provider = MagicMock(spec=ResourceProvider) - provider.model_registry = MagicMock() - provider.model_registry.get_model_config.side_effect = RuntimeError("registry unavailable") - provider.model_registry.get_model_provider.return_value = SimpleNamespace(name="mock-provider") - column_config = LLMTextColumnConfig(name="answer", prompt="hello", model_alias=MODEL_ALIAS) - generator = ModelCellGenerator(config=column_config, resource_provider=provider) - - with caplog.at_level("DEBUG", logger="data_designer.engine.dataset_builders.utils.scheduling_hints"): - resolver = SchedulingHintResolver({"answer": generator}) - - hint = resolver.hint_for(generator) - - assert hint == SchedulingHint(group_kind="custom_model", identity_suffix=(MODEL_ALIAS,), weight=1) - fallback_records = [ - record for record in caplog.records if "Falling back to custom-model scheduling group" in record.getMessage() - ] - assert len(fallback_records) == 1 - assert "answer" in fallback_records[0].getMessage() - assert MODEL_ALIAS in fallback_records[0].getMessage() - assert fallback_records[0].exc_info is not None - - -def test_scheduling_hint_resolver_partial_alias_fallback_preserves_resolved_weight() -> None: - @custom_column_generator(model_aliases=["resolved", "missing"]) - def gen_with_models(row: dict, generator_params: None, models: dict) -> dict: - row["custom_llm"] = "value" - return row - - provider = _provider_with_model_configs( - { - "resolved": ModelConfig( - alias="resolved", - model="model-resolved", - inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=7), - provider="mock-provider", - ) - } - ) - config = CustomColumnConfig(name="custom_llm", generator_function=gen_with_models) - generator = CustomColumnGenerator(config=config, resource_provider=provider) - - resolver = SchedulingHintResolver({"custom_llm": generator}) - hint = resolver.hint_for(generator) - - assert hint == SchedulingHint(group_kind="custom_model", identity_suffix=("missing", "resolved"), weight=7) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_factory.py b/packages/data-designer-engine/tests/engine/models/clients/test_factory.py index ffdad291f..f809db8be 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_factory.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_factory.py @@ -18,9 +18,9 @@ from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient from data_designer.engine.models.clients.factory import create_model_client +from data_designer.engine.models.clients.model_request_executor import ModelRequestExecutor from data_designer.engine.models.clients.retry import RetryConfig -from data_designer.engine.models.clients.throttle_manager import ThrottleManager -from data_designer.engine.models.clients.throttled import ThrottledModelClient +from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController from data_designer.engine.secret_resolver import SecretResolver @@ -178,40 +178,51 @@ def test_concurrency_mode_defaults_to_sync( assert client.concurrency_mode == ClientConcurrencyMode.SYNC -# --- Throttle manager wrapping --- +# --- Request admission wrapping --- -def test_throttle_manager_wraps_openai_client( +def test_request_admission_wraps_openai_client( openai_model_config: ModelConfig, secret_resolver: SecretResolver, openai_registry: ModelProviderRegistry, ) -> None: - tm = ThrottleManager() + controller = AdaptiveRequestAdmissionController() + retry_config = RetryConfig(max_retries=5) client = create_model_client( - openai_model_config, secret_resolver, openai_registry, retry_config=RetryConfig(), throttle_manager=tm + openai_model_config, + secret_resolver, + openai_registry, + retry_config=retry_config, + request_admission=controller, ) - assert isinstance(client, ThrottledModelClient) + assert isinstance(client, ModelRequestExecutor) assert isinstance(client._inner, OpenAICompatibleClient) + assert client._retry_config is retry_config + assert client._inner._retry_config.max_retries == 0 -def test_throttle_manager_wraps_anthropic_client( +def test_request_admission_wraps_anthropic_client( anthropic_model_config: ModelConfig, secret_resolver: SecretResolver, anthropic_registry: ModelProviderRegistry, ) -> None: - tm = ThrottleManager() + controller = AdaptiveRequestAdmissionController() client = create_model_client( - anthropic_model_config, secret_resolver, anthropic_registry, retry_config=RetryConfig(), throttle_manager=tm + anthropic_model_config, + secret_resolver, + anthropic_registry, + retry_config=RetryConfig(), + request_admission=controller, ) - assert isinstance(client, ThrottledModelClient) + assert isinstance(client, ModelRequestExecutor) assert isinstance(client._inner, AnthropicClient) -def test_no_throttle_manager_returns_inner_client_directly( +def test_no_request_admission_returns_inner_client_directly( openai_model_config: ModelConfig, secret_resolver: SecretResolver, openai_registry: ModelProviderRegistry, ) -> None: client = create_model_client(openai_model_config, secret_resolver, openai_registry) assert isinstance(client, OpenAICompatibleClient) - assert not isinstance(client, ThrottledModelClient) + assert not isinstance(client, ModelRequestExecutor) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py b/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py new file mode 100644 index 000000000..2c44a0c00 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py @@ -0,0 +1,357 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import logging + +import pytest + +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind +from data_designer.engine.models.clients.model_request_executor import ModelRequestExecutor +from data_designer.engine.models.clients.retry import RetryConfig +from data_designer.engine.models.clients.types import ( + AssistantMessage, + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ImageGenerationRequest, + ImageGenerationResponse, + ImagePayload, +) +from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController +from data_designer.engine.models.request_admission.resources import RequestDomain +from data_designer.engine.observability import InMemoryAdmissionEventSink + + +class _Client: + provider_name = "nvidia" + + def __init__(self) -> None: + self.error: Exception | None = None + + def supports_chat_completion(self) -> bool: + return True + + def supports_embeddings(self) -> bool: + return True + + def supports_image_generation(self) -> bool: + return True + + def close(self) -> None: + return None + + async def aclose(self) -> None: + return None + + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + if self.error is not None: + raise self.error + return ChatCompletionResponse(AssistantMessage(content="ok")) + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + if self.error is not None: + raise self.error + return ChatCompletionResponse(AssistantMessage(content="ok")) + + def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + return EmbeddingResponse(vectors=[[1.0]]) + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + return EmbeddingResponse(vectors=[[1.0]]) + + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + return ImageGenerationResponse(images=[ImagePayload("abc")]) + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + return ImageGenerationResponse(images=[ImagePayload("abc")]) + + +class _BrokenSink: + def emit_request_event(self, _event: object) -> None: + raise RuntimeError("sink boom") + + +class _GatedAsyncClient(_Client): + def __init__(self) -> None: + super().__init__() + self.chat_started = asyncio.Event() + self.embedding_started = asyncio.Event() + self.image_started = asyncio.Event() + self.release_chat = asyncio.Event() + self.release_embedding = asyncio.Event() + self.release_image = asyncio.Event() + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + self.chat_started.set() + await self.release_chat.wait() + return ChatCompletionResponse(AssistantMessage(content="chat")) + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + self.embedding_started.set() + await self.release_embedding.wait() + return EmbeddingResponse(vectors=[[1.0]]) + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + self.image_started.set() + await self.release_image.wait() + return ImageGenerationResponse(images=[ImagePayload("image")]) + + +class _FlakyClient(_Client): + def __init__( + self, + *, + failures: int, + kind: ProviderErrorKind = ProviderErrorKind.INTERNAL_SERVER, + status_code: int | None = 503, + ) -> None: + super().__init__() + self.failures = failures + self.calls = 0 + self.kind = kind + self.status_code = status_code + + def _maybe_fail(self) -> None: + self.calls += 1 + if self.calls <= self.failures: + raise ProviderError( + kind=self.kind, + message="temporarily unavailable", + status_code=self.status_code, + provider_name="nvidia", + model_name="nemotron", + ) + + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + self._maybe_fail() + return ChatCompletionResponse(AssistantMessage(content="ok")) + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + self._maybe_fail() + return ChatCompletionResponse(AssistantMessage(content="ok")) + + +def _executor() -> tuple[ModelRequestExecutor, AdaptiveRequestAdmissionController, _Client]: + controller = AdaptiveRequestAdmissionController() + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + client = _Client() + return ModelRequestExecutor(client, controller, "nvidia", "nemotron"), controller, client + + +def test_model_request_executor_releases_successful_request() -> None: + executor, controller, _client = _executor() + + response = executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert response.message.content == "ok" + snapshot = controller.pressure.snapshot(next(iter(controller.pressure.snapshots()))) + assert snapshot is not None + assert snapshot.active_lease_count == 0 + assert snapshot.last_outcome == "success" + + +def test_model_request_executor_classifies_rate_limit() -> None: + executor, controller, client = _executor() + client.error = ProviderError( + kind=ProviderErrorKind.RATE_LIMIT, + message="rate limited", + provider_name="nvidia", + model_name="nemotron", + retry_after=1.0, + ) + + with pytest.raises(ProviderError): + executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + snapshot = controller.pressure.snapshot(next(iter(controller.pressure.snapshots()))) + assert snapshot is not None + assert snapshot.last_outcome == "rate_limited" + assert snapshot.cooldown_remaining_seconds > 0 + + +def test_model_request_executor_retries_provider_503_with_fresh_leases() -> None: + sink = InMemoryAdmissionEventSink() + controller = AdaptiveRequestAdmissionController(event_sink=sink) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + client = _FlakyClient(failures=1) + executor = ModelRequestExecutor( + client, + controller, + "nvidia", + "nemotron", + event_sink=sink, + retry_config=RetryConfig(max_retries=1, backoff_factor=0.0), + ) + + response = executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert response.message.content == "ok" + assert client.calls == 2 + acquired = [event for event in sink.request_events if event.event_kind == "request_lease_acquired"] + released = [event for event in sink.request_events if event.event_kind == "request_lease_released"] + assert len(acquired) == 2 + assert len(released) == 2 + assert {event.request_lease_id for event in acquired} == {event.request_lease_id for event in released} + + +def test_model_request_executor_does_not_retry_provider_timeout_without_status() -> None: + controller = AdaptiveRequestAdmissionController() + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + client = _FlakyClient(failures=2, kind=ProviderErrorKind.TIMEOUT, status_code=None) + executor = ModelRequestExecutor( + client, + controller, + "nvidia", + "nemotron", + retry_config=RetryConfig(max_retries=2, backoff_factor=0.0), + ) + + with pytest.raises(ProviderError) as exc_info: + executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert exc_info.value.kind == ProviderErrorKind.TIMEOUT + assert client.calls == 1 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_model_request_executor_retries_async_provider_503_with_fresh_leases() -> None: + sink = InMemoryAdmissionEventSink() + controller = AdaptiveRequestAdmissionController(event_sink=sink) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + client = _FlakyClient(failures=1) + executor = ModelRequestExecutor( + client, + controller, + "nvidia", + "nemotron", + event_sink=sink, + retry_config=RetryConfig(max_retries=1, backoff_factor=0.0), + ) + + response = await executor.acompletion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert response.message.content == "ok" + assert client.calls == 2 + acquired = [event for event in sink.request_events if event.event_kind == "request_lease_acquired"] + released = [event for event in sink.request_events if event.event_kind == "request_lease_released"] + assert len(acquired) == 2 + assert len(released) == 2 + assert {event.request_lease_id for event in acquired} == {event.request_lease_id for event in released} + + +@pytest.mark.asyncio(loop_scope="session") +async def test_model_request_executor_releases_async_cancellation() -> None: + class _SlowClient(_Client): + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + await asyncio.sleep(30) + return ChatCompletionResponse(AssistantMessage(content="late")) + + controller = AdaptiveRequestAdmissionController() + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + executor = ModelRequestExecutor(_SlowClient(), controller, "nvidia", "nemotron") + + task = asyncio.create_task(executor.acompletion(ChatCompletionRequest(model="nemotron", messages=[]))) + await asyncio.sleep(0) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + snapshot = controller.pressure.snapshot(next(iter(controller.pressure.snapshots()))) + assert snapshot is not None + assert snapshot.active_lease_count == 0 + assert snapshot.last_outcome == "local_cancelled" + + +def test_model_request_executor_maps_image_chat_domain() -> None: + executor, controller, _client = _executor() + + executor.generate_image(ImageGenerationRequest(model="nemotron", prompt="p", messages=[])) + + resources = controller.pressure.snapshots() + assert any(resource.domain == RequestDomain.CHAT for resource in resources) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_model_request_executor_shares_provider_model_cap_across_async_domains() -> None: + controller = AdaptiveRequestAdmissionController() + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + client = _GatedAsyncClient() + executor = ModelRequestExecutor(client, controller, "nvidia", "nemotron") + + chat_task = asyncio.create_task(executor.acompletion(ChatCompletionRequest(model="nemotron", messages=[]))) + await asyncio.wait_for(client.chat_started.wait(), timeout=1.0) + embedding_task = asyncio.create_task(executor.aembeddings(EmbeddingRequest(model="nemotron", inputs=["x"]))) + image_task = asyncio.create_task(executor.agenerate_image(ImageGenerationRequest(model="nemotron", prompt="image"))) + await _wait_for_request_waiters(controller, expected=2) + + global_snapshot = controller.pressure.global_snapshot("nvidia", "nemotron") + assert global_snapshot is not None + assert global_snapshot.aggregate_in_flight == 1 + assert not client.embedding_started.is_set() + assert not client.image_started.is_set() + + client.release_chat.set() + await asyncio.wait_for(client.embedding_started.wait(), timeout=1.0) + assert not client.image_started.is_set() + assert (await chat_task).message.content == "chat" + + global_snapshot = controller.pressure.global_snapshot("nvidia", "nemotron") + assert global_snapshot is not None + assert global_snapshot.aggregate_in_flight == 1 + client.release_embedding.set() + await asyncio.wait_for(client.image_started.wait(), timeout=1.0) + assert (await embedding_task).vectors == [[1.0]] + + client.release_image.set() + assert (await image_task).images[0].b64_data == "image" + global_snapshot = controller.pressure.global_snapshot("nvidia", "nemotron") + assert global_snapshot is not None + assert global_snapshot.aggregate_in_flight == 0 + + +async def _wait_for_request_waiters(controller: AdaptiveRequestAdmissionController, *, expected: int) -> None: + for _ in range(50): + waiters = sum(snapshot.waiters for snapshot in controller.pressure.snapshots().values()) + if waiters == expected: + return + await asyncio.sleep(0) + raise AssertionError(f"expected {expected} request waiters") + + +def test_model_request_executor_emits_attempt_events_with_correlation_fields() -> None: + sink = InMemoryAdmissionEventSink() + controller = AdaptiveRequestAdmissionController(event_sink=sink) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + executor = ModelRequestExecutor(_Client(), controller, "nvidia", "nemotron", event_sink=sink) + + executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + kinds = [event.event_kind for event in sink.request_events] + assert "request_wait_started" in kinds + assert "request_lease_acquired" in kinds + assert "model_request_started" in kinds + assert "model_request_completed" in kinds + assert "request_lease_released" in kinds + attempts = {event.request_attempt_id for event in sink.request_events if event.request_attempt_id is not None} + assert len(attempts) == 1 + assert all(event.request_resource_key is not None for event in sink.request_events) + assert all(event.pressure_snapshot is not None for event in sink.request_events) + attempt_events = [event for event in sink.request_events if event.request_attempt_id is not None] + assert attempt_events + assert all(event.request_group_key is not None for event in attempt_events) + assert all(event.pressure_snapshot.resource == event.request_resource_key for event in attempt_events) + + +def test_model_request_executor_logs_sink_failures(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level(logging.WARNING, logger="data_designer.engine.models.clients.model_request_executor") + controller = AdaptiveRequestAdmissionController() + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + executor = ModelRequestExecutor(_Client(), controller, "nvidia", "nemotron", event_sink=_BrokenSink()) + + executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert "Model request event sink raised; dropping event." in caplog.text diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_throttle_manager.py b/packages/data-designer-engine/tests/engine/models/clients/test_throttle_manager.py deleted file mode 100644 index 11a24edb7..000000000 --- a/packages/data-designer-engine/tests/engine/models/clients/test_throttle_manager.py +++ /dev/null @@ -1,579 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import asyncio -import threading -import time - -import pytest - -from data_designer.config.run_config import ThrottleConfig -from data_designer.engine.models.clients.throttle_manager import ( - CAPACITY_POLL_INTERVAL, - ThrottleDomain, - ThrottleManager, -) - -PROVIDER = "test-provider" -MODEL = "gpt-test" -DOMAIN = ThrottleDomain.CHAT - - -@pytest.fixture -def manager() -> ThrottleManager: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=4) - return tm - - -# --- try_acquire --- - - -def test_acquire_under_limit_returns_zero(manager: ThrottleManager) -> None: - wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert wait == 0.0 - - -def test_acquire_at_capacity_returns_short_poll_interval(manager: ThrottleManager) -> None: - for _ in range(4): - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert wait == pytest.approx(CAPACITY_POLL_INTERVAL) - - -def test_acquire_respects_blocked_until(manager: ThrottleManager) -> None: - """Rate-limit cooldown returns remaining block duration (not the short capacity poll).""" - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, retry_after=5.0, now=1.0) - wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=2.0) - assert wait == pytest.approx(4.0, abs=0.01) - - -def test_acquire_without_registration_raises() -> None: - tm = ThrottleManager() - with pytest.raises(RuntimeError, match="register"): - tm.try_acquire(provider_name="unknown", model_id="m", domain=DOMAIN, now=0.0) - - -# --- release_success --- - - -def test_release_success_frees_slot(manager: ThrottleManager) -> None: - for _ in range(4): - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert wait == 0.0 - - -def test_additive_increase_after_success_window() -> None: - tm = ThrottleManager(ThrottleConfig(success_window=5)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=10) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - limit_after_drop = state.current_limit - - for i in range(5): - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) - - assert state.current_limit == limit_after_drop + 1 - - -def test_additive_increase_uses_configured_step() -> None: - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=3)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=20) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - limit_after_drop = state.current_limit - - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - - assert state.current_limit == limit_after_drop + 3 - - -def test_current_limit_never_exceeds_effective_max() -> None: - tm = ThrottleManager(ThrottleConfig(success_window=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=2) - for i in range(20): - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit <= 2 - - -def test_additive_increase_clamped_to_effective_max() -> None: - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=100)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=5) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 5 - - -# --- release_rate_limited --- - - -def test_rate_limited_reduces_current_limit(manager: ThrottleManager) -> None: - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 3 # floor(4 * 0.75) - - -def test_rate_limited_never_drops_below_one() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit >= 1 - - -def test_rate_limited_resets_success_streak(manager: ThrottleManager) -> None: - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.success_streak == 0 - - -def test_rate_limited_uses_retry_after_for_blocked_until(manager: ThrottleManager) -> None: - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, retry_after=7.0, now=10.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.blocked_until == pytest.approx(17.0, abs=0.01) - - -def test_rate_limited_uses_default_block_when_no_retry_after(manager: ThrottleManager) -> None: - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.blocked_until == pytest.approx(10.0 + ThrottleConfig.DEFAULT_COOLDOWN_SECONDS, abs=0.01) - - -# --- release_failure --- - - -def test_failure_releases_slot_without_limit_change(manager: ThrottleManager) -> None: - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - limit_before = state.current_limit - manager.release_failure(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.current_limit == limit_before - assert state.in_flight == 0 - - -def test_failure_does_not_reset_cascade_while_burst_in_flight(manager: ThrottleManager) -> None: - """Mixed-response burst (429 → 500 → 429 with multiple slots in-flight) must reduce only once. - - With a real burst of in-flight requests, an interleaved non-rate-limit - failure should NOT break the cascade - otherwise the next 429 from the - same wave would be treated as a new cascade and double-reduce the limit - even though the provider hasn't recovered between the two 429s. - """ - # Saturate to limit (4 concurrent slots). - for _ in range(4): - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.in_flight == 4 - limit_before = state.current_limit - - # First 429 from the burst: limit reduced once. - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - limit_after_first_429 = state.current_limit - assert limit_after_first_429 < limit_before - assert state.consecutive_429s == 1 - assert state.in_flight == 3 - - # Second response from the same burst: 500. With the regression, this - # would reset the cascade to 0; with the fix, in_flight > 0 keeps it at 1. - manager.release_failure(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.consecutive_429s == 1, "cascade must not reset while the prior burst is still in-flight" - assert state.in_flight == 2 - - # Third response from the same burst: another 429. With the regression - # this would be treated as a new cascade and reduce the limit again. - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.current_limit == limit_after_first_429, "limit must not double-reduce within the same burst" - assert state.in_flight == 1 - - -def test_failure_resets_cascade_after_burst_drains(manager: ThrottleManager) -> None: - """Once the burst fully drains (in_flight == 0), the next non-RL failure breaks the cascade. - - This preserves the original PR intent for the sequential 429 → 500 → 429 - case: provider rate-limited, settled, then rate-limited again. - """ - # Saturate, then drain: one 429 then one 500 with no concurrency. - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.consecutive_429s == 1 - assert state.in_flight == 0 - - # New request after the burst drained. release_failure sees in_flight 1 → 0. - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_failure(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.consecutive_429s == 0 - assert state.in_flight == 0 - - -# --- Global cap --- - - -def test_two_aliases_effective_max_is_minimum() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=10) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a2", max_parallel_requests=3) - assert tm.get_effective_max(PROVIDER, MODEL) == 3 - - -def test_domain_clamped_when_new_alias_lowers_cap() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=10) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 10 - - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a2", max_parallel_requests=3) - assert state.current_limit == 3 - - -# --- Domain isolation --- - - -def test_chat_and_embedding_throttle_independently() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=2) - - for _ in range(2): - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=ThrottleDomain.CHAT, now=0.0) - wait_chat = tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=ThrottleDomain.CHAT, now=0.0) - assert wait_chat > 0.0 - - wait_emb = tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=ThrottleDomain.EMBEDDING, now=0.0) - assert wait_emb == 0.0 - - -# --- 429 lifecycle scenario --- - - -def test_rate_limit_lifecycle_acquire_backoff_recover() -> None: - """End-to-end AIMD lifecycle: steady-state → 429 → backoff → cooldown → recovery. - - Uses the ``now`` parameter to simulate time without real sleeps. - Config: success_window=3, additive_increase=1, max_parallel=4, reduce_factor=0.75. - """ - tm = ThrottleManager(ThrottleConfig(success_window=3, additive_increase=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=4) - t = 0.0 - - # Phase 1 — Steady state (t=0): all 4 slots acquired and released successfully. - for _ in range(4): - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - for _ in range(4): - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 4 - - # Phase 2 — 429 hits (t=10): reduce_factor=0.75 → floor(4*0.75)=3. - # Domain is blocked until t=10+5=15. - t = 10.0 - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, retry_after=5.0, now=t) - assert state.current_limit == 3 - assert state.blocked_until == 15.0 - - # Phase 3 — During cooldown (t=12): acquire returns positive wait since 12 < 15. - wait = tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=12.0) - assert wait > 0.0 - - # Phase 4 — Cooldown expires (t=16): acquire succeeds, start accumulating successes. - # Need 3 successes (success_window=3) to bump limit 3 → 4. - t = 16.0 - for _ in range(3): - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - t += 1.0 - - assert state.current_limit == 4 - - -# --- Ceiling stabilization --- - - -def test_ceiling_stabilization_with_overshoot() -> None: - """After a 429, AIMD increase stops at ceiling + overshoot instead of effective_max. - - Config: effective_max=1000, success_window=1, ceiling_overshoot=0.10. - Scenario: 429 at limit 40 → floor(40*0.75)=30 → ceiling=40 → soft cap = 40 + 4 = 44. - Recovery should stop at 44, not climb to 1000. - """ - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1000) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - state.current_limit = 40 - - # 429 at limit 40 → floor(40*0.75)=30, ceiling recorded as 40. - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - assert state.current_limit == 30 - assert state.rate_limit_ceiling == 40 - - # Pump success windows to climb back up. soft_cap = 40 + floor(40*0.1) = 44. - t = 20.0 - for _ in range(20): - t += 1.0 - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - - assert state.current_limit == 44, f"Expected stabilization at 44, got {state.current_limit}" - - # Further successes should not increase beyond the soft ceiling. - for _ in range(10): - t += 1.0 - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - - assert state.current_limit == 44, f"Limit crept past soft ceiling: {state.current_limit}" - - -def test_ceiling_lowers_on_repeated_429_after_recovery() -> None: - """A 429 after partial recovery lowers the ceiling, tightening the soft cap. - - Scenario: first 429 at 40 → floor(40*0.75)=30, ceiling=40. - Recovery: set limit to 30, one success bumps to 31 (success_window=1). - Second 429 at 31 → floor(31*0.75)=23, ceiling = min(40, 31) = 31. - Soft cap = 31 + max(1, floor(31*0.1)) = 31 + 3 = 34. - """ - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1000) - - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - state.current_limit = 40 - - # First 429 at 40 → floor(40*0.75)=30, ceiling=40. - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - assert state.rate_limit_ceiling == 40 - assert state.current_limit == 30 - - # Recovery: one success bumps 30 → 31. - t = 20.0 - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - assert state.current_limit == 31 - - # Second 429 at 31 → floor(31*0.75)=23, ceiling = min(40, 31) = 31. - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t + 1) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t + 1) - assert state.rate_limit_ceiling == 31 - assert state.current_limit == 23 - - # Soft cap = 31 + max(1, floor(31*0.1)) = 34. Climb should stop there. - t = 40.0 - for _ in range(15): - t += 1.0 - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - - assert state.current_limit == 34, f"Expected soft cap at 34, got {state.current_limit}" - - -def test_cascade_only_first_429_reduces_limit() -> None: - """Only the first 429 in a cascade reduces the limit; subsequent ones just release permits.""" - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=100) - - for _ in range(4): - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.in_flight == 4 - - # First 429: limit 100 → 75, ceiling set to 100. - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - assert state.current_limit == 75 - assert state.rate_limit_ceiling == 100 - assert state.in_flight == 3 - - # Subsequent cascade 429s: limit stays at 75, only in_flight decrements. - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - assert state.current_limit == 75 - assert state.rate_limit_ceiling == 100 - assert state.in_flight == 2 - - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - assert state.current_limit == 75 - assert state.in_flight == 1 - - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - assert state.current_limit == 75 - assert state.in_flight == 0 - - -def test_ceiling_does_not_restrict_when_at_effective_max() -> None: - """When effective_max is small (e.g. 4), the ceiling + overshoot should not - prevent recovery to effective_max. - """ - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=4) - - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - # floor(4 * 0.75) = 3; ceiling=4, soft_cap = min(4 + max(1, floor(4*0.1)), 4) = 4 - assert state.current_limit == 3 - - t = 20.0 - for _ in range(5): - t += 1.0 - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - - assert state.current_limit == 4, f"Should recover to effective_max=4, got {state.current_limit}" - - -# --- Acquire timeout --- - - -def test_acquire_sync_raises_timeout_when_at_capacity() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - # Saturate the single slot so try_acquire returns a positive wait. - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - with pytest.raises(TimeoutError, match="timed out"): - tm.acquire_sync(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, timeout=0.0) - - -def test_acquire_sync_does_not_overshoot_timeout() -> None: - """When wait > remaining budget, raise immediately instead of sleeping the full wait.""" - tm = ThrottleManager(ThrottleConfig(cooldown_seconds=5.0)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - # Timeout of 0.5s is less than the 5s block wait — should raise fast, not sleep 5s. - start = time.monotonic() - with pytest.raises(TimeoutError, match="timed out"): - tm.acquire_sync(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, timeout=0.5) - elapsed = time.monotonic() - start - assert elapsed < 2.0, f"acquire_sync overshot timeout: elapsed {elapsed:.1f}s (expected <2s)" - - -@pytest.mark.asyncio -async def test_acquire_async_raises_timeout_when_at_capacity() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - with pytest.raises(TimeoutError, match="timed out"): - await tm.acquire_async(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, timeout=0.0) - - -@pytest.mark.asyncio -async def test_acquire_async_default_no_deadline_waits_for_release() -> None: - """``timeout=None`` (the default) waits for the permit instead of raising. - - Issue #551: the previous 300s default produced spurious ``ModelTimeoutError`` - cascades on slow endpoints with deep queues; now queue waits scale with - provider speed and only the HTTP timeout deadlines actual work. The - ``timeout=0.0`` case is covered by ``test_acquire_async_raises_timeout_when_at_capacity``. - """ - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - async def release_after(delay: float) -> None: - await asyncio.sleep(delay) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - # Hold a strong reference to the task so the loop's weak-ref bookkeeping - # can't GC it before the inner await observes the release. - release_task = asyncio.create_task(release_after(0.05)) - try: - # asyncio.wait_for caps the test runtime; the inner acquire_async passes None. - await asyncio.wait_for( - tm.acquire_async(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN), - timeout=2.0, - ) - finally: - await release_task - - -def test_acquire_sync_default_no_deadline_waits_for_release() -> None: - """Sync counterpart: ``timeout=None`` default blocks until release.""" - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - def release_after(delay: float) -> None: - time.sleep(delay) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - threading.Thread(target=release_after, args=(0.05,), daemon=True).start() - start = time.monotonic() - tm.acquire_sync(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - elapsed = time.monotonic() - start - assert 0.04 < elapsed < 2.0, f"expected ~0.05s wait, got {elapsed:.3f}s" - - -# --- Thread safety --- - - -def test_concurrent_acquire_release_no_errors() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=4) - errors: list[Exception] = [] - - def worker() -> None: - try: - for _ in range(50): - tm.acquire_sync(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - except Exception as exc: - errors.append(exc) - - threads = [threading.Thread(target=worker) for _ in range(8)] - for t in threads: - t.start() - for t in threads: - t.join(timeout=10) - assert not errors, f"Thread errors: {errors}" - - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.in_flight == 0 diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_throttled_model_client.py b/packages/data-designer-engine/tests/engine/models/clients/test_throttled_model_client.py deleted file mode 100644 index 0bec7231b..000000000 --- a/packages/data-designer-engine/tests/engine/models/clients/test_throttled_model_client.py +++ /dev/null @@ -1,469 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import asyncio -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from data_designer.config.run_config import ThrottleConfig -from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind -from data_designer.engine.models.clients.throttle_manager import DomainThrottleState, ThrottleDomain, ThrottleManager -from data_designer.engine.models.clients.throttled import ThrottledModelClient -from data_designer.engine.models.clients.types import ( - AssistantMessage, - ChatCompletionRequest, - ChatCompletionResponse, - EmbeddingRequest, - EmbeddingResponse, - ImageGenerationRequest, - ImageGenerationResponse, - Usage, -) - -PROVIDER = "test-provider" -MODEL_ID = "test-model" - - -@pytest.fixture -def throttle_manager() -> ThrottleManager: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL_ID, alias="alias", max_parallel_requests=10) - return tm - - -@pytest.fixture -def inner_client() -> MagicMock: - client = MagicMock() - client.provider_name = PROVIDER - client.supports_chat_completion.return_value = True - client.supports_embeddings.return_value = True - client.supports_image_generation.return_value = True - client.completion.return_value = ChatCompletionResponse(message=AssistantMessage(content="ok"), usage=Usage()) - client.acompletion = AsyncMock( - return_value=ChatCompletionResponse(message=AssistantMessage(content="ok"), usage=Usage()) - ) - client.embeddings.return_value = EmbeddingResponse(vectors=[[0.1]], usage=Usage()) - client.aembeddings = AsyncMock(return_value=EmbeddingResponse(vectors=[[0.1]], usage=Usage())) - client.generate_image.return_value = ImageGenerationResponse(images=[]) - client.agenerate_image = AsyncMock(return_value=ImageGenerationResponse(images=[])) - client.close.return_value = None - client.aclose = AsyncMock() - return client - - -@pytest.fixture -def throttled_client(inner_client: MagicMock, throttle_manager: ThrottleManager) -> ThrottledModelClient: - return ThrottledModelClient( - inner=inner_client, - throttle_manager=throttle_manager, - provider_name=PROVIDER, - model_id=MODEL_ID, - ) - - -# --- Protocol delegation --- - - -def test_provider_name_delegates(throttled_client: ThrottledModelClient) -> None: - assert throttled_client.provider_name == PROVIDER - - -def test_supports_methods_delegate(throttled_client: ThrottledModelClient) -> None: - assert throttled_client.supports_chat_completion() is True - assert throttled_client.supports_embeddings() is True - assert throttled_client.supports_image_generation() is True - - -def test_close_delegates(throttled_client: ThrottledModelClient, inner_client: MagicMock) -> None: - throttled_client.close() - inner_client.close.assert_called_once() - - -@pytest.mark.asyncio(loop_scope="session") -async def test_aclose_delegates(throttled_client: ThrottledModelClient, inner_client: MagicMock) -> None: - await throttled_client.aclose() - inner_client.aclose.assert_awaited_once() - - -# --- Sync: acquire/release on success --- - - -def test_completion_success_releases_success( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ChatCompletionRequest(model=MODEL_ID, messages=[]) - result = throttled_client.completion(request) - assert result.message.content == "ok" - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - assert state.success_streak == 1 - - -def test_embeddings_success_releases_success( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = EmbeddingRequest(model=MODEL_ID, inputs=["hello"]) - result = throttled_client.embeddings(request) - assert result.vectors == [[0.1]] - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.EMBEDDING) - assert state is not None - assert state.in_flight == 0 - assert state.success_streak == 1 - - -def test_generate_image_diffusion_uses_image_domain( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ImageGenerationRequest(model=MODEL_ID, prompt="a cat", messages=None) - throttled_client.generate_image(request) - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.IMAGE) - assert state is not None - assert state.success_streak == 1 - - -def test_generate_image_chat_backed_uses_chat_domain( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ImageGenerationRequest(model=MODEL_ID, prompt="a cat", messages=[{"role": "user", "content": "draw"}]) - throttled_client.generate_image(request) - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.success_streak == 1 - - -# --- Async: acquire/release on success --- - - -@pytest.mark.asyncio(loop_scope="session") -async def test_acompletion_success_releases_success( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ChatCompletionRequest(model=MODEL_ID, messages=[]) - result = await throttled_client.acompletion(request) - assert result.message.content == "ok" - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - assert state.success_streak == 1 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_aembeddings_success_releases_success( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = EmbeddingRequest(model=MODEL_ID, inputs=["hello"]) - result = await throttled_client.aembeddings(request) - assert result.vectors == [[0.1]] - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.EMBEDDING) - assert state is not None - assert state.in_flight == 0 - assert state.success_streak == 1 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_agenerate_image_diffusion_uses_image_domain( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ImageGenerationRequest(model=MODEL_ID, prompt="a cat", messages=None) - await throttled_client.agenerate_image(request) - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.IMAGE) - assert state is not None - assert state.success_streak == 1 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_agenerate_image_chat_backed_uses_chat_domain( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ImageGenerationRequest(model=MODEL_ID, prompt="a cat", messages=[{"role": "user", "content": "draw"}]) - await throttled_client.agenerate_image(request) - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.success_streak == 1 - - -# --- Rate-limit error: release_rate_limited with retry_after --- - - -def test_completion_rate_limit_calls_release_rate_limited( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - throttled_client._inner.completion.side_effect = ProviderError( - kind=ProviderErrorKind.RATE_LIMIT, - message="429", - status_code=429, - retry_after=5.0, - ) - with pytest.raises(ProviderError, match="429"): - throttled_client.completion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - assert state.blocked_until > 0 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_acompletion_rate_limit_calls_release_rate_limited( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - throttled_client._inner.acompletion = AsyncMock( - side_effect=ProviderError( - kind=ProviderErrorKind.RATE_LIMIT, - message="429", - status_code=429, - retry_after=3.0, - ) - ) - with pytest.raises(ProviderError, match="429"): - await throttled_client.acompletion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - assert state.blocked_until > 0 - - -# --- Non-rate-limit ProviderError: release_failure --- - - -def test_completion_non_rate_limit_error_calls_release_failure( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - throttled_client._inner.completion.side_effect = ProviderError( - kind=ProviderErrorKind.INTERNAL_SERVER, - message="500", - status_code=500, - ) - with pytest.raises(ProviderError, match="500"): - throttled_client.completion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - assert state.success_streak == 0 - - -# --- Non-ProviderError exception: release_failure --- - - -def test_completion_generic_exception_calls_release_failure( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - throttled_client._inner.completion.side_effect = RuntimeError("boom") - with pytest.raises(RuntimeError, match="boom"): - throttled_client.completion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_acompletion_generic_exception_calls_release_failure( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - throttled_client._inner.acompletion = AsyncMock(side_effect=RuntimeError("boom")) - with pytest.raises(RuntimeError, match="boom"): - await throttled_client.acompletion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - - -# --- Acquire timeout: normalized to ProviderError(kind=TIMEOUT), no release --- - - -def test_sync_acquire_timeout_normalized_to_provider_error(inner_client: MagicMock) -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL_ID, alias="alias", max_parallel_requests=1) - client = ThrottledModelClient(inner=inner_client, throttle_manager=tm, provider_name=PROVIDER, model_id=MODEL_ID) - - with patch.object(tm, "acquire_sync", side_effect=TimeoutError("timed out")): - with pytest.raises(ProviderError) as exc_info: - client.completion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - assert exc_info.value.kind == ProviderErrorKind.TIMEOUT - - inner_client.completion.assert_not_called() - - -@pytest.mark.asyncio(loop_scope="session") -async def test_async_acquire_timeout_normalized_to_provider_error(inner_client: MagicMock) -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL_ID, alias="alias", max_parallel_requests=1) - client = ThrottledModelClient(inner=inner_client, throttle_manager=tm, provider_name=PROVIDER, model_id=MODEL_ID) - - with patch.object(tm, "acquire_async", side_effect=TimeoutError("timed out")): - with pytest.raises(ProviderError) as exc_info: - await client.acompletion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - assert exc_info.value.kind == ProviderErrorKind.TIMEOUT - - inner_client.acompletion.assert_not_awaited() - - -# --- Cancellation: release_failure on CancelledError --- - - -@pytest.mark.asyncio(loop_scope="session") -async def test_acompletion_cancelled_releases_permit(throttle_manager: ThrottleManager) -> None: - """CancelledError during an in-flight async request releases the throttle permit.""" - blocked = asyncio.Event() - - async def slow_acompletion(_request: ChatCompletionRequest) -> ChatCompletionResponse: - blocked.set() - await asyncio.sleep(60) - return ChatCompletionResponse(message=AssistantMessage(content="ok"), usage=Usage()) - - inner = MagicMock() - inner.provider_name = PROVIDER - inner.acompletion = slow_acompletion - - client = ThrottledModelClient( - inner=inner, throttle_manager=throttle_manager, provider_name=PROVIDER, model_id=MODEL_ID - ) - request = ChatCompletionRequest(model=MODEL_ID, messages=[]) - - task = asyncio.create_task(client.acompletion(request)) - await blocked.wait() - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 1 - - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - - assert state.in_flight == 0 - - -# --- E2E: full AIMD feedback loop --- - - -@pytest.mark.asyncio(loop_scope="session") -async def test_aimd_feedback_loop_rate_limit_reduces_then_successes_recover() -> None: - """Verify the full AIMD cycle: success -> rate-limit halves limit -> successes recover. - - Uses a real ThrottleManager with aggressive tuning (success_window=2, - additive_increase=1) so the test can drive a full decrease+increase cycle - with a small number of calls. - - Sequence: - 1. Register model with max_parallel_requests=4. - 2. Make 1 successful async completion -> limit stays 4, streak=1. - 3. Hit a 429 with retry_after=0.01s -> limit halves to 2, cooldown applied. - 4. Wait for cooldown to expire. - 5. Make 2 more successes -> streak reaches success_window=2, limit increases to 3. - 6. Make 2 more successes -> limit increases to 4 (full recovery). - """ - tm = ThrottleManager( - ThrottleConfig( - reduce_factor=0.5, - additive_increase=1, - success_window=2, - cooldown_seconds=0.01, - ) - ) - tm.register(provider_name=PROVIDER, model_id=MODEL_ID, alias="a", max_parallel_requests=4) - - call_count = 0 - rate_limit_on_call = 2 - - async def mock_acompletion(request: ChatCompletionRequest) -> ChatCompletionResponse: - nonlocal call_count - call_count += 1 - if call_count == rate_limit_on_call: - raise ProviderError( - kind=ProviderErrorKind.RATE_LIMIT, - message="429 Too Many Requests", - status_code=429, - retry_after=0.01, - ) - return ChatCompletionResponse(message=AssistantMessage(content="ok"), usage=Usage()) - - inner = MagicMock() - inner.provider_name = PROVIDER - inner.acompletion = mock_acompletion - - client = ThrottledModelClient(inner=inner, throttle_manager=tm, provider_name=PROVIDER, model_id=MODEL_ID) - request = ChatCompletionRequest(model=MODEL_ID, messages=[]) - - def get_state() -> DomainThrottleState: - s = tm.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert s is not None - return s - - # Step 1: first success - await client.acompletion(request) - assert get_state().current_limit == 4 - assert get_state().success_streak == 1 - - # Step 2: 429 -> AIMD decrease - with pytest.raises(ProviderError): - await client.acompletion(request) - assert get_state().current_limit == 2 - assert get_state().success_streak == 0 - assert get_state().in_flight == 0 - - # Step 3: wait for cooldown - await asyncio.sleep(0.02) - - # Step 4: two successes -> additive increase (limit 2 -> 3) - await client.acompletion(request) - assert get_state().success_streak == 1 - await client.acompletion(request) - assert get_state().current_limit == 3 - assert get_state().success_streak == 0 - - # Step 5: two more successes -> additive increase (limit 3 -> 4, full recovery) - await client.acompletion(request) - await client.acompletion(request) - assert get_state().current_limit == 4 - assert get_state().success_streak == 0 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_concurrent_requests_bounded_by_throttle_limit() -> None: - """Concurrent async requests are bounded by the throttle limit. - - Registers a model with max_parallel_requests=2, fires 5 concurrent - acompletion calls that each sleep briefly, and verifies that the - ThrottleManager never had more than 2 in-flight at once. - """ - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL_ID, alias="a", max_parallel_requests=2) - - peak_in_flight = 0 - lock = asyncio.Lock() - - async def mock_acompletion(request: ChatCompletionRequest) -> ChatCompletionResponse: - nonlocal peak_in_flight - state = tm.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - if state is not None: - async with lock: - peak_in_flight = max(peak_in_flight, state.in_flight) - await asyncio.sleep(0.02) - return ChatCompletionResponse(message=AssistantMessage(content="ok"), usage=Usage()) - - inner = MagicMock() - inner.provider_name = PROVIDER - inner.acompletion = mock_acompletion - - client = ThrottledModelClient(inner=inner, throttle_manager=tm, provider_name=PROVIDER, model_id=MODEL_ID) - request = ChatCompletionRequest(model=MODEL_ID, messages=[]) - - tasks = [asyncio.create_task(client.acompletion(request)) for _ in range(5)] - results = await asyncio.gather(*tasks, return_exceptions=True) - - successes = [r for r in results if not isinstance(r, Exception)] - assert len(successes) == 5 - assert peak_in_flight <= 2 - - state = tm.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 diff --git a/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py b/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py new file mode 100644 index 000000000..c0c9a5801 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py @@ -0,0 +1,442 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import logging +import threading +import time + +import pytest + +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.controller import ( + RELEASED_LEASE_HISTORY_LIMIT, + AdaptiveRequestAdmissionController, + RequestAdmissionDenied, + RequestAdmissionError, + RequestAdmissionLease, +) +from data_designer.engine.models.request_admission.outcomes import RequestReleaseOutcome +from data_designer.engine.models.request_admission.resources import ( + RequestAdmissionItem, + RequestDomain, + RequestGroupSpec, + RequestResourceKey, +) +from data_designer.engine.observability import InMemoryAdmissionEventSink + + +def _item(domain: RequestDomain = RequestDomain.CHAT, timeout: float | None = None) -> RequestAdmissionItem: + resource = RequestResourceKey("nvidia", "nemotron", domain) + return RequestAdmissionItem( + resource=resource, + group=RequestGroupSpec(resource), + queue_wait_timeout_seconds=timeout, + ) + + +def _controller(cap: int = 2, config: RequestAdmissionConfig | None = None) -> AdaptiveRequestAdmissionController: + controller = AdaptiveRequestAdmissionController(config) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=cap) + return controller + + +class _BrokenRequestSink: + def emit_request_event(self, _event: object) -> None: + raise RuntimeError("sink boom") + + +def test_request_admission_acquires_and_releases_lease() -> None: + controller = _controller(cap=1) + item = _item() + + decision = controller.try_acquire(item) + + assert isinstance(decision, RequestAdmissionLease) + assert controller.pressure.snapshot(item.resource).in_flight_count == 1 # type: ignore[union-attr] + result = controller.release(decision, RequestReleaseOutcome(kind="success")) + assert result.released is True + assert controller.pressure.snapshot(item.resource).in_flight_count == 0 # type: ignore[union-attr] + + +def test_request_admission_enforces_provider_model_aggregate_cap() -> None: + controller = _controller(cap=1) + chat = _item(RequestDomain.CHAT) + embedding = _item(RequestDomain.EMBEDDING) + lease = controller.try_acquire(chat) + assert isinstance(lease, RequestAdmissionLease) + + denied = controller.try_acquire(embedding) + + assert isinstance(denied, RequestAdmissionDenied) + assert denied.reason == "no_capacity" + + +def test_request_admission_duplicate_release_does_not_corrupt_counts() -> None: + controller = _controller(cap=1) + item = _item() + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + + first = controller.release(lease, RequestReleaseOutcome(kind="success")) + second = controller.release(lease, RequestReleaseOutcome(kind="success")) + + assert first.released is True + assert second.released is False + assert second.reason == "duplicate" + assert controller.pressure.snapshot(item.resource).active_lease_count == 0 # type: ignore[union-attr] + + +def test_request_admission_stale_release_requires_exact_lease() -> None: + controller = _controller(cap=1) + item = _item() + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + stale = RequestAdmissionLease( + lease_id=lease.lease_id, + item=lease.item, + acquired_at=lease.acquired_at, + current_adaptive_limit=lease.current_adaptive_limit + 1, + effective_max=lease.effective_max, + controller_generation=lease.controller_generation, + ) + + stale_result = controller.release(stale, RequestReleaseOutcome(kind="provider_failure")) + snapshot = controller.pressure.snapshot(item.resource) + + assert stale_result.released is False + assert stale_result.reason == "stale_lease" + assert snapshot is not None + assert snapshot.in_flight_count == 1 + assert snapshot.active_lease_count == 1 + + released = controller.release(lease, RequestReleaseOutcome(kind="success")) + + assert released.released is True + assert controller.pressure.snapshot(item.resource).active_lease_count == 0 # type: ignore[union-attr] + + +def test_request_admission_rate_limit_decreases_and_sets_cooldown() -> None: + controller = _controller( + cap=4, + config=RequestAdmissionConfig( + multiplicative_decrease_factor=0.5, + cooldown_seconds=10, + ), + ) + item = _item() + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + + controller.release(lease, RequestReleaseOutcome(kind="rate_limited", retry_after_seconds=1.0)) + denied = controller.try_acquire(item) + snapshot = controller.pressure.snapshot(item.resource) + + assert isinstance(denied, RequestAdmissionDenied) + assert denied.reason == "cooldown" + assert snapshot is not None + assert snapshot.current_limit == 2 + assert snapshot.cooldown_remaining_seconds > 0 + + +def test_request_admission_rate_limit_burst_decreases_once_per_cascade() -> None: + controller = _controller( + cap=8, + config=RequestAdmissionConfig( + multiplicative_decrease_factor=0.5, + cooldown_seconds=10, + ), + ) + item = _item() + leases = [controller.try_acquire(item) for _ in range(8)] + assert all(isinstance(lease, RequestAdmissionLease) for lease in leases) + + for lease in leases: + controller.release(lease, RequestReleaseOutcome(kind="rate_limited")) + snapshot = controller.pressure.snapshot(item.resource) + + assert snapshot is not None + assert snapshot.current_limit == 4 + assert snapshot.rate_limit_ceiling == 8 + assert snapshot.consecutive_rate_limits == 8 + + +def test_request_admission_fresh_rate_limit_after_burst_decreases_again() -> None: + controller = _controller( + cap=8, + config=RequestAdmissionConfig( + multiplicative_decrease_factor=0.5, + cooldown_seconds=0, + ), + ) + item = _item() + leases = [controller.try_acquire(item) for _ in range(8)] + assert all(isinstance(lease, RequestAdmissionLease) for lease in leases) + + for lease in leases: + controller.release(lease, RequestReleaseOutcome(kind="rate_limited")) + snapshot = controller.pressure.snapshot(item.resource) + assert snapshot is not None + assert snapshot.current_limit == 4 + assert snapshot.rate_limit_ceiling == 8 + + fresh_lease = controller.try_acquire(item) + assert isinstance(fresh_lease, RequestAdmissionLease) + assert fresh_lease.current_adaptive_limit == 4 + + controller.release(fresh_lease, RequestReleaseOutcome(kind="rate_limited")) + snapshot = controller.pressure.snapshot(item.resource) + + assert snapshot is not None + assert snapshot.current_limit == 2 + assert snapshot.rate_limit_ceiling == 8 + assert snapshot.consecutive_rate_limits == 9 + + +def test_request_admission_additive_recovery_after_successes() -> None: + item = _item() + controller = _controller( + cap=3, + config=RequestAdmissionConfig( + initial_limits={item.resource: 1}, + increase_after_successes=1, + additive_increase_step=1, + ), + ) + + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + controller.release(lease, RequestReleaseOutcome(kind="success")) + + assert controller.pressure.snapshot(item.resource).current_limit == 2 # type: ignore[union-attr] + + +def test_request_admission_blocking_timeout_raises_typed_error() -> None: + controller = _controller(cap=1) + first = _item() + second = _item(timeout=0.01) + lease = controller.try_acquire(first) + assert isinstance(lease, RequestAdmissionLease) + + with pytest.raises(RequestAdmissionError) as exc_info: + controller.acquire_sync(second) + + assert exc_info.value.decision.reason == "queue_timeout" + + +def test_request_admission_zero_sync_timeout_is_immediate() -> None: + controller = _controller(cap=1) + lease = controller.try_acquire(_item()) + assert isinstance(lease, RequestAdmissionLease) + + with pytest.raises(RequestAdmissionError) as exc_info: + controller.acquire_sync(_item(RequestDomain.EMBEDDING, timeout=0.0)) + + assert exc_info.value.decision.reason == "queue_timeout" + snapshot = controller.pressure.snapshot(RequestResourceKey("nvidia", "nemotron", RequestDomain.EMBEDDING)) + assert snapshot is not None + assert snapshot.waiters == 0 + controller.release(lease, RequestReleaseOutcome(kind="success")) + + +def test_request_admission_sync_unregistered_provider_raises_hard_denial() -> None: + controller = AdaptiveRequestAdmissionController() + + with pytest.raises(RequestAdmissionError) as exc_info: + controller.acquire_sync(_item()) + + assert exc_info.value.decision.reason == "hard_policy_denial" + snapshot = controller.pressure.snapshot(_item().resource) + assert snapshot is not None + assert snapshot.waiters == 0 + + +def test_request_admission_logs_sink_failures(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level(logging.WARNING, logger="data_designer.engine.models.request_admission.controller") + controller = AdaptiveRequestAdmissionController(event_sink=_BrokenRequestSink()) + + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + + assert "Request admission event sink raised; dropping event." in caplog.text + + +def test_request_lease_released_event_records_release_outcome() -> None: + sink = InMemoryAdmissionEventSink() + controller = AdaptiveRequestAdmissionController(event_sink=sink) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + item = _item() + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + + controller.release(lease, RequestReleaseOutcome(kind="provider_failure")) + + release_events = [event for event in sink.request_events if event.event_kind == "request_lease_released"] + assert release_events + assert release_events[-1].reason_or_outcome == "provider_failure" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_acquire_sync_rejects_running_event_loop() -> None: + controller = _controller(cap=1) + + with pytest.raises(RuntimeError, match="would block the running event loop"): + controller.acquire_sync(_item()) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_try_acquire_does_not_bypass_queued_waiter_for_same_provider_model() -> None: + controller = _controller(cap=1) + first = _item(RequestDomain.CHAT) + queued = _item(RequestDomain.EMBEDDING, timeout=2) + incoming = _item(RequestDomain.IMAGE) + lease = controller.try_acquire(first) + assert isinstance(lease, RequestAdmissionLease) + + queued_task = asyncio.create_task(controller.acquire_async(queued)) + await asyncio.sleep(0) + + denied = controller.try_acquire(incoming) + + assert isinstance(denied, RequestAdmissionDenied) + assert denied.reason == "no_capacity" + snapshot = controller.pressure.snapshot(queued.resource) + assert snapshot is not None + assert snapshot.waiters == 1 + controller.release(lease, RequestReleaseOutcome(kind="success")) + queued_lease = await queued_task + controller.release(queued_lease, RequestReleaseOutcome(kind="success")) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_request_admission_zero_async_timeout_is_immediate() -> None: + controller = _controller(cap=1) + lease = controller.try_acquire(_item()) + assert isinstance(lease, RequestAdmissionLease) + + with pytest.raises(RequestAdmissionError) as exc_info: + await controller.acquire_async(_item(RequestDomain.EMBEDDING, timeout=0.0)) + + assert exc_info.value.decision.reason == "queue_timeout" + snapshot = controller.pressure.snapshot(RequestResourceKey("nvidia", "nemotron", RequestDomain.EMBEDDING)) + assert snapshot is not None + assert snapshot.waiters == 0 + controller.release(lease, RequestReleaseOutcome(kind="success")) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_acquire_async_does_not_assign_expired_waiter_after_release( + monkeypatch: pytest.MonkeyPatch, +) -> None: + controller = _controller(cap=1) + monkeypatch.setattr(controller, "_wait_seconds_locked", lambda _item, _now, _deadline: 10.0) + lease = controller.try_acquire(_item(RequestDomain.CHAT)) + assert isinstance(lease, RequestAdmissionLease) + queued = _item(RequestDomain.EMBEDDING, timeout=0.01) + + queued_task = asyncio.create_task(controller.acquire_async(queued)) + for _ in range(20): + snapshot = controller.pressure.snapshot(queued.resource) + if snapshot is not None and snapshot.waiters == 1: + break + await asyncio.sleep(0) + else: + raise AssertionError("async waiter did not enqueue") + + def release_after_deadline() -> None: + time.sleep(0.03) + controller.release(lease, RequestReleaseOutcome(kind="success")) + + release_thread = threading.Thread(target=release_after_deadline) + release_thread.start() + try: + time.sleep(0.06) + with pytest.raises(RequestAdmissionError) as exc_info: + await asyncio.wait_for(queued_task, timeout=0.5) + finally: + release_thread.join() + + assert exc_info.value.decision.reason == "queue_timeout" + snapshot = controller.pressure.snapshot(queued.resource) + assert snapshot is not None + assert snapshot.waiters == 0 + assert snapshot.active_lease_count == 0 + assert snapshot.in_flight_count == 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_acquire_async_wakes_when_release_assigns_lease(monkeypatch: pytest.MonkeyPatch) -> None: + controller = _controller(cap=1) + monkeypatch.setattr(controller, "_wait_seconds_locked", lambda _item, _now, _deadline: 10.0) + lease = controller.try_acquire(_item(RequestDomain.CHAT)) + assert isinstance(lease, RequestAdmissionLease) + queued = _item(RequestDomain.EMBEDDING, timeout=30.0) + + queued_task = asyncio.create_task(controller.acquire_async(queued)) + for _ in range(20): + snapshot = controller.pressure.snapshot(queued.resource) + if snapshot is not None and snapshot.waiters == 1: + break + await asyncio.sleep(0) + else: + raise AssertionError("async waiter did not enqueue") + + controller.release(lease, RequestReleaseOutcome(kind="success")) + queued_lease = await asyncio.wait_for(queued_task, timeout=0.5) + + controller.release(queued_lease, RequestReleaseOutcome(kind="success")) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_acquire_async_unregistered_provider_raises_hard_denial(monkeypatch: pytest.MonkeyPatch) -> None: + controller = AdaptiveRequestAdmissionController() + monkeypatch.setattr(controller, "_wait_seconds_locked", lambda _item, _now, _deadline: 10.0) + queued = _item(RequestDomain.CHAT, timeout=30.0) + + with pytest.raises(RequestAdmissionError) as exc_info: + await asyncio.wait_for(controller.acquire_async(queued), timeout=0.5) + + assert exc_info.value.decision.reason == "hard_policy_denial" + snapshot = controller.pressure.snapshot(queued.resource) + assert snapshot is not None + assert snapshot.waiters == 0 + + +def test_request_admission_released_history_is_bounded() -> None: + controller = _controller(cap=1) + first_lease: RequestAdmissionLease | None = None + for _ in range(RELEASED_LEASE_HISTORY_LIMIT + 5): + lease = controller.try_acquire(_item()) + assert isinstance(lease, RequestAdmissionLease) + first_lease = first_lease or lease + controller.release(lease, RequestReleaseOutcome(kind="success")) + + assert len(controller._released) == RELEASED_LEASE_HISTORY_LIMIT + assert len(controller._released_order) == RELEASED_LEASE_HISTORY_LIMIT + assert controller._released_order.maxlen == RELEASED_LEASE_HISTORY_LIMIT + assert first_lease is not None + assert controller.release(first_lease, RequestReleaseOutcome(kind="success")).reason == "unknown_lease" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_async_cancellation_after_waiter_assignment_releases_lease() -> None: + controller = _controller(cap=1) + first = _item(RequestDomain.CHAT) + queued = _item(RequestDomain.EMBEDDING, timeout=1.0) + lease = controller.try_acquire(first) + assert isinstance(lease, RequestAdmissionLease) + + queued_task = asyncio.create_task(controller.acquire_async(queued)) + await asyncio.sleep(0) + controller.release(lease, RequestReleaseOutcome(kind="success")) + queued_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await queued_task + + snapshot = controller.pressure.snapshot(queued.resource) + assert snapshot is not None + assert snapshot.active_lease_count == 0 + assert snapshot.in_flight_count == 0 + assert snapshot.waiters == 0 diff --git a/packages/data-designer-engine/tests/engine/test_capacity.py b/packages/data-designer-engine/tests/engine/test_capacity.py new file mode 100644 index 000000000..6fd4daa33 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/test_capacity.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from data_designer.engine.capacity import ( + AsyncCapacityConfigured, + AsyncCapacityObservedMaxima, + AsyncCapacityPlan, + AsyncCapacityRuntimeSnapshot, + CapacityValue, + RequestAdmissionConfigSnapshot, + RowGroupAdmission, +) +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap + + +def test_request_admission_config_snapshot_records_resources() -> None: + resource = RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT) + config = RequestAdmissionConfig(initial_limits={resource: 2}, max_limit_clamps={resource: 4}) + + snapshot = RequestAdmissionConfigSnapshot.from_config(config) + + assert snapshot.resources == (resource,) + assert snapshot.initial_limits[resource] == 2 + assert snapshot.max_limit_clamps[resource] == 4 + + +def test_async_capacity_plan_records_configured_runtime_and_maxima() -> None: + resource = RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT) + provider_model = ProviderModelKey("nvidia", "nemotron") + static_cap = ProviderModelStaticCap(cap=4, aliases=("default",), raw_caps={"default": 4}) + + plan = AsyncCapacityPlan( + configured=AsyncCapacityConfigured( + buffer_size=CapacityValue(value=16, source="run_config"), + row_group_admission=RowGroupAdmission( + row_group_concurrency=CapacityValue(value=2, source="dataset_builder"), + observed_in_flight=1, + ), + submission_capacity=CapacityValue(value=8, source="engine_internal_config"), + task_resource_limits=CapacityValue(value={"submission": 8, "llm_wait": 4}, source="engine_internal_config"), + request_resources=CapacityValue(value=(resource,), source="runtime_snapshot"), + provider_model_static_caps=CapacityValue(value={provider_model: static_cap}, source="model_metadata"), + request_domain_initial_limits=CapacityValue(value={resource: 2}, source="engine_internal_config"), + request_admission_config=CapacityValue( + value=RequestAdmissionConfigSnapshot.from_config(RequestAdmissionConfig(initial_limits={resource: 2})), + source="engine_internal_config", + ), + transport_pool_limits=CapacityValue(value={provider_model: 8}, source="adapter_config"), + ), + runtime_snapshot=AsyncCapacityRuntimeSnapshot( + request_domain_current_limits={resource: 2}, + request_domain_effective_max={resource: 4}, + request_domain_blocked_until={resource: None}, + provider_model_aggregate_in_flight={provider_model: 0}, + ), + observed_maxima=AsyncCapacityObservedMaxima( + row_groups_in_flight=1, + request_in_flight_by_resource={resource: 2}, + provider_model_aggregate_in_flight={provider_model: 2}, + ), + ) + + assert plan.configured.provider_model_static_caps.value[provider_model].merge_rule == "min_same_endpoint" + assert plan.runtime_snapshot.request_domain_current_limits[resource] == 2 + assert plan.observed_maxima.provider_model_aggregate_in_flight[provider_model] == 2 diff --git a/packages/data-designer-engine/tests/engine/test_observability.py b/packages/data-designer-engine/tests/engine/test_observability.py new file mode 100644 index 000000000..55a2896b1 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/test_observability.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from data_designer.engine.observability import ( + CorrelatedRuntimeView, + InMemoryAdmissionEventSink, + RequestAdmissionEvent, + RuntimeCorrelation, + RuntimeCorrelationProvider, + SchedulerAdmissionEvent, +) + + +def _correlation() -> RuntimeCorrelation: + return RuntimeCorrelation( + run_id="run", + row_group=1, + task_column="answer", + task_type="cell", + scheduling_group_kind="model", + scheduling_group_identity_hash="hash", + task_execution_id="task-exec", + ) + + +def test_runtime_correlation_provider_sets_and_resets_context() -> None: + provider = RuntimeCorrelationProvider() + correlation = _correlation() + + token = provider.set(correlation) + assert provider.current() == correlation + + provider.reset(token) + assert provider.current() is None + + +def test_admission_events_capture_correlation_and_diagnostics() -> None: + correlation = _correlation() + + scheduler_event = SchedulerAdmissionEvent.capture( + "task_lease_acquired", + sequence=1, + correlation=correlation, + task_id="task-1", + task_lease_id="lease-1", + diagnostics={"resource": "submission"}, + ) + request_event = RequestAdmissionEvent.capture( + "request_lease_acquired", + sequence=2, + correlation=correlation, + request_attempt_id="request-1", + request_lease_id="lease-2", + diagnostics={"resource": "chat"}, + ) + + assert scheduler_event.captured_correlation == correlation + assert scheduler_event.task_id == "task-1" + assert scheduler_event.diagnostics == {"resource": "submission"} + assert request_event.captured_correlation == correlation + assert request_event.request_attempt_id == "request-1" + assert request_event.diagnostics == {"resource": "chat"} + + +def test_in_memory_admission_event_sink_collects_scheduler_and_request_events() -> None: + sink = InMemoryAdmissionEventSink() + scheduler_event = SchedulerAdmissionEvent.capture("selected", sequence=1) + request_event = RequestAdmissionEvent.capture("request_wait_started", sequence=2) + + sink.emit_scheduler_event(scheduler_event) + sink.emit_request_event(request_event) + + assert sink.scheduler_events == [scheduler_event] + assert sink.request_events == [request_event] + + +def test_correlated_runtime_view_timeline_sorts_events() -> None: + scheduler_event = SchedulerAdmissionEvent(event_kind="selected", captured_at_monotonic=2.0, sequence=1) + first_request_event = RequestAdmissionEvent( + event_kind="request_wait_started", + captured_at_monotonic=1.0, + sequence=3, + ) + second_request_event = RequestAdmissionEvent( + event_kind="request_lease_acquired", + captured_at_monotonic=2.0, + sequence=0, + ) + view = CorrelatedRuntimeView( + scheduler_events=(scheduler_event,), + request_events=(first_request_event, second_request_event), + ) + + assert view.timeline == (first_request_event, second_request_event, scheduler_event) diff --git a/packages/data-designer/src/data_designer/interface/results.py b/packages/data-designer/src/data_designer/interface/results.py index 29f222279..c90070c58 100644 --- a/packages/data-designer/src/data_designer/interface/results.py +++ b/packages/data-designer/src/data_designer/interface/results.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: import pandas as pd - from data_designer.engine.dataset_builders.utils.task_model import TaskTrace + from data_designer.engine.dataset_builders.scheduling.task_model import TaskTrace ExportFormat = Literal["jsonl", "csv", "parquet"] SUPPORTED_EXPORT_FORMATS: tuple[str, ...] = get_args(ExportFormat) diff --git a/tests_e2e/tests/test_mcp_demo.py b/tests_e2e/tests/test_mcp_demo.py index 163e904cb..d7bca2e9e 100644 --- a/tests_e2e/tests/test_mcp_demo.py +++ b/tests_e2e/tests/test_mcp_demo.py @@ -101,25 +101,25 @@ def test_mcp_server_tool_usage_with_nvidia_text(tmp_path: Path) -> None: assert tool_call_messages tool_calls: list[dict[str, object]] = [] - tool_call_indices: dict[str, int] = {} + tool_call_positions: dict[str, tuple[int, int]] = {} for msg_index, msg in enumerate(trace): if not isinstance(msg, dict): continue if msg.get("role") != "assistant": continue - for tool_call in msg.get("tool_calls") or []: + for tool_call_index, tool_call in enumerate(msg.get("tool_calls") or []): if not isinstance(tool_call, dict): continue tool_calls.append(tool_call) function = tool_call.get("function") or {} if isinstance(function, dict): name = function.get("name") - if isinstance(name, str) and name not in tool_call_indices: - tool_call_indices[name] = msg_index + if isinstance(name, str) and name not in tool_call_positions: + tool_call_positions[name] = (msg_index, tool_call_index) - assert tool_call_indices.get("get_fact") is not None - assert tool_call_indices.get("add_numbers") is not None - assert tool_call_indices["get_fact"] < tool_call_indices["add_numbers"] + assert tool_call_positions.get("get_fact") is not None + assert tool_call_positions.get("add_numbers") is not None + assert tool_call_positions["get_fact"] < tool_call_positions["add_numbers"] def _tool_call_to_name_args(tool_call: dict[str, object]) -> tuple[str | None, dict[str, object]]: function = tool_call.get("function")