diff --git a/nemo_curator/core/client.py b/nemo_curator/core/client.py index 10facab1a2..b85858ae3c 100644 --- a/nemo_curator/core/client.py +++ b/nemo_curator/core/client.py @@ -60,6 +60,8 @@ class RayClient: Args: ray_port: The port number of the Ray GCS. ray_dashboard_port: The port number of the Ray dashboard. + ray_min_worker_port: The first worker port Ray may bind. + ray_max_worker_port: The last worker port Ray may bind. ray_temp_dir: The temporary directory to use for Ray. include_dashboard: Whether to include dashboard integration. If true, adds Ray metrics service discovery. ray_metrics_port: The port number of the Ray metrics. @@ -79,6 +81,8 @@ class RayClient: ray_port: int = DEFAULT_RAY_PORT ray_dashboard_port: int = DEFAULT_RAY_DASHBOARD_PORT ray_client_server_port: int = DEFAULT_RAY_CLIENT_SERVER_PORT + ray_min_worker_port: int | None = None + ray_max_worker_port: int | None = None ray_temp_dir: str = DEFAULT_RAY_TEMP_DIR include_dashboard: bool = True ray_metrics_port: int = DEFAULT_RAY_METRICS_PORT @@ -155,6 +159,8 @@ def start(self) -> None: ray_metrics_port=self.ray_metrics_port, ray_client_server_port=self.ray_client_server_port, ray_dashboard_host=self.ray_dashboard_host, + ray_min_worker_port=self.ray_min_worker_port, + ray_max_worker_port=self.ray_max_worker_port, num_gpus=self.num_gpus, num_cpus=self.num_cpus, object_store_memory=self.object_store_memory, diff --git a/nemo_curator/core/serve/dynamo/backend.py b/nemo_curator/core/serve/dynamo/backend.py index 0ed4ee6dbd..36003f4a06 100644 --- a/nemo_curator/core/serve/dynamo/backend.py +++ b/nemo_curator/core/serve/dynamo/backend.py @@ -290,7 +290,7 @@ def _resolve_effective_router( - ``mode``: honor ``router.mode`` if set; otherwise auto-pick ``"kv"`` when any model uses ``mode="disagg"``, else leave unset so the - Dynamo frontend falls back to its own ``round_robin`` default. + Dynamo frontend falls back to its own ``round-robin`` default. - ``kv_events``: when we auto-pick ``mode="kv"`` we also auto-enable ``kv_events`` so the router consumes what prefill workers publish unconditionally in disagg. If the user set ``router.mode`` explicitly diff --git a/nemo_curator/core/serve/dynamo/config.py b/nemo_curator/core/serve/dynamo/config.py index 3422b40340..708bcfd529 100644 --- a/nemo_curator/core/serve/dynamo/config.py +++ b/nemo_curator/core/serve/dynamo/config.py @@ -36,26 +36,41 @@ def __post_init__(self) -> None: raise ValueError(msg) +DynamoRouterMode = Literal[ + "round-robin", + "round_robin", + "random", + "power-of-two", + "kv", + "direct", + "least-loaded", + "device-aware-weighted", +] + + @dataclass class DynamoRouterConfig: """Frontend router config for Dynamo. ``mode=None`` means "auto": Curator picks ``"kv"`` if any model uses ``mode="disagg"``, else leaves ``--router-mode`` unset so the Dynamo - frontend falls back to its own ``round_robin`` default. ``kv_events`` + frontend falls back to its own ``round-robin`` default. ``kv_events`` only applies when ``mode == "kv"``: pass ``kv_events=True`` to opt into exact ZMQ KV-cache event publishing; the default uses the router's approximate tree-based tracking. Anything else is forwarded to the Dynamo frontend as CLI args via ``router_kwargs``. """ - mode: Literal["round_robin", "random", "kv", "direct"] | None = None + mode: DynamoRouterMode | None = None kv_events: bool = False router_kwargs: dict[str, Any] = field(default_factory=dict) _RESERVED_ROUTER_KWARGS: ClassVar[frozenset[str]] = frozenset({"router_mode", "router_kv_events"}) + _MODE_ALIASES: ClassVar[dict[str, str]] = {"round_robin": "round-robin"} def __post_init__(self) -> None: + if self.mode is not None: + self.mode = self._MODE_ALIASES.get(self.mode, self.mode) # type: ignore[assignment] if self.mode is not None and self.mode != "kv" and self.kv_events: msg = f"kv_events=True is only meaningful when mode='kv'; got mode={self.mode!r}." raise ValueError(msg) diff --git a/nemo_curator/core/serve/dynamo/vllm.py b/nemo_curator/core/serve/dynamo/vllm.py index f6bfcae1e3..eda1961bcb 100644 --- a/nemo_curator/core/serve/dynamo/vllm.py +++ b/nemo_curator/core/serve/dynamo/vllm.py @@ -17,6 +17,7 @@ from __future__ import annotations import json +import os import tempfile from functools import reduce from pathlib import Path @@ -67,12 +68,19 @@ "config": {"setup_timeout_seconds": 600}, } +_USE_DRIVER_ENV_VAR = "NEMO_CURATOR_DYNAMO_USE_DRIVER_ENV" + @ray.remote def _write_actor_overrides_file(path: str, body: str) -> None: Path(path).write_text(body) +def _use_driver_env_for_dynamo() -> bool: + """Return true when Dynamo actors should use the driver's Python env.""" + return os.environ.get(_USE_DRIVER_ENV_VAR, "0").lower() in {"1", "true", "yes", "on"} + + def ensure_actor_overrides_on_all_nodes(*, ignore_head_node: bool = False) -> None: """Write the actor-venv ``--override`` file at a fixed path on every alive node. @@ -109,6 +117,8 @@ def ensure_actor_overrides_on_all_nodes(*, ignore_head_node: bool = False) -> No def dynamo_runtime_env(model_config: DynamoVLLMModelConfig) -> dict[str, Any]: """Merge the user's ``runtime_env`` with the Dynamo-vLLM package pin.""" + if _use_driver_env_for_dynamo(): + return model_config.runtime_env or {} return BaseModelConfig.merge_runtime_envs(DYNAMO_VLLM_RUNTIME_ENV, model_config.runtime_env or None) @@ -116,6 +126,8 @@ def merge_model_runtime_envs(models: list[DynamoVLLMModelConfig]) -> dict[str, A """Merge every model's ``runtime_env`` onto the Dynamo-vLLM pin for the shared frontend actor.""" envs = [m.runtime_env for m in models if m.runtime_env] user_merged = reduce(BaseModelConfig.merge_runtime_envs, envs) if envs else None + if _use_driver_env_for_dynamo(): + return user_merged or {} return BaseModelConfig.merge_runtime_envs(DYNAMO_VLLM_RUNTIME_ENV, user_merged) diff --git a/nemo_curator/core/serve/ray_serve/backend.py b/nemo_curator/core/serve/ray_serve/backend.py index f7da6f21aa..f6b7c5e1a6 100644 --- a/nemo_curator/core/serve/ray_serve/backend.py +++ b/nemo_curator/core/serve/ray_serve/backend.py @@ -70,11 +70,17 @@ def _deploy(self) -> None: llm_configs = [self._to_llm_config(model, quiet_runtime_env=quiet_env) for model in server.models] build_args: dict[str, Any] = {"llm_configs": llm_configs} + ingress_deployment_config = dict(server.backend.ingress_deployment_config) if quiet_env: # Suppress access logs on the OpenAI ingress deployment too. - build_args["ingress_deployment_config"] = { - "ray_actor_options": {"runtime_env": quiet_env}, - } + ray_actor_options = dict(ingress_deployment_config.get("ray_actor_options", {})) + ray_actor_options["runtime_env"] = BaseModelConfig.merge_runtime_envs( + ray_actor_options.get("runtime_env", {}), + quiet_env, + ) + ingress_deployment_config["ray_actor_options"] = ray_actor_options + if ingress_deployment_config: + build_args["ingress_deployment_config"] = ingress_deployment_config from ray import serve from ray.serve.llm import build_openai_app diff --git a/nemo_curator/core/serve/ray_serve/config.py b/nemo_curator/core/serve/ray_serve/config.py index cec5e1d7cb..321c79154f 100644 --- a/nemo_curator/core/serve/ray_serve/config.py +++ b/nemo_curator/core/serve/ray_serve/config.py @@ -31,3 +31,4 @@ class RayServeServerConfig(BaseServerConfig): """Server-level Ray Serve config.""" model_configs: ClassVar[tuple[type[BaseModelConfig], ...]] = (RayServeModelConfig,) + ingress_deployment_config: dict[str, Any] = field(default_factory=dict) diff --git a/nemo_curator/core/utils.py b/nemo_curator/core/utils.py index f36671116a..200cffed3a 100644 --- a/nemo_curator/core/utils.py +++ b/nemo_curator/core/utils.py @@ -139,6 +139,8 @@ def init_cluster( # noqa: PLR0913 ray_metrics_port: int, ray_client_server_port: int, ray_dashboard_host: str, + ray_min_worker_port: int | None = None, + ray_max_worker_port: int | None = None, num_gpus: int | None = None, num_cpus: int | None = None, object_store_memory: int | None = None, @@ -164,6 +166,10 @@ def init_cluster( # noqa: PLR0913 ray_command.extend(["--dashboard-port", str(ray_dashboard_port)]) ray_command.extend(["--ray-client-server-port", str(ray_client_server_port)]) ray_command.extend(["--temp-dir", ray_temp_dir]) + if ray_min_worker_port is not None: + ray_command.extend(["--min-worker-port", str(ray_min_worker_port)]) + if ray_max_worker_port is not None: + ray_command.extend(["--max-worker-port", str(ray_max_worker_port)]) if object_store_memory is not None: ray_command.extend(["--object-store-memory", str(object_store_memory)]) ray_command.extend(["--disable-usage-stats"]) diff --git a/nemo_curator/models/client/llm_client.py b/nemo_curator/models/client/llm_client.py index d406cbed84..2f6532459e 100644 --- a/nemo_curator/models/client/llm_client.py +++ b/nemo_curator/models/client/llm_client.py @@ -15,11 +15,14 @@ import asyncio import secrets from abc import ABC, abstractmethod -from collections.abc import Iterable +from collections.abc import Awaitable, Callable, Iterable from dataclasses import dataclass +from typing import TypeVar from loguru import logger +T = TypeVar("T") + class ConversationFormatter(ABC): """ @@ -116,23 +119,15 @@ async def _query_model_impl( msg = "Subclass of AsyncLLMClient must implement '_query_model_impl'" raise NotImplementedError(msg) - async def query_model( # noqa: C901, PLR0912 - self, - *, - messages: Iterable, - model: str, - conversation_formatter: ConversationFormatter | None = None, - generation_config: GenerationConfig | dict | None = None, - ) -> list[str]: - """ - Query the model with automatic retry and concurrency control. - """ - # Use default config if none provided + @staticmethod + def _coerce_generation_config(generation_config: GenerationConfig | dict | None) -> GenerationConfig: if generation_config is None: - generation_config = GenerationConfig() - elif isinstance(generation_config, dict): - generation_config = GenerationConfig(**generation_config) + return GenerationConfig() + if isinstance(generation_config, dict): + return GenerationConfig(**generation_config) + return generation_config + async def _run_with_retry_and_concurrency(self, operation: Callable[[], Awaitable[T]]) -> T: # noqa: C901, PLR0912 # Initialize semaphore if not already done or if we're in a different event loop current_loop = asyncio.get_running_loop() if self._semaphore is None or self._semaphore_loop != current_loop: @@ -179,12 +174,7 @@ async def query_model( # noqa: C901, PLR0912 # Attempt the query try: - return await self._query_model_impl( - messages=messages, - model=model, - conversation_formatter=conversation_formatter, - generation_config=generation_config, - ) + return await operation() except Exception as e: last_exception = e # If this is the last attempt, provide helpful error message @@ -208,7 +198,27 @@ async def query_model( # noqa: C901, PLR0912 raise last_exception # This should never be reached, but add explicit return for linter - logger.warning( - "Unexpected code path: AsyncLLMClient.query_model completed without returning a result or raising an exception" + msg = "Unexpected code path: AsyncLLMClient operation completed without returning a result or raising" + raise RuntimeError(msg) + + async def query_model( + self, + *, + messages: Iterable, + model: str, + conversation_formatter: ConversationFormatter | None = None, + generation_config: GenerationConfig | dict | None = None, + ) -> list[str]: + """ + Query the model with automatic retry and concurrency control. + """ + # Use default config if none provided + generation_config = self._coerce_generation_config(generation_config) + return await self._run_with_retry_and_concurrency( + lambda: self._query_model_impl( + messages=messages, + model=model, + conversation_formatter=conversation_formatter, + generation_config=generation_config, ) - return [] + ) diff --git a/nemo_curator/models/client/openai_client.py b/nemo_curator/models/client/openai_client.py index 3ca232fa1e..96fd6ce398 100644 --- a/nemo_curator/models/client/openai_client.py +++ b/nemo_curator/models/client/openai_client.py @@ -14,6 +14,8 @@ import warnings from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any from loguru import logger from openai import AsyncOpenAI, OpenAI @@ -21,6 +23,16 @@ from nemo_curator.models.client.llm_client import AsyncLLMClient, ConversationFormatter, GenerationConfig, LLMClient +@dataclass(frozen=True) +class OpenAIChatCompletionResult: + """OpenAI-compatible chat completion content and aggregate usage.""" + + contents: list[str] + prompt_tokens: int | None = None + completion_tokens: int | None = None + total_tokens: int | None = None + + class OpenAIClient(LLMClient): """ A wrapper around OpenAI's Python client for querying models @@ -45,6 +57,21 @@ def query_model( conversation_formatter: ConversationFormatter | None = None, generation_config: GenerationConfig | dict | None = None, ) -> list[str]: + return self.query_model_with_usage( + messages=messages, + model=model, + conversation_formatter=conversation_formatter, + generation_config=generation_config, + ).contents + + def query_model_with_usage( + self, + *, + messages: Iterable, + model: str, + conversation_formatter: ConversationFormatter | None = None, + generation_config: GenerationConfig | dict | None = None, + ) -> OpenAIChatCompletionResult: if conversation_formatter is not None: warnings.warn("conversation_formatter is not used in an OpenAIClient", stacklevel=2) @@ -80,7 +107,7 @@ def query_model( response = self.client.chat.completions.create(**create_kwargs) - return [choice.message.content for choice in response.choices] + return _completion_result_from_response(response) class AsyncOpenAIClient(AsyncLLMClient): @@ -122,6 +149,25 @@ async def _query_model_impl( """ Internal implementation of query_model without retry/concurrency logic. """ + result = await self._query_model_with_usage_impl( + messages=messages, + model=model, + conversation_formatter=conversation_formatter, + generation_config=generation_config, + ) + return result.contents + + async def _query_model_with_usage_impl( + self, + *, + messages: Iterable, + model: str, + conversation_formatter: ConversationFormatter | None = None, + generation_config: GenerationConfig | dict | None = None, + ) -> OpenAIChatCompletionResult: + """ + Internal implementation of query_model_with_usage without retry/concurrency logic. + """ if conversation_formatter is not None: warnings.warn("conversation_formatter is not used in an AsyncOpenAIClient", stacklevel=2) @@ -157,4 +203,50 @@ async def _query_model_impl( response = await self.client.chat.completions.create(**create_kwargs) - return [choice.message.content for choice in response.choices] + return _completion_result_from_response(response) + + async def query_model_with_usage( + self, + *, + messages: Iterable, + model: str, + conversation_formatter: ConversationFormatter | None = None, + generation_config: GenerationConfig | dict | None = None, + ) -> OpenAIChatCompletionResult: + """ + Query the model and keep OpenAI-compatible usage counters when the server returns them. + """ + generation_config = self._coerce_generation_config(generation_config) + return await self._run_with_retry_and_concurrency( + lambda: self._query_model_with_usage_impl( + messages=messages, + model=model, + conversation_formatter=conversation_formatter, + generation_config=generation_config, + ) + ) + + +def _completion_result_from_response(response: Any) -> OpenAIChatCompletionResult: # noqa: ANN401 + usage = getattr(response, "usage", None) + return OpenAIChatCompletionResult( + contents=[choice.message.content for choice in response.choices], + prompt_tokens=_usage_int(usage, "prompt_tokens"), + completion_tokens=_usage_int(usage, "completion_tokens"), + total_tokens=_usage_int(usage, "total_tokens"), + ) + + +def _usage_int(usage: Any, field: str) -> int | None: # noqa: ANN401 + if usage is None: + return None + value = usage.get(field) if isinstance(usage, dict) else getattr(usage, field, None) + if isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float) and value.is_integer(): + return int(value) + if isinstance(value, str) and value.isdigit(): + return int(value) + return None diff --git a/nemo_curator/stages/text/experimental/dripper/__init__.py b/nemo_curator/stages/text/experimental/dripper/__init__.py new file mode 100644 index 0000000000..a356740083 --- /dev/null +++ b/nemo_curator/stages/text/experimental/dripper/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dripper/MinerU-HTML HTML content extraction stages for NeMo Curator.""" + +from nemo_curator.stages.text.experimental.dripper._base_stages import ( + DripperHTMLExtractionStage, + DripperHTMLInferenceStage, + DripperHTMLPostprocessStage, + DripperHTMLPreprocessStage, +) +from nemo_curator.stages.text.experimental.dripper.layout_template import DripperHTMLLayoutTemplateStage +from nemo_curator.stages.text.experimental.dripper.workflow import DripperHTMLWorkflow + +__all__ = [ + "DripperHTMLExtractionStage", + "DripperHTMLInferenceStage", + "DripperHTMLLayoutTemplateStage", + "DripperHTMLPostprocessStage", + "DripperHTMLPreprocessStage", + "DripperHTMLWorkflow", # main user entry point +] diff --git a/nemo_curator/stages/text/experimental/dripper/_base_stages.py b/nemo_curator/stages/text/experimental/dripper/_base_stages.py new file mode 100644 index 0000000000..34845b228a --- /dev/null +++ b/nemo_curator/stages/text/experimental/dripper/_base_stages.py @@ -0,0 +1,1003 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MinerU-HTML extraction stages: DripperHTMLExtractionStage, PreprocessStage, InferenceStage, PostprocessStage.""" + +from __future__ import annotations + +import asyncio +import time +from collections import defaultdict +from dataclasses import dataclass, field, replace +from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable + +import pandas as pd +from loguru import logger + +from nemo_curator.models.client.llm_client import GenerationConfig +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.text.experimental.translation.utils.async_utils import run_async_safe +from nemo_curator.tasks import DocumentBatch + +if TYPE_CHECKING: + from nemo_curator.backends.base import WorkerMetadata + from nemo_curator.models.client.llm_client import AsyncLLMClient + +from nemo_curator.stages.text.experimental.dripper.stage import ( + _DRIPPER_EMPTY_INPUT_COL, + _DRIPPER_LAYOUT_FINALIZED_COL, + _DRIPPER_NEEDS_LLM_COL, + _DRIPPER_PRIMARY_ERROR_COL, + _DRIPPER_PROMPT_COL, + _STRUCTURED_OUTPUT_MODES, + _append_warning, + _apply_fallback_extraction, + _case_has_item_ids, + _coerce_html, + _coerce_optional_str, + _coerce_usage_int, + _count_item_ids, + _DripperInferenceResult, + _DripperPostResult, + _DripperPrepResult, + _DripperRowResult, + _generation_config_for_item_count, + _get_processed_attr, + _is_empty_document_error, + _load_mineru_html_bindings, + _MinerUHTMLBindings, + _numeric_series_or_zero, + _query_dripper_model, + _rebuild_batch, + _run_dripper_health_check, + _sanitize_case_output_html, + _with_structured_output_config, +) + + +def _col_str_list(df: pd.DataFrame, col: str, n: int) -> list[str]: + return df[col].astype(str).tolist() if col in df else [""] * n + + +def _col_float_list(df: pd.DataFrame, col: str, n: int) -> list[float]: + return pd.to_numeric(df[col], errors="coerce").fillna(0.0).tolist() if col in df else [0.0] * n + + +def _col_int_list(df: pd.DataFrame, col: str, n: int) -> list[int]: + return pd.to_numeric(df[col], errors="coerce").fillna(0).astype(int).tolist() if col in df else [0] * n + + +@runtime_checkable +class _HasDynamicTokenParams(Protocol): + dynamic_max_token_padding: int + dynamic_max_tokens_per_item: int + dynamic_min_max_tokens: int + + +@runtime_checkable +class _HasLLMClientParams(Protocol): + client: Any + model_name: str + max_concurrent_requests: int + structured_output_mode: str + + +def _validate_dynamic_token_params(obj: _HasDynamicTokenParams) -> None: + if obj.dynamic_max_token_padding < 0: + msg = "dynamic_max_token_padding must be non-negative" + raise ValueError(msg) + if obj.dynamic_max_tokens_per_item <= 0: + msg = "dynamic_max_tokens_per_item must be positive" + raise ValueError(msg) + if obj.dynamic_min_max_tokens <= 0: + msg = "dynamic_min_max_tokens must be positive" + raise ValueError(msg) + + +def _validate_llm_client_params(obj: _HasLLMClientParams, class_name: str) -> None: + if obj.client is None: + msg = f"{class_name} requires a non-None 'client' (AsyncLLMClient)" + raise ValueError(msg) + obj.model_name = obj.model_name.strip() + if not obj.model_name: + msg = f"{class_name} requires a non-empty 'model_name'" + raise ValueError(msg) + if obj.max_concurrent_requests <= 0: + msg = "max_concurrent_requests must be positive" + raise ValueError(msg) + if obj.structured_output_mode not in _STRUCTURED_OUTPUT_MODES: + msg = f"structured_output_mode must be one of {sorted(_STRUCTURED_OUTPUT_MODES)}" + raise ValueError(msg) + + +def _apply_conversion_to_row_result(case: object, base: _DripperRowResult, conversion_error: str) -> _DripperRowResult: + output_data = getattr(case, "output_data", None) + main_html = getattr(output_data, "main_html", "") if output_data is not None else "" + main_content = getattr(output_data, "main_content", "") or "" + warning = base.warning + error = "" + if conversion_error: + if _is_empty_document_error(conversion_error) and not str(main_html).strip(): + warning = _append_warning(warning, conversion_error) + else: + error = conversion_error + return replace(base, main_html=main_html, main_content=main_content, error=error, warning=warning) + + +@dataclass(kw_only=True) +class _DripperColumnsMixin: + html_col: str = "html" + url_col: str | None = "url" + raw_response_col: str = "dripper_response" + preprocess_time_col: str = "dripper_preprocess_time_s" + inference_time_col: str = "dripper_inference_time_s" + postprocess_time_col: str = "dripper_postprocess_time_s" + total_time_col: str = "dripper_time_s" + error_col: str = "dripper_error" + warning_col: str = "dripper_warning" + item_count_col: str = "dripper_item_count" + prompt_chars_col: str = "dripper_prompt_chars" + request_max_tokens_col: str = "dripper_request_max_tokens" + prompt_tokens_col: str = "dripper_prompt_tokens" + completion_tokens_col: str = "dripper_completion_tokens" + total_tokens_col: str = "dripper_total_tokens" + simplified_html_col: str = "dripper_simplified_html" + mapped_html_col: str = "dripper_mapped_html" + prompt_version: str = "short_compact" + generation_config: GenerationConfig | None = None + dynamic_max_tokens: bool = False + dynamic_max_token_padding: int = 16 + dynamic_max_tokens_per_item: int = 6 + dynamic_min_max_tokens: int = 32 + + +@dataclass(kw_only=True) +class DripperHTMLExtractionStage(_DripperColumnsMixin, ProcessingStage[DocumentBatch, DocumentBatch]): + name: str = "DripperHTMLExtractionStage" + client: AsyncLLMClient | None + model_name: str + output_html_col: str = "dripper_html" + output_content_col: str = "dripper_content" + output_format: str = "mm_md" + fallback: Literal["trafilatura", "bypass", "empty"] = "trafilatura" + structured_output_mode: Literal["none", "structured_outputs", "guided_regex"] = "none" + max_concurrent_requests: int = 64 + health_check: bool = True + keep_intermediate: bool = False + + _bindings: _MinerUHTMLBindings | None = field(init=False, repr=False, default=None) + _fallback_handler: Any = field(init=False, repr=False, default=None) + _initialized: bool = field(init=False, repr=False, default=False) + + def __post_init__(self) -> None: + _validate_llm_client_params(self, "DripperHTMLExtractionStage") + _validate_dynamic_token_params(self) + + def inputs(self) -> tuple[list[str], list[str]]: + return ["data"], [self.html_col] + + def outputs(self) -> tuple[list[str], list[str]]: + columns = [ + self.output_html_col, + self.output_content_col, + self.raw_response_col, + self.preprocess_time_col, + self.inference_time_col, + self.postprocess_time_col, + self.total_time_col, + self.error_col, + self.warning_col, + self.item_count_col, + self.prompt_chars_col, + self.request_max_tokens_col, + self.prompt_tokens_col, + self.completion_tokens_col, + self.total_tokens_col, + ] + if self.keep_intermediate: + columns.extend([self.simplified_html_col, self.mapped_html_col]) + return ["data"], columns + + def setup(self, worker_metadata: WorkerMetadata | None = None) -> None: # noqa: ARG002 + if self._initialized: + return + self._bindings = _load_mineru_html_bindings() + self._fallback_handler = self._bindings.get_fallback_handler(self.fallback) + self.client.setup() + if self.health_check: + run_async_safe(lambda: _run_dripper_health_check(self.client, self.model_name, self.generation_config)) + self._initialized = True + + def process(self, batch: DocumentBatch) -> DocumentBatch: + if not self._initialized: + self.setup() + + df = batch.to_pandas().copy() + if self.html_col not in df.columns: + msg = f"Input batch is missing required HTML column: {self.html_col!r}" + raise ValueError(msg) + + html_values = df[self.html_col].tolist() + if self.url_col is not None and self.url_col in df.columns: + url_values = df[self.url_col].tolist() + else: + url_values = [None] * len(df) + + results = run_async_safe(lambda: self._extract_all_async(html_values, url_values)) + df[self.output_html_col] = [r.main_html for r in results] + df[self.output_content_col] = [r.main_content for r in results] + df[self.raw_response_col] = [r.raw_response for r in results] + df[self.preprocess_time_col] = [r.preprocess_time_s for r in results] + df[self.inference_time_col] = [r.inference_time_s for r in results] + df[self.postprocess_time_col] = [r.postprocess_time_s for r in results] + df[self.total_time_col] = [r.total_time_s for r in results] + df[self.error_col] = [r.error for r in results] + df[self.warning_col] = [r.warning for r in results] + df[self.item_count_col] = [r.item_count for r in results] + df[self.prompt_chars_col] = [r.prompt_chars for r in results] + df[self.request_max_tokens_col] = [r.request_max_tokens for r in results] + df[self.prompt_tokens_col] = [r.prompt_tokens for r in results] + df[self.completion_tokens_col] = [r.completion_tokens for r in results] + df[self.total_tokens_col] = [r.total_tokens for r in results] + if self.keep_intermediate: + df[self.simplified_html_col] = [r.simplified_html for r in results] + df[self.mapped_html_col] = [r.mapped_html for r in results] + + return _rebuild_batch(batch, df) + + async def _extract_all_async(self, html_values: list[object], url_values: list[object]) -> list[_DripperRowResult]: + sem = asyncio.Semaphore(self.max_concurrent_requests) + + async def _extract_one_throttled(html_value: object, url_value: object) -> _DripperRowResult: + async with sem: + return await self._extract_one_async(html_value, url_value) + + tasks = [ + _extract_one_throttled(html_value, url_value) + for html_value, url_value in zip(html_values, url_values, strict=False) + ] + raw_results = await asyncio.gather(*tasks, return_exceptions=True) + + results: list[_DripperRowResult] = [] + for idx, result in enumerate(raw_results): + if isinstance(result, BaseException): + logger.error("Dripper extraction failed for row {}: {}", idx, result) + results.append(_DripperRowResult(error=str(result))) + else: + results.append(result) + return results + + def _preprocess_case(self, case: object) -> tuple[object, int, str, str, bool]: + case = self._bindings.simplify_single_input(case) + item_count = _count_item_ids(case) + if not _case_has_item_ids(case): + case = self._bindings.extract_main_html_fallback(case, fallback_handler=self._fallback_handler) + return ( + case, + item_count, + "", + "no _item_id attributes after simplification; used fallback without LLM", + False, + ) + case = self._bindings.build_prompt(case, prompt_version=self.prompt_version) + prompt = case.generate_input.full_prompt + return case, item_count, prompt, "", True + + async def _run_inference_async( + self, case: object, prompt: str, item_count: int + ) -> tuple[object, str, int, int, int, int]: + generation_config = _with_structured_output_config( + _generation_config_for_item_count(self, item_count), prompt, self.structured_output_mode + ) + request_max_tokens = generation_config.max_tokens or 0 + raw_response, prompt_tokens, completion_tokens, total_tokens = await _query_dripper_model( + self.client, self.model_name, [{"role": "user", "content": prompt}], generation_config + ) + case.generate_output = self._bindings.generate_output_cls(response=raw_response) + case = self._bindings.parse_result(case) + case = self._bindings.extract_main_html_single(case) + return case, raw_response, request_max_tokens, prompt_tokens, completion_tokens, total_tokens + + async def _extract_one_async(self, html_value: object, url_value: object) -> _DripperRowResult: + start_total = time.perf_counter() + html = _coerce_html(html_value) + if not html.strip(): + return _DripperRowResult(total_time_s=time.perf_counter() - start_total, warning="empty HTML input") + + url = _coerce_optional_str(url_value) + case = self._bindings.case_cls(self._bindings.input_cls(raw_html=html, url=url)) + raw_response = "" + preprocess_time_s = 0.0 + inference_time_s = 0.0 + postprocess_time_s = 0.0 + warning = "" + item_count = 0 + prompt_chars = 0 + request_max_tokens = 0 + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + + try: + t0 = time.perf_counter() + case, item_count, prompt, warning, needs_llm = self._preprocess_case(case) + preprocess_time_s = time.perf_counter() - t0 + if needs_llm: + prompt_chars = len(prompt) + t1 = time.perf_counter() + ( + case, + raw_response, + request_max_tokens, + prompt_tokens, + completion_tokens, + total_tokens, + ) = await self._run_inference_async(case, prompt, item_count) + inference_time_s = time.perf_counter() - t1 + except Exception as exc: # noqa: BLE001 + if preprocess_time_s == 0.0: + preprocess_time_s = time.perf_counter() - start_total + primary_error = str(exc) + logger.debug("Dripper primary extraction failed, applying {} fallback: {}", self.fallback, primary_error) + try: + t2 = time.perf_counter() + case = self._bindings.extract_main_html_fallback(case, fallback_handler=self._fallback_handler) + postprocess_time_s += time.perf_counter() - t2 + warning = primary_error + except Exception as fallback_exc: # noqa: BLE001 + return _DripperRowResult( + raw_response=raw_response, + preprocess_time_s=preprocess_time_s, + inference_time_s=inference_time_s, + postprocess_time_s=postprocess_time_s, + total_time_s=time.perf_counter() - start_total, + error=f"{primary_error}; fallback failed: {fallback_exc}", + warning=primary_error, + simplified_html=_get_processed_attr(case, "simpled_html"), + mapped_html=_get_processed_attr(case, "map_html"), + item_count=item_count, + prompt_chars=prompt_chars, + request_max_tokens=request_max_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + + partial = _DripperRowResult( + raw_response=raw_response, + warning=warning, + preprocess_time_s=preprocess_time_s, + inference_time_s=inference_time_s, + postprocess_time_s=postprocess_time_s, + item_count=item_count, + prompt_chars=prompt_chars, + request_max_tokens=request_max_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + simplified_html=_get_processed_attr(case, "simpled_html"), + mapped_html=_get_processed_attr(case, "map_html"), + ) + t3 = time.perf_counter() + conversion_error, case = self._convert_case(case) + postprocess_time_s += time.perf_counter() - t3 + partial = replace( + partial, postprocess_time_s=postprocess_time_s, total_time_s=time.perf_counter() - start_total + ) + return self._apply_conversion_result(case, partial, conversion_error) + + def _convert_case(self, case: object) -> tuple[str, object]: + try: + _sanitize_case_output_html(case) + return "", self._bindings.convert2content(case, output_format=self.output_format) + except Exception as exc: # noqa: BLE001 + conversion_error = str(exc) + logger.debug("Dripper content conversion failed: {}", conversion_error) + return conversion_error, case + + def _apply_conversion_result( + self, case: object, base: _DripperRowResult, conversion_error: str + ) -> _DripperRowResult: + return _apply_conversion_to_row_result(case, base, conversion_error) + + +@dataclass(kw_only=True) +class DripperHTMLPreprocessStage(_DripperColumnsMixin, ProcessingStage[DocumentBatch, DocumentBatch]): + name: str = "DripperHTMLPreprocessStage" + worker_count: int | None = None + + _bindings: _MinerUHTMLBindings | None = field(init=False, repr=False, default=None) + _initialized: bool = field(init=False, repr=False, default=False) + + def __post_init__(self) -> None: + _validate_dynamic_token_params(self) + if self.worker_count is not None and self.worker_count <= 0: + msg = "worker_count must be positive when set" + raise ValueError(msg) + + def num_workers(self) -> int | None: + return self.worker_count + + def inputs(self) -> tuple[list[str], list[str]]: + return ["data"], [self.html_col] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], [ + self.raw_response_col, + self.preprocess_time_col, + self.inference_time_col, + self.postprocess_time_col, + self.total_time_col, + self.error_col, + self.warning_col, + self.item_count_col, + self.prompt_chars_col, + self.request_max_tokens_col, + self.prompt_tokens_col, + self.completion_tokens_col, + self.total_tokens_col, + self.simplified_html_col, + self.mapped_html_col, + _DRIPPER_PROMPT_COL, + _DRIPPER_NEEDS_LLM_COL, + _DRIPPER_PRIMARY_ERROR_COL, + _DRIPPER_EMPTY_INPUT_COL, + ] + + def setup(self, worker_metadata: WorkerMetadata | None = None) -> None: # noqa: ARG002 + if self._initialized: + return + self._bindings = _load_mineru_html_bindings() + self._initialized = True + + def process(self, batch: DocumentBatch) -> DocumentBatch: + if not self._initialized: + self.setup() + + df = batch.to_pandas().copy() + if self.html_col not in df.columns: + msg = f"Input batch is missing required HTML column: {self.html_col!r}" + raise ValueError(msg) + + html_values = df[self.html_col].tolist() + if self.url_col is not None and self.url_col in df.columns: + url_values = df[self.url_col].tolist() + else: + url_values = [None] * len(df) + + results = [ + self._prepare_one(html_value, url_value) + for html_value, url_value in zip(html_values, url_values, strict=False) + ] + + pt = [r.preprocess_time_s for r in results] + df = df.assign( + **{ + self.raw_response_col: "", + self.preprocess_time_col: pt, + self.inference_time_col: 0.0, + self.postprocess_time_col: 0.0, + self.total_time_col: pt, + self.error_col: "", + self.warning_col: [r.warning for r in results], + self.item_count_col: [r.item_count for r in results], + self.prompt_chars_col: [r.prompt_chars for r in results], + self.request_max_tokens_col: [r.request_max_tokens for r in results], + self.prompt_tokens_col: 0, + self.completion_tokens_col: 0, + self.total_tokens_col: 0, + self.simplified_html_col: [r.simplified_html for r in results], + self.mapped_html_col: [r.mapped_html for r in results], + _DRIPPER_PROMPT_COL: [r.prompt for r in results], + _DRIPPER_NEEDS_LLM_COL: [r.needs_llm for r in results], + _DRIPPER_PRIMARY_ERROR_COL: [r.primary_error for r in results], + _DRIPPER_EMPTY_INPUT_COL: [r.empty_input for r in results], + } + ) + + self._log_metrics( + { + "preprocess_rows": float(len(df)), + "preprocess_llm_rows": float(sum(r.needs_llm for r in results)), + "preprocess_fallback_rows": float(sum((not r.needs_llm) and (not r.empty_input) for r in results)), + } + ) + return _rebuild_batch(batch, df) + + def _prepare_one(self, html_value: object, url_value: object) -> _DripperPrepResult: + started = time.perf_counter() + html = _coerce_html(html_value) + if not html.strip(): + return _DripperPrepResult( + empty_input=True, + preprocess_time_s=time.perf_counter() - started, + warning="empty HTML input", + ) + + url = _coerce_optional_str(url_value) + case = self._bindings.case_cls(self._bindings.input_cls(raw_html=html, url=url)) + simplified_html = "" + mapped_html = "" + item_count = 0 + try: + case = self._bindings.simplify_single_input(case) + simplified_html = _get_processed_attr(case, "simpled_html") + mapped_html = _get_processed_attr(case, "map_html") + item_count = _count_item_ids(case) + if not _case_has_item_ids(case): + return _DripperPrepResult( + needs_llm=False, + preprocess_time_s=time.perf_counter() - started, + warning="no _item_id attributes after simplification; used fallback without LLM", + simplified_html=simplified_html, + mapped_html=mapped_html, + item_count=item_count, + ) + + case = self._bindings.build_prompt(case, prompt_version=self.prompt_version) + prompt = case.generate_input.full_prompt + return _DripperPrepResult( + prompt=prompt, + needs_llm=True, + preprocess_time_s=time.perf_counter() - started, + simplified_html=simplified_html, + mapped_html=mapped_html, + item_count=item_count, + prompt_chars=len(prompt), + request_max_tokens=_generation_config_for_item_count(self, item_count).max_tokens or 0, + ) + except Exception as exc: # noqa: BLE001 + primary_error = str(exc) + logger.debug("Dripper preprocessing failed; postprocess stage will apply fallback: {}", primary_error) + return _DripperPrepResult( + needs_llm=False, + preprocess_time_s=time.perf_counter() - started, + primary_error=primary_error, + warning=primary_error, + simplified_html=simplified_html, + mapped_html=mapped_html, + item_count=item_count, + ) + + +@dataclass(kw_only=True) +class DripperHTMLInferenceStage(ProcessingStage[DocumentBatch, DocumentBatch]): + name: str = "DripperHTMLInferenceStage" + client: AsyncLLMClient | None + model_name: str + raw_response_col: str = "dripper_response" + inference_time_col: str = "dripper_inference_time_s" + warning_col: str = "dripper_warning" + item_count_col: str = "dripper_item_count" + request_max_tokens_col: str = "dripper_request_max_tokens" + prompt_tokens_col: str = "dripper_prompt_tokens" + completion_tokens_col: str = "dripper_completion_tokens" + total_tokens_col: str = "dripper_total_tokens" + generation_config: GenerationConfig | None = None + structured_output_mode: Literal["none", "structured_outputs", "guided_regex"] = "none" + max_concurrent_requests: int = 64 + health_check: bool = False + worker_count: int | None = None + + _initialized: bool = field(init=False, repr=False, default=False) + + def __post_init__(self) -> None: + _validate_llm_client_params(self, "DripperHTMLInferenceStage") + if self.worker_count is not None and self.worker_count <= 0: + msg = "worker_count must be positive when set" + raise ValueError(msg) + + def num_workers(self) -> int | None: + return self.worker_count + + def inputs(self) -> tuple[list[str], list[str]]: + return ["data"], [_DRIPPER_PROMPT_COL, _DRIPPER_NEEDS_LLM_COL, self.request_max_tokens_col] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], [ + self.raw_response_col, + self.inference_time_col, + self.warning_col, + self.prompt_tokens_col, + self.completion_tokens_col, + self.total_tokens_col, + _DRIPPER_PRIMARY_ERROR_COL, + ] + + def setup(self, worker_metadata: WorkerMetadata | None = None) -> None: # noqa: ARG002 + if self._initialized: + return + self.client.setup() + if self.health_check: + run_async_safe(lambda: _run_dripper_health_check(self.client, self.model_name, self.generation_config)) + self._initialized = True + + def process(self, batch: DocumentBatch) -> DocumentBatch: + if not self._initialized: + self.setup() + + df = batch.to_pandas().copy() + results = run_async_safe(lambda: self._infer_all_async(df)) + + n = len(df) + needs_llm = df[_DRIPPER_NEEDS_LLM_COL].astype(bool).tolist() + existing_raw_responses = _col_str_list(df, self.raw_response_col, n) + existing_inference_times = _col_float_list(df, self.inference_time_col, n) + existing_prompt_tokens = _col_int_list(df, self.prompt_tokens_col, n) + existing_completion_tokens = _col_int_list(df, self.completion_tokens_col, n) + existing_total_tokens = _col_int_list(df, self.total_tokens_col, n) + existing_warnings = df[self.warning_col].astype(str) if self.warning_col in df else pd.Series([""] * n) + existing_primary_errors = ( + df[_DRIPPER_PRIMARY_ERROR_COL].astype(str) if _DRIPPER_PRIMARY_ERROR_COL in df else pd.Series([""] * n) + ) + df[self.raw_response_col] = [ + r.raw_response if q else e for r, q, e in zip(results, needs_llm, existing_raw_responses, strict=True) + ] + df[self.inference_time_col] = [ + r.inference_time_s if q else e + for r, q, e in zip(results, needs_llm, existing_inference_times, strict=True) + ] + df[self.warning_col] = [ + _append_warning(ew, r.warning) for ew, r in zip(existing_warnings.tolist(), results, strict=True) + ] + df[_DRIPPER_PRIMARY_ERROR_COL] = [ + _append_warning(ee, r.primary_error) + for ee, r in zip(existing_primary_errors.tolist(), results, strict=True) + ] + for col, attr, existing in ( + (self.prompt_tokens_col, "prompt_tokens", existing_prompt_tokens), + (self.completion_tokens_col, "completion_tokens", existing_completion_tokens), + (self.total_tokens_col, "total_tokens", existing_total_tokens), + ): + df[col] = [getattr(r, attr) if q else e for r, q, e in zip(results, needs_llm, existing, strict=True)] + + llm_prompts = [ + str(row.get(_DRIPPER_PROMPT_COL, "") or "") + for _, row in df.iterrows() + if bool(row.get(_DRIPPER_NEEDS_LLM_COL, False)) + ] + non_empty_llm_prompts = [prompt for prompt in llm_prompts if prompt.strip()] + unique_llm_prompts = len(set(non_empty_llm_prompts)) + self._log_metrics( + { + "inference_rows": float(len(df)), + "inference_llm_rows": float(sum(bool(v) for v in df[_DRIPPER_NEEDS_LLM_COL].tolist())), + "inference_unique_llm_prompts": float(unique_llm_prompts), + "inference_dedup_saved_rows": float(len(non_empty_llm_prompts) - unique_llm_prompts), + "inference_errors": float(sum(1 for r in results if r.primary_error)), + } + ) + return _rebuild_batch(batch, df) + + async def _infer_all_async(self, df: pd.DataFrame) -> list[_DripperInferenceResult]: + sem = asyncio.Semaphore(self.max_concurrent_requests) + prompts = df[_DRIPPER_PROMPT_COL].astype(str).tolist() + needs_llm = df[_DRIPPER_NEEDS_LLM_COL].astype(bool).tolist() + request_max_tokens = _col_int_list(df, self.request_max_tokens_col, len(df)) + + async def _infer_one_throttled(prompt: str, row_max_tokens: int) -> _DripperInferenceResult: + async with sem: + return await self._infer_one_async(prompt, row_max_tokens) + + grouped_indexes: dict[tuple[str, int], list[int]] = defaultdict(list) + results: list[_DripperInferenceResult | None] = [None] * len(df) + for idx, (prompt, should_query, row_max_tokens) in enumerate( + zip(prompts, needs_llm, request_max_tokens, strict=True) + ): + if not should_query: + results[idx] = _DripperInferenceResult() + elif not prompt.strip(): + results[idx] = _DripperInferenceResult( + primary_error="empty Dripper prompt", warning="empty Dripper prompt" + ) + else: + grouped_indexes[(prompt, row_max_tokens)].append(idx) + + tasks = {key: _infer_one_throttled(prompt=key[0], row_max_tokens=key[1]) for key in grouped_indexes} + raw_results = await asyncio.gather(*tasks.values(), return_exceptions=True) + + for (_key, indexes), result in zip(grouped_indexes.items(), raw_results, strict=True): + if isinstance(result, BaseException): + logger.error("Dripper inference failed for prompt group {} rows: {}", len(indexes), result) + error = str(result) + first_result = _DripperInferenceResult(primary_error=error, warning=error) + else: + first_result = result + first_idx = indexes[0] + results[first_idx] = first_result + for duplicate_idx in indexes[1:]: + results[duplicate_idx] = replace( + first_result, + inference_time_s=0.0, + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + + return [result if result is not None else _DripperInferenceResult() for result in results] + + async def _infer_one_async(self, prompt: str, row_max_tokens: int) -> _DripperInferenceResult: + started = time.perf_counter() + try: + generation_config = self.generation_config or GenerationConfig() + if row_max_tokens > 0 and generation_config.max_tokens != row_max_tokens: + generation_config = replace(generation_config, max_tokens=row_max_tokens) + generation_config = _with_structured_output_config(generation_config, prompt, self.structured_output_mode) + raw_response, prompt_tokens, completion_tokens, total_tokens = await self._query_model_with_usage( + model=self.model_name, + messages=[{"role": "user", "content": prompt}], + generation_config=generation_config, + ) + except Exception as exc: # noqa: BLE001 + error = str(exc) + logger.debug("Dripper inference failed; postprocess stage will apply fallback: {}", error) + return _DripperInferenceResult( + inference_time_s=time.perf_counter() - started, + primary_error=error, + warning=error, + ) + return _DripperInferenceResult( + raw_response=raw_response, + inference_time_s=time.perf_counter() - started, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + + async def _query_model_with_usage( + self, + *, + model: str, + messages: list[dict[str, str]], + generation_config: GenerationConfig, + ) -> tuple[str, int, int, int]: + query_model_with_usage = getattr(self.client, "query_model_with_usage", None) + if callable(query_model_with_usage): + response = await query_model_with_usage( + model=model, + messages=messages, + generation_config=generation_config, + ) + contents = getattr(response, "contents", []) + return ( + contents[0] if contents else "", + _coerce_usage_int(getattr(response, "prompt_tokens", None)), + _coerce_usage_int(getattr(response, "completion_tokens", None)), + _coerce_usage_int(getattr(response, "total_tokens", None)), + ) + + response = await self.client.query_model( + model=model, + messages=messages, + generation_config=generation_config, + ) + return response[0] if response else "", 0, 0, 0 + + +@dataclass(kw_only=True) +class DripperHTMLPostprocessStage(ProcessingStage[DocumentBatch, DocumentBatch]): + name: str = "DripperHTMLPostprocessStage" + html_col: str = "html" + url_col: str | None = "url" + output_html_col: str = "dripper_html" + output_content_col: str = "dripper_content" + raw_response_col: str = "dripper_response" + preprocess_time_col: str = "dripper_preprocess_time_s" + inference_time_col: str = "dripper_inference_time_s" + postprocess_time_col: str = "dripper_postprocess_time_s" + total_time_col: str = "dripper_time_s" + error_col: str = "dripper_error" + warning_col: str = "dripper_warning" + fallback: Literal["trafilatura", "bypass", "empty"] = "trafilatura" + output_format: str = "mm_md" + keep_intermediate: bool = False + simplified_html_col: str = "dripper_simplified_html" + mapped_html_col: str = "dripper_mapped_html" + worker_count: int | None = None + + _bindings: _MinerUHTMLBindings | None = field(init=False, repr=False, default=None) + _fallback_handler: Any = field(init=False, repr=False, default=None) + _initialized: bool = field(init=False, repr=False, default=False) + + def __post_init__(self) -> None: + if self.worker_count is not None and self.worker_count <= 0: + msg = "worker_count must be positive when set" + raise ValueError(msg) + + def num_workers(self) -> int | None: + return self.worker_count + + def inputs(self) -> tuple[list[str], list[str]]: + return ["data"], [ + self.html_col, + self.raw_response_col, + self.simplified_html_col, + self.mapped_html_col, + _DRIPPER_NEEDS_LLM_COL, + _DRIPPER_PRIMARY_ERROR_COL, + _DRIPPER_EMPTY_INPUT_COL, + ] + + def outputs(self) -> tuple[list[str], list[str]]: + columns = [ + self.output_html_col, + self.output_content_col, + self.postprocess_time_col, + self.total_time_col, + self.error_col, + self.warning_col, + ] + if self.keep_intermediate: + columns.extend([self.simplified_html_col, self.mapped_html_col]) + return ["data"], columns + + def setup(self, worker_metadata: WorkerMetadata | None = None) -> None: # noqa: ARG002 + if self._initialized: + return + self._bindings = _load_mineru_html_bindings() + self._fallback_handler = self._bindings.get_fallback_handler(self.fallback) + self._initialized = True + + def process(self, batch: DocumentBatch) -> DocumentBatch: + if not self._initialized: + self.setup() + + df = batch.to_pandas().copy() + html_values = df[self.html_col].tolist() + if self.url_col is not None and self.url_col in df.columns: + url_values = df[self.url_col].tolist() + else: + url_values = [None] * len(df) + + results = [ + self._postprocess_one(row, html_value, url_value) + for (_, row), html_value, url_value in zip(df.iterrows(), html_values, url_values, strict=True) + ] + + preprocess_times = _numeric_series_or_zero(df, self.preprocess_time_col) + inference_times = _numeric_series_or_zero(df, self.inference_time_col) + postprocess_times = pd.Series([r.postprocess_time_s for r in results], index=df.index) + + df[self.output_html_col] = [r.main_html for r in results] + df[self.output_content_col] = [r.main_content for r in results] + df[self.postprocess_time_col] = postprocess_times + df[self.total_time_col] = preprocess_times + inference_times + postprocess_times + df[self.error_col] = [r.error for r in results] + df[self.warning_col] = [r.warning for r in results] + + drop_cols = [ + _DRIPPER_PROMPT_COL, + _DRIPPER_NEEDS_LLM_COL, + _DRIPPER_PRIMARY_ERROR_COL, + _DRIPPER_EMPTY_INPUT_COL, + _DRIPPER_LAYOUT_FINALIZED_COL, + ] + if not self.keep_intermediate: + drop_cols.extend([self.simplified_html_col, self.mapped_html_col]) + df = df.drop(columns=[col for col in drop_cols if col in df.columns]) + + self._log_metrics( + { + "postprocess_rows": float(len(df)), + "postprocess_errors": float(sum(1 for r in results if r.error)), + "postprocess_warnings": float(sum(1 for r in results if r.warning)), + } + ) + return _rebuild_batch(batch, df) + + def _postprocess_one(self, row: pd.Series, html_value: object, url_value: object) -> _DripperPostResult: + started = time.perf_counter() + warning = str(row.get(self.warning_col, "") or "") + primary_error = str(row.get(_DRIPPER_PRIMARY_ERROR_COL, "") or "") + if bool(row.get(_DRIPPER_LAYOUT_FINALIZED_COL, False)): + return _DripperPostResult( + main_html=str(row.get(self.output_html_col, "") or ""), + main_content=row.get(self.output_content_col, "") or "", + postprocess_time_s=float(row.get(self.postprocess_time_col, 0.0) or 0.0), + error=str(row.get(self.error_col, "") or ""), + warning=warning, + ) + html = _coerce_html(html_value) + if bool(row.get(_DRIPPER_EMPTY_INPUT_COL, False)) or not html.strip(): + return _DripperPostResult( + postprocess_time_s=time.perf_counter() - started, + warning=warning or "empty HTML input", + ) + + url = _coerce_optional_str(url_value) + case = self._build_case( + html=html, + url=url, + simplified_html=str(row.get(self.simplified_html_col, "") or ""), + mapped_html=str(row.get(self.mapped_html_col, "") or ""), + ) + raw_response = str(row.get(self.raw_response_col, "") or "") + needs_llm = bool(row.get(_DRIPPER_NEEDS_LLM_COL, False)) + + case, warning, fallback_error = self._postprocess_prepare_case( + case, + raw_response=raw_response, + needs_llm=needs_llm, + primary_error=primary_error, + warning=warning, + ) + if fallback_error: + return _DripperPostResult( + postprocess_time_s=time.perf_counter() - started, + error=fallback_error, + warning=warning, + ) + + conversion_error = "" + try: + _sanitize_case_output_html(case) + case = self._bindings.convert2content(case, output_format=self.output_format) + except Exception as exc: # noqa: BLE001 + conversion_error = str(exc) + logger.debug("Dripper content conversion failed: {}", conversion_error) + + output_data = getattr(case, "output_data", None) + main_html = getattr(output_data, "main_html", "") if output_data is not None else "" + main_content = getattr(output_data, "main_content", "") if output_data is not None else "" + if main_content is None: + main_content = "" + error = "" + if conversion_error: + if _is_empty_document_error(conversion_error) and not str(main_html).strip(): + warning = _append_warning(warning, conversion_error) + else: + error = conversion_error + + return _DripperPostResult( + main_html=main_html, + main_content=main_content, + postprocess_time_s=time.perf_counter() - started, + error=error, + warning=warning, + ) + + def _postprocess_prepare_case( + self, + case: object, + *, + raw_response: str, + needs_llm: bool, + primary_error: str, + warning: str, + ) -> tuple[object, str, str]: + if needs_llm and raw_response: + try: + case.generate_output = self._bindings.generate_output_cls(response=raw_response) + case = self._bindings.parse_result(case) + case = self._bindings.extract_main_html_single(case) + except Exception as exc: # noqa: BLE001 + primary_error = _append_warning(primary_error, str(exc)) + logger.debug("Dripper parse/extract failed, applying {} fallback: {}", self.fallback, primary_error) + fallback_result = _apply_fallback_extraction( + self._bindings, self._fallback_handler, case, primary_error + ) + warning = _append_warning(warning, fallback_result[1]) + return fallback_result[0], warning, fallback_result[2] + return case, warning, "" + if needs_llm and not primary_error: + primary_error = "empty Dripper response" + fallback_result = _apply_fallback_extraction(self._bindings, self._fallback_handler, case, primary_error) + warning = _append_warning(warning, fallback_result[1]) + return fallback_result[0], warning, fallback_result[2] + + def _build_case(self, *, html: str, url: str | None, simplified_html: str, mapped_html: str) -> object: + case = self._bindings.case_cls(self._bindings.input_cls(raw_html=html, url=url)) + if simplified_html or mapped_html: + case.process_data = self._bindings.process_data_cls(simpled_html=simplified_html, map_html=mapped_html) + return case diff --git a/nemo_curator/stages/text/experimental/dripper/_layout_planning.py b/nemo_curator/stages/text/experimental/dripper/_layout_planning.py new file mode 100644 index 0000000000..c6d0fd7069 --- /dev/null +++ b/nemo_curator/stages/text/experimental/dripper/_layout_planning.py @@ -0,0 +1,913 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Layout-group planning and URL/DOM helpers for DripperHTMLLayoutTemplateStage.""" + +from __future__ import annotations + +import json +import re +from collections import Counter, defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal +from urllib.parse import parse_qsl, urlparse + +import pandas as pd # noqa: TC002 — used at runtime (df.iterrows, df.iloc, etc.) +from loguru import logger + +from nemo_curator.stages.text.experimental.dripper.stage import ( + _DRIPPER_NEEDS_LLM_COL, + _coerce_html, + _is_missing, +) + +_LAYOUT_RE_MD5 = re.compile(r"^[0-9a-f]{32}$") +_LAYOUT_RE_SHA1 = re.compile(r"^[0-9a-f]{40}$") +_LAYOUT_RE_UUID = re.compile(r"^[a-z0-9]{8}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{4}-[a-z0-9]{12}$") +_LAYOUT_RE_TIMESTAMP = re.compile(r"^\d{10,13}$") +_LAYOUT_RE_NUM = re.compile(r"\d+") + +_LAYOUT_SEMANTIC_QUERY_VALUE_KEYS = {"hl", "lang", "language", "locale"} +_LAYOUT_EXACT_QUERY_VALUE_KEYS = {"id"} + +_LAYOUT_PAGE_SIGNATURE_MODES = { + "none", + "url_shape", + "url_low_card_query_shape", + "url_semantic_shape", + "item_count_bucket", + "item_count_exact", + "url_shape_item_count_bucket", + "url_shape_item_count_exact", + "url_low_card_query_shape_item_count_bucket", + "url_low_card_query_shape_item_count_exact", + "url_semantic_shape_item_count_bucket", + "url_semantic_shape_item_count_exact", +} + + +def _parse_url(value: object) -> tuple[str, object]: + text = "" if _is_missing(value) else str(value).strip() + if not text: + return "", None + parsed = urlparse(text) + if not parsed.hostname and "://" not in text: + parsed = urlparse(f"//{text}") + return text, parsed + + +def _url_host_key(value: object) -> str: + _text, parsed = _parse_url(value) + if parsed is None: + return "" + host = (parsed.hostname or "").strip().lower().rstrip(".") + try: + return host.encode("idna").decode("ascii") + except UnicodeError: + return host + + +def _normalize_url_path_segment(segment: str) -> str: + segment = segment.lower() + suffix = "" + if "." in segment: + segment, extension = segment.rsplit(".", 1) + suffix = f".{extension}" + if re.search(r"\d", segment): + return f"#num{suffix}" + return f"{segment}{suffix}" + + +def _url_shape_key(value: object) -> str: + _text, parsed = _parse_url(value) + if parsed is None: + return "" + raw_segments = [segment for segment in (parsed.path or "").split("/") if segment] + query_keys = ",".join(sorted({key for key, _value in parse_qsl(parsed.query, keep_blank_values=True)})) + if parsed.query: + normalized_segments = [segment.lower() for segment in raw_segments] + else: + normalized_segments = [_normalize_url_path_segment(segment) for segment in raw_segments] + return f"path={'/'.join(normalized_segments)}|q={query_keys}" + + +def _url_low_card_query_shape_key(value: object, low_card_query_keys: set[str]) -> str: + _text, parsed = _parse_url(value) + if parsed is None: + return "" + raw_segments = [segment for segment in (parsed.path or "").split("/") if segment] + if parsed.query: + normalized_segments = [segment.lower() for segment in raw_segments] + else: + normalized_segments = [_normalize_url_path_segment(segment) for segment in raw_segments] + + include_all_query_values = bool(parsed.query) and not low_card_query_keys + query_parts = [] + for key, query_value in sorted(parse_qsl(parsed.query, keep_blank_values=True)): + lowered_key = key.strip().lower() + if not lowered_key: + continue + if ( + include_all_query_values + or lowered_key in low_card_query_keys + or lowered_key in _LAYOUT_EXACT_QUERY_VALUE_KEYS + ): + query_parts.append(f"{lowered_key}={query_value.strip().lower()}") + else: + query_parts.append(lowered_key) + return f"path={'/'.join(normalized_segments)}|q={','.join(query_parts)}" + + +def _url_semantic_shape_key(value: object) -> str: + def _norm_seg(seg: str) -> str: + seg = seg.lower() + suffix = "" + if "." in seg: + seg, ext = seg.rsplit(".", 1) + suffix = f".{ext}" + if ( + seg.isdigit() + or _LAYOUT_RE_MD5.fullmatch(seg) + or _LAYOUT_RE_SHA1.fullmatch(seg) + or _LAYOUT_RE_UUID.fullmatch(seg) + or _LAYOUT_RE_TIMESTAMP.fullmatch(seg) + ): + return f"#num{suffix}" + return f"{seg}{suffix}" + + def _norm_qval(v: str) -> str: + t = v.strip().lower() + if not t: + return "" + if ( + t.isdigit() + or _LAYOUT_RE_MD5.fullmatch(t) + or _LAYOUT_RE_SHA1.fullmatch(t) + or _LAYOUT_RE_UUID.fullmatch(t) + or _LAYOUT_RE_TIMESTAMP.fullmatch(t) + ): + return "#num" + return t + + _text, parsed = _parse_url(value) + if parsed is None: + return "" + raw_segments = [segment for segment in (parsed.path or "").split("/") if segment] + normalized_segments = [_norm_seg(segment) for segment in raw_segments] + query_parts = [] + for key, query_value in sorted(parse_qsl(parsed.query, keep_blank_values=True)): + lowered_key = key.lower() + if lowered_key in _LAYOUT_SEMANTIC_QUERY_VALUE_KEYS: + query_parts.append(f"{lowered_key}={_norm_qval(query_value)}") + else: + query_parts.append(lowered_key) + return f"path={'/'.join(normalized_segments)}|q={','.join(query_parts)}" + + +def _coerce_item_count(value: object) -> int: + if isinstance(value, bool): + return 0 + if isinstance(value, int): + return value + if isinstance(value, float) and value.is_integer(): + return int(value) + try: + return int(float(str(value))) + except (TypeError, ValueError): + return 0 + + +def _coerce_positive_int(value: object) -> int: + return max(0, _coerce_item_count(value)) + + +# (threshold, label) — label=None → use str(count); count > 128 → "129+" +_ITEM_COUNT_BUCKETS: tuple[tuple[int, str | None], ...] = ( + (8, None), + (16, "9-16"), + (32, "17-32"), + (64, "33-64"), + (128, "65-128"), +) + + +def _item_count_bucket(value: object) -> str: + count = _coerce_item_count(value) + if count <= 0: + return "0" + for threshold, label in _ITEM_COUNT_BUCKETS: + if count <= threshold: + return str(count) if label is None else label + return "129+" + + +def _layout_page_signature_key(url_value: object, item_count_value: object, mode: str) -> str: + return _layout_page_signature_key_with_low_card_queries(url_value, item_count_value, mode, set()) + + +def _layout_page_signature_key_with_low_card_queries( + url_value: object, + item_count_value: object, + mode: str, + low_card_query_keys: set[str], +) -> str: + if not mode or mode == "none": + return "" + parts: list[str] = [] + if "url_low_card_query_shape" in mode: + parts.append(f"url={_url_low_card_query_shape_key(url_value, low_card_query_keys)}") + elif "url_semantic_shape" in mode: + parts.append(f"url={_url_semantic_shape_key(url_value)}") + elif "url_shape" in mode: + parts.append(f"url={_url_shape_key(url_value)}") + if "item_count_exact" in mode: + parts.append(f"items={_coerce_item_count(item_count_value)}") + elif "item_count_bucket" in mode: + parts.append(f"items={_item_count_bucket(item_count_value)}") + return "|".join(parts) + + +def _validation_query_values(url_text: str) -> list[tuple[str, str]]: + _text, parsed = _parse_url(url_text) + if parsed is None: + return [] + return [ + (key.strip().lower(), value.strip().lower()) + for key, value in parse_qsl(parsed.query, keep_blank_values=True) + if key.strip() + ] + + +def _low_card_query_value_keys(url_values: list[Any], max_distinct: int = 16) -> set[str]: + values_by_key: dict[str, set[str]] = defaultdict(set) + for url_value in url_values: + url_text = "" if _is_missing(url_value) else str(url_value) + for key, value in _validation_query_values(url_text): + values_by_key[key].add(value) + return {key for key, values in values_by_key.items() if 1 < len(values) <= max_distinct} + + +_LAYOUT_TAGS_TO_IGNORE = {"script", "style", "meta", "link", "br", "noscript"} +_LAYOUT_TAGS_IGNORE_ATTR = {"a", "i", "b", "li", "tr", "td", "img", "p", "body"} +_TOKEN_RE = re.compile(r"\w+", re.UNICODE) + + +def _normalize_attr_tokens(value: str | None) -> str: + if not value: + return "" + tokens = value.split() + if len(tokens) > 1: + normalized = [token.lower() for token in tokens if not _LAYOUT_RE_NUM.search(token)] + else: + lowered = tokens[0].strip().lower() + normalized_tok = next( + ( + label + for pat, label in ( + (_LAYOUT_RE_MD5, "[MD5]"), + (_LAYOUT_RE_SHA1, "[SHA1]"), + (_LAYOUT_RE_UUID, "[UUID]"), + (_LAYOUT_RE_TIMESTAMP, "[TIMESTAMP]"), + ) + if pat.fullmatch(lowered) + ), + _LAYOUT_RE_NUM.sub("", lowered), + ) + normalized = [normalized_tok] if normalized_tok else [] + return " ".join(token for token in normalized if token) + + +def _walk_dom_element(element: object) -> object: + raw_tag = getattr(element, "tag", None) + if not isinstance(raw_tag, str): + return None + tag = raw_tag.lower() + if tag in _LAYOUT_TAGS_TO_IGNORE: + return None + attrs: list[tuple[str, str]] = [] + if tag not in _LAYOUT_TAGS_IGNORE_ATTR: + class_attr = _normalize_attr_tokens(element.get("class")) + id_attr = _normalize_attr_tokens(element.get("id")) + if class_attr: + attrs.append(("class", class_attr)) + if id_attr: + attrs.append(("id", id_attr)) + children = [child for child in (_walk_dom_element(child) for child in element) if child is not None] + return [tag, attrs, children] + + +def _layout_dom_path_fingerprint(html_text: str) -> str: + try: + from lxml.html import HTMLParser, fromstring + except ModuleNotFoundError: + return "" + try: + parser = HTMLParser(collect_ids=False, encoding="utf-8", remove_comments=True, remove_pis=True) + root = fromstring(html_text.encode("utf-8", errors="ignore"), parser=parser) + body_nodes = root.xpath("//body") + root = body_nodes[0] if body_nodes else root + except Exception: # noqa: BLE001 + return "" + return json.dumps(_walk_dom_element(root), ensure_ascii=False, sort_keys=True, separators=(",", ":")) + + +def _layout_feature_fingerprint(feature: object) -> str: + if not isinstance(feature, dict): + return "" + + def normalize_part(part: str) -> dict[str, list[tuple[str, int]]]: + raw = feature.get(part, {}) + if not isinstance(raw, dict): + return {} + return { + str(layer): sorted(Counter(str(v) for v in vals).items()) + for layer, vals in raw.items() + if isinstance(vals, list) + } + + payload = {"tags": normalize_part("tags"), "attrs": normalize_part("attrs")} + return json.dumps(payload, ensure_ascii=False, sort_keys=True, separators=(",", ":")) + + +def _coerce_optional_float(value: object) -> float | None: + if isinstance(value, bool) or value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _labels_to_webkit_response(labels: object) -> dict[str, int]: + if not isinstance(labels, dict): + return {} + response: dict[str, int] = {} + for item_id, label in labels.items(): + normalized = str(label).strip().lower() + response[f"item_id {item_id}"] = 1 if normalized in {"main", "1", "true"} else 0 + return response + + +def _item_id_response(all_item_ids: list[str], main_item_ids: set[str]) -> str: + labels = {item_id: ("main" if item_id in main_item_ids else "other") for item_id in all_item_ids} + if all(item_id.isdigit() for item_id in all_item_ids): + return "".join(f"{item_id}{label}" for item_id, label in labels.items()) + return json.dumps(labels, ensure_ascii=False, separators=(",", ":")) + + +def _token_f1(candidate: object, reference: object) -> float: + candidate_tokens = Counter(_TOKEN_RE.findall(str(candidate or "").lower())) + reference_tokens = Counter(_TOKEN_RE.findall(str(reference or "").lower())) + if not candidate_tokens and not reference_tokens: + return 1.0 + if not candidate_tokens or not reference_tokens: + return 0.0 + overlap = sum((candidate_tokens & reference_tokens).values()) + if overlap == 0: + return 0.0 + precision = overlap / sum(candidate_tokens.values()) + recall = overlap / sum(reference_tokens.values()) + return 2 * precision * recall / (precision + recall) + + +if TYPE_CHECKING: + from collections.abc import Callable + + from nemo_curator.stages.text.experimental.dripper.stage import ( + _LLMWebKitBindings, + ) + +# Column name duplicated here to avoid a circular import with layout_template.py. +_DRIPPER_ITEM_COUNT_COL = "dripper_item_count" +_MAX_EXEMPLARS_PER_LAYOUT = 3 + + +@dataclass(kw_only=True) +class DripperLayoutAdvancedConfig: + host_single_cluster_min_pages: int = 0 + host_single_cluster_max_pages: int = 0 + max_exact_host_pages: int = 0 + large_host_mode: Literal["standalone", "feature_hash", "dom_path_hash"] = "standalone" + propagation_concurrency: int = 32 + representative_candidates: int = 1 + defer_fallback_llm: bool = False + defer_propagation: bool = False + failed_host_fallback_signature_mode: str = "none" + failed_layout_fallback_signature_mode: str = "none" + page_signature_mode: str = "none" + validation_signature_mode: str = "none" + + +@dataclass(frozen=True) +class _LayoutGroupPlan: + indexes: list[int] + host_key: str = "" + source: str = "dom" + fallback_groups: tuple[list[int], ...] = () + + +@dataclass(frozen=True) +class _LayoutPlanningConfig: + html_col: str + url_col: str | None + host_col: str | None + layout_id_col: str | None + layout_cluster_threshold: float + min_cluster_size: int + adv: DripperLayoutAdvancedConfig + web_bindings: _LLMWebKitBindings | None + + +def _build_layout_group_plans(cfg: _LayoutPlanningConfig, df: pd.DataFrame) -> list[_LayoutGroupPlan]: + if len(df) < cfg.min_cluster_size: + return [] + precomputed_plans = _build_precomputed_layout_group_plans(cfg, df) + if precomputed_plans is not None: + return precomputed_plans + + samples_by_host = _build_host_samples(cfg, df) + return _build_plans_from_host_samples(cfg, df, samples_by_host) + + +def _build_host_samples(cfg: _LayoutPlanningConfig, df: pd.DataFrame) -> dict[str, list[dict[str, Any]]]: + samples_by_host: dict[str, list[dict[str, Any]]] = defaultdict(list) + for idx, row in df.iterrows(): + if not bool(row.get(_DRIPPER_NEEDS_LLM_COL, False)): + continue + html_text = _coerce_html(row.get(cfg.html_col, "")) + if not html_text.strip(): + continue + try: + feature = cfg.web_bindings.get_feature(html_text) + except Exception as exc: # noqa: BLE001 + logger.debug("Dripper layout feature extraction failed for row {}: {}", idx, exc) + continue + if feature is None: + continue + samples_by_host[_row_host_key(cfg, row)].append({"track_id": str(idx), "html": html_text, "feature": feature}) + return samples_by_host + + +def _build_plans_from_host_samples( + cfg: _LayoutPlanningConfig, + df: pd.DataFrame, + samples_by_host: dict[str, list[dict[str, Any]]], +) -> list[_LayoutGroupPlan]: + plans: list[_LayoutGroupPlan] = [] + adv = cfg.adv + for host_key, samples in samples_by_host.items(): + if len(samples) < cfg.min_cluster_size: + continue + host_indexes = sorted(int(sample["track_id"]) for sample in samples) + fallback_groups = _build_layout_groups_for_host_samples(cfg, df, host_key, samples) + n = len(samples) + try_single = ( + adv.host_single_cluster_min_pages > 0 + and n >= adv.host_single_cluster_min_pages + and not (adv.host_single_cluster_max_pages > 0 and n > adv.host_single_cluster_max_pages) + ) + if try_single: + plans.append( + _LayoutGroupPlan( + indexes=host_indexes, + host_key=host_key, + source="host_single_cluster", + fallback_groups=tuple(fallback_groups), + ) + ) + continue + for indexes in fallback_groups: + plans.append( + _LayoutGroupPlan( + indexes=indexes, + host_key=host_key, + source="dom", + fallback_groups=tuple(_build_failed_layout_fallback_groups(cfg, df, indexes)), + ) + ) + return plans + + +def _build_precomputed_layout_group_plans( + cfg: _LayoutPlanningConfig, df: pd.DataFrame +) -> list[_LayoutGroupPlan] | None: + if not cfg.layout_id_col or cfg.layout_id_col not in df.columns: + return None + + by_layout: dict[tuple[str, str], list[int]] = defaultdict(list) + for idx, row in df.iterrows(): + if not bool(row.get(_DRIPPER_NEEDS_LLM_COL, False)): + continue + html_text = _coerce_html(row.get(cfg.html_col, "")) + if not html_text.strip(): + continue + layout_key = _row_layout_id_key(cfg, row) + if not layout_key: + continue + by_layout[(_row_host_key(cfg, row), layout_key)].append(int(idx)) + + plans: list[_LayoutGroupPlan] = [] + for (host_key, layout_key), indexes in sorted(by_layout.items(), key=lambda item: (min(item[1]), item[0])): + sorted_indexes = sorted(indexes) + if len(sorted_indexes) < cfg.min_cluster_size: + continue + plan_groups = _split_large_precomputed_layout_group(cfg, df, host_key, layout_key, sorted_indexes) + for plan_indexes in plan_groups: + if len(plan_indexes) < cfg.min_cluster_size: + continue + plans.append( + _LayoutGroupPlan( + indexes=plan_indexes, + host_key=host_key, + source=f"precomputed_layout:{layout_key}", + fallback_groups=tuple(_build_failed_layout_fallback_groups(cfg, df, plan_indexes)), + ) + ) + return plans + + +def _split_large_precomputed_layout_group( + cfg: _LayoutPlanningConfig, + df: pd.DataFrame, + host_key: str, + _layout_key: str, + indexes: list[int], +) -> list[list[int]]: + adv = cfg.adv + if not adv.max_exact_host_pages or len(indexes) <= adv.max_exact_host_pages: + return [indexes] + if adv.large_host_mode == "standalone": + return [] + + samples: list[dict[str, Any]] = [] + for idx in indexes: + html_text = _coerce_html(df.iloc[idx].get(cfg.html_col, "")) + if not html_text.strip(): + continue + sample: dict[str, Any] = {"track_id": str(idx), "html": html_text} + if adv.large_host_mode == "feature_hash": + try: + feature = cfg.web_bindings.get_feature(html_text) if cfg.web_bindings else None + except Exception as exc: # noqa: BLE001 + logger.debug("Dripper precomputed layout feature extraction failed for row {}: {}", idx, exc) + continue + if feature is None: + continue + sample["feature"] = feature + samples.append(sample) + fingerprint_fn = ( + (lambda sample: _layout_feature_fingerprint(sample.get("feature"))) + if adv.large_host_mode == "feature_hash" + else (lambda sample: _layout_dom_path_fingerprint(str(sample.get("html") or ""))) + ) + return _build_fingerprint_groups(cfg, df, host_key, samples, fingerprint_fn=fingerprint_fn) + + +def _row_host_key(cfg: _LayoutPlanningConfig, row: pd.Series) -> str: + if cfg.host_col and cfg.host_col in row: + host_key = _url_host_key(row.get(cfg.host_col)) + if host_key: + return host_key + return _url_host_key(row.get(cfg.url_col) if cfg.url_col else None) + + +def _row_layout_id_key(cfg: _LayoutPlanningConfig, row: pd.Series) -> str: + if not cfg.layout_id_col: + return "" + value = row.get(cfg.layout_id_col) + text = "" if _is_missing(value) else str(value).strip() + if not text or text in {"-1", "-2"} or text.endswith(("_-1", "_-2")): + return "" + return text + + +def _build_layout_groups_for_host_samples( + cfg: _LayoutPlanningConfig, + df: pd.DataFrame, + host_key: str, + samples: list[dict[str, Any]], +) -> list[list[int]]: + if len(samples) < cfg.min_cluster_size: + return [] + + # Large-host fast path: skip clustering, use fingerprint bucketing instead. + adv = cfg.adv + if adv.max_exact_host_pages and len(samples) > adv.max_exact_host_pages: + if adv.large_host_mode == "feature_hash": + fingerprint_fn = lambda sample: _layout_feature_fingerprint(sample.get("feature")) # noqa: E731 + elif adv.large_host_mode == "dom_path_hash": + fingerprint_fn = lambda sample: _layout_dom_path_fingerprint(str(sample.get("html") or "")) # noqa: E731 + else: + return [] + return _build_fingerprint_groups(cfg, df, host_key, samples, fingerprint_fn=fingerprint_fn) + + try: + clustered_samples, _layout_ids = cfg.web_bindings.cluster_html_struct( + samples, + threshold=cfg.layout_cluster_threshold, + ) + except Exception as exc: # noqa: BLE001 + logger.debug("Dripper layout clustering failed for host {}: {}", host_key, exc) + return [] + + if not clustered_samples: + return [] + return _build_clustered_host_groups(cfg, df, host_key, clustered_samples) + + +def _build_clustered_host_groups( + cfg: _LayoutPlanningConfig, + df: pd.DataFrame, + _host_key: str, + clustered_samples: list[dict[str, Any]], +) -> list[list[int]]: + max_layer_n = int( + next((s.get("max_layer_n") for s in clustered_samples if int(s.get("layout_id", -1)) >= 0), None) or 5 + ) + exemplars_by_layout: dict[int, list[dict[str, Any]]] = defaultdict(list) + for sample in clustered_samples: + layout_id = int(sample.get("layout_id", -1)) + if layout_id < 0: + continue + if len(exemplars_by_layout[layout_id]) < _MAX_EXEMPLARS_PER_LAYOUT: + exemplars_by_layout[layout_id].append(sample) + + by_layout: dict[tuple[int, str], list[int]] = defaultdict(list) + for sample in clustered_samples: + layout_id = _assign_layout_by_exemplar_similarity(cfg, sample.get("feature"), exemplars_by_layout, max_layer_n) + if layout_id < 0: + continue + row_idx = int(sample["track_id"]) + _row = df.iloc[row_idx] + signature_key = _layout_page_signature_key( + _row.get(cfg.url_col) if cfg.url_col else None, + _row.get(_DRIPPER_ITEM_COUNT_COL), + cfg.adv.page_signature_mode, + ) + by_layout[(layout_id, signature_key)].append(row_idx) + groups: list[list[int]] = [] + for (_layout_id, _signature_key), indexes in sorted(by_layout.items()): + if len(indexes) >= cfg.min_cluster_size: + groups.append(sorted(indexes)) + return groups + + +def _build_failed_layout_fallback_groups( + cfg: _LayoutPlanningConfig, df: pd.DataFrame, indexes: list[int] +) -> list[list[int]]: + mode = cfg.adv.failed_layout_fallback_signature_mode + if mode == "none" or len(indexes) < cfg.min_cluster_size: + return [] + + children = _split_fallback_groups_by_signature(cfg, df, [indexes], mode) + parent_set = set(indexes) + return [child for child in children if set(child) != parent_set] + + +def _assign_layout_by_exemplar_similarity( + cfg: _LayoutPlanningConfig, + feature: object, + exemplars_by_layout: dict[int, list[dict[str, Any]]], + max_layer_n: int, +) -> int: + for layout_id, exemplars in sorted(exemplars_by_layout.items()): + for exemplar in exemplars: + try: + score = cfg.web_bindings.similarity(feature, exemplar.get("feature"), max_layer_n) + except Exception as exc: # noqa: BLE001 + logger.debug("Dripper layout similarity failed for layout {}: {}", layout_id, exc) + continue + if score is not None and score >= cfg.layout_cluster_threshold: + return layout_id + return -2 + + +def _build_fingerprint_groups( + cfg: _LayoutPlanningConfig, + df: pd.DataFrame, + _host_key: str, + samples: list[dict[str, Any]], + *, + fingerprint_fn: Callable[[dict[str, Any]], str], +) -> list[list[int]]: + by_fingerprint: dict[str, list[int]] = defaultdict(list) + for sample in samples: + by_fingerprint[fingerprint_fn(sample)].append(int(sample["track_id"])) + + groups: list[list[int]] = [] + for _fingerprint, indexes in sorted(by_fingerprint.items(), key=lambda item: (min(item[1]), item[0])): + by_signature: dict[str, list[int]] = defaultdict(list) + for row_idx in indexes: + _row = df.iloc[row_idx] + signature_key = _layout_page_signature_key( + _row.get(cfg.url_col) if cfg.url_col else None, + _row.get(_DRIPPER_ITEM_COUNT_COL), + cfg.adv.page_signature_mode, + ) + by_signature[signature_key].append(row_idx) + for _signature_key, signature_indexes in sorted(by_signature.items()): + if len(signature_indexes) < cfg.min_cluster_size: + continue + groups.append(sorted(signature_indexes)) + return groups + + +def _split_fallback_groups_by_signature( + cfg: _LayoutPlanningConfig, + df: pd.DataFrame, + groups: list[list[int]], + mode: str, +) -> list[list[int]]: + split_groups: list[list[int]] = [] + for group in groups: + low_card_query_keys: set[str] = set() + if "url_low_card_query_shape" in mode and cfg.url_col: + low_card_query_keys = _low_card_query_value_keys([df.iloc[row_idx].get(cfg.url_col) for row_idx in group]) + by_signature: dict[str, list[int]] = defaultdict(list) + use_low_card = "url_low_card_query_shape" in mode + for row_idx in group: + row = df.iloc[row_idx] + url = row.get(cfg.url_col) if cfg.url_col else None + if use_low_card: + signature_key = _layout_page_signature_key_with_low_card_queries( + url, row.get(_DRIPPER_ITEM_COUNT_COL), mode, low_card_query_keys + ) + else: + signature_key = _layout_page_signature_key(url, row.get(_DRIPPER_ITEM_COUNT_COL), mode) + by_signature[signature_key].append(row_idx) + for _signature, indexes in sorted(by_signature.items(), key=lambda item: (min(item[1]), item[0])): + if len(indexes) >= cfg.min_cluster_size: + split_groups.append(sorted(indexes)) + return split_groups + + +_QUERY_POSITIONS_THRESHOLD = 8 +_QUERY_POSITIONS_HIGH = 4 +_QUERY_POSITIONS_LOW = 3 + +_ColSpec = tuple[str | None, str] + + +@dataclass +class _SelectorState: + selected: list[int] + selected_set: set[int] + count: int + url_col: str | None + item_count_col: str + + def add(self, idx: int) -> None: + if len(self.selected) >= self.count or idx in self.selected_set: + return + self.selected.append(idx) + self.selected_set.add(idx) + + def is_full(self) -> bool: + return len(self.selected) >= self.count + + +def _select_by_signature( + df: pd.DataFrame, + indexes: list[int], + *, + signature_mode: str, + state: _SelectorState, +) -> bool: + url_col = state.url_col + item_count_col = state.item_count_col + low_card_query_keys: set[str] = set() + if "url_low_card_query_shape" in signature_mode and url_col: + low_card_query_keys = _low_card_query_value_keys([df.iloc[idx].get(url_col) for idx in indexes]) + by_signature: dict[str, list[int]] = defaultdict(list) + for idx in indexes: + row = df.iloc[idx] + signature_key = _layout_page_signature_key_with_low_card_queries( + row.get(url_col) if url_col else None, + row.get(item_count_col) if item_count_col in row else None, + signature_mode, + low_card_query_keys, + ) + by_signature[signature_key].append(idx) + signature_groups = sorted( + by_signature.values(), + key=lambda group: (-len(group), _validation_sample_key(df.iloc[group[0]], group[0], url_col, item_count_col)), + ) + for group in signature_groups: + for idx in _select_validation_indexes(df, sorted(group), 1, (url_col, item_count_col), signature_mode="none"): + state.add(idx) + break + if state.is_full(): + return True + return False + + +def _select_by_url( + df: pd.DataFrame, + indexes: list[int], + *, + state: _SelectorState, +) -> None: + url_col = state.url_col + count = state.count + query_value_rows: dict[str, list[tuple[str, int]]] = defaultdict(list) + for idx in indexes: + url_text = str(df.iloc[idx].get(url_col) or "") + for key, value in _validation_query_values(url_text): + query_value_rows[key].append((value, idx)) + for key in sorted(query_value_rows): + entries = sorted(query_value_rows[key]) + query_positions = _QUERY_POSITIONS_HIGH if count >= _QUERY_POSITIONS_THRESHOLD else _QUERY_POSITIONS_LOW + for position in _spread_positions(len(entries), min(count, query_positions)): + state.add(entries[position][1]) + if state.is_full(): + return + + url_sorted = sorted(indexes, key=lambda idx: (str(df.iloc[idx].get(url_col) or ""), idx)) + for position in _spread_positions(len(url_sorted), count): + state.add(url_sorted[position]) + if state.is_full(): + return + + +def _select_validation_indexes( + df: pd.DataFrame, + indexes: list[int], + count: int, + cols: _ColSpec, + *, + signature_mode: str = "none", +) -> list[int]: + url_col, item_count_col = cols + if count <= 0 or not indexes: + return [] + if count >= len(indexes): + return list(indexes) + if count == 1: + return [indexes[-1]] + + state = _SelectorState( + selected=[], selected_set=set(), count=count, url_col=url_col, item_count_col=item_count_col + ) + + if ( + signature_mode + and signature_mode != "none" + and _select_by_signature(df, indexes, signature_mode=signature_mode, state=state) + ): + return sorted(state.selected) + + state.add(indexes[0]) + state.add(indexes[-1]) + + item_sorted = sorted(indexes, key=lambda idx: (_coerce_item_count(df.iloc[idx].get(item_count_col)), idx)) + state.add(item_sorted[0]) + state.add(item_sorted[-1]) + + if url_col: + _select_by_url(df, indexes, state=state) + if state.is_full(): + return sorted(state.selected) + + remaining = [idx for idx in indexes if idx not in state.selected_set] + remaining.sort(key=lambda idx: _validation_sample_key(df.iloc[idx], idx, url_col, item_count_col)) + for idx in remaining: + state.add(idx) + if state.is_full(): + break + return sorted(state.selected) + + +def _spread_positions(length: int, count: int) -> list[int]: + if length <= 0 or count <= 0: + return [] + if count >= length: + return list(range(length)) + if count == 1: + return [length // 2] + return sorted({round(slot * (length - 1) / (count - 1)) for slot in range(count)}) + + +def _validation_sample_key( + row: pd.Series, + row_index: int, + url_col: str | None, + item_count_col: str, +) -> tuple[int, int]: + import hashlib + + url_text = str(row.get(url_col) or "") if url_col else "" + item_count = str(row.get(item_count_col) or "") + payload = f"{url_text}\0{item_count}\0{row_index}".encode("utf-8", errors="replace") + digest = hashlib.blake2b(payload, digest_size=8).digest() + return int.from_bytes(digest, byteorder="big", signed=False), row_index diff --git a/nemo_curator/stages/text/experimental/dripper/gpu_layout_clustering.py b/nemo_curator/stages/text/experimental/dripper/gpu_layout_clustering.py new file mode 100644 index 0000000000..a2ad75a54d --- /dev/null +++ b/nemo_curator/stages/text/experimental/dripper/gpu_layout_clustering.py @@ -0,0 +1,166 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU-accelerated layout clustering using cuML DBSCAN + cupy cosine similarity. + +Drop-in replacement for llm-webkit's cluster_html_struct (same inputs/outputs). +Falls back to sklearn when CUDA unavailable or cluster < GPU_MIN_SIZE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from loguru import logger + +if TYPE_CHECKING: + from collections.abc import Callable + from types import ModuleType + + import cupy as cp + +GPU_MIN_SIZE = 200 + + +def _gpu_available() -> bool: + try: + import cupy as cp + + _ = cp.cuda.Device(0).compute_capability # raises if no GPU + except Exception: # noqa: BLE001 - any import/runtime error means no usable GPU + return False + return True + + +def _feature_matrices(features_vec: list[dict]) -> tuple[np.ndarray, np.ndarray]: + tags = np.stack([f["tags"] for f in features_vec]).astype(np.float32) + attrs = np.stack([f["attrs"] for f in features_vec]).astype(np.float32) + return tags, attrs + + +def _cosine_similarity_gpu(x: cp.ndarray) -> cp.ndarray: + import cupy as cp + + norms = cp.linalg.norm(x, axis=1, keepdims=True) + norms = cp.maximum(norms, 1e-10) + x_norm = x / norms + return x_norm @ x_norm.T # (N, D) @ (D, N) -> (N, N) cosine similarity + + +def cluster_html_struct_gpu( + sampled_list: list[dict], + threshold: float = 0.95, + gpu_min_size: int = GPU_MIN_SIZE, + tag_weight: float = 0.7, +) -> tuple[list[dict], list[int]]: + n = len(sampled_list) + + import llm_web_kit.html_layout.html_layout_cosin as _cosin_mod + from llm_web_kit.html_layout.html_layout_cosin import cluster_html_struct as _sklearn_cluster + + use_gpu = n >= gpu_min_size and _gpu_available() + + if not use_gpu: + logger.debug( + "cluster_html_struct_gpu: n={} < gpu_min_size={} or no GPU — using sklearn", + n, + gpu_min_size, + ) + return _sklearn_cluster(sampled_list, threshold) + + logger.info("cluster_html_struct_gpu: n={} pages — using GPU (cuML DBSCAN + cupy cosine)", n) + try: + return _cluster_gpu(sampled_list, threshold, tag_weight, _cosin_mod) + except Exception as exc: # noqa: BLE001 - fall back to sklearn on any GPU failure + logger.warning("GPU clustering failed ({}) — falling back to sklearn", exc) + return _sklearn_cluster(sampled_list, threshold) + + +def _cluster_gpu( + sampled_list: list[dict], + threshold: float, + tag_weight: float, + cosin_mod: ModuleType, +) -> tuple[list[dict], list[int]]: + import cuml.cluster + import cupy as cp + + features = [s["feature"] for s in sampled_list] + _simp_features_fn = _get_simp_features(cosin_mod) + layer_n, features_vec = _simp_features_fn(features) + tags, attrs = _feature_matrices(features_vec) + + tags_gpu = cp.asarray(tags) + attrs_gpu = cp.asarray(attrs) + tag_sim = _cosine_similarity_gpu(tags_gpu) + attr_sim = _cosine_similarity_gpu(attrs_gpu) + + attr_norms = cp.linalg.norm(attrs_gpu, axis=1) + no_attr = attr_norms == 0 + sim_matrix = tag_weight * tag_sim + (1 - tag_weight) * attr_sim + if cp.any(no_attr): + sim_matrix[no_attr, :] = tag_sim[no_attr, :] + sim_matrix[:, no_attr] = tag_sim[:, no_attr] + + dist_matrix = 1.0 - cp.clip(sim_matrix, 0, 1) + eps = float(1.0 - threshold) + dist_np = cp.asnumpy(dist_matrix) + + try: + dbscan = cuml.cluster.DBSCAN( + eps=eps, + min_samples=2, + metric="precomputed", + output_type="numpy", + ) + layout_ids = dbscan.fit_predict(dist_np) + except Exception as exc: # noqa: BLE001 + logger.debug("cuML DBSCAN precomputed failed ({}), using sklearn", exc) + layout_ids = _sklearn_dbscan(dist_np, eps) + + layout_ids = [int(x) for x in layout_ids] + + success = [] + for idd, sample in zip(layout_ids, sampled_list, strict=False): + sample["layout_id"] = idd + sample["max_layer_n"] = layer_n + success.append(sample) + + n_clusters = len({x for x in layout_ids if x >= 0}) + n_noise = sum(1 for x in layout_ids if x < 0) + logger.info("cluster_html_struct_gpu: n={} → {} clusters ({} noise)", len(sampled_list), n_clusters, n_noise) + return success, list(set(layout_ids)) + + +def _get_simp_features(cosin_mod: ModuleType) -> Callable: + # llm-webkit's __simp_features is module-private; Python mangles it to ___simp_features. + # We look up both forms so upstream renames surface immediately rather than silently failing. + for name in ("_html_layout_cosin__simp_features", "__simp_features", "simp_features"): + fn = getattr(cosin_mod, name, None) + if callable(fn): + return fn + msg = ( + "Could not find the feature-vectorization helper (__simp_features) in " + "llm_web_kit.html_layout.html_layout_cosin; the GPU clustering path needs it. " + "The llm_web_kit internal API may have changed." + ) + raise RuntimeError(msg) + + +def _sklearn_dbscan(dist_matrix: np.ndarray, eps: float) -> list[int]: + from sklearn.cluster import DBSCAN + + clustering = DBSCAN(eps=eps, min_samples=2, metric="precomputed") + return clustering.fit_predict(dist_matrix).tolist() diff --git a/nemo_curator/stages/text/experimental/dripper/layout_template.py b/nemo_curator/stages/text/experimental/dripper/layout_template.py new file mode 100644 index 0000000000..1ac45b188e --- /dev/null +++ b/nemo_curator/stages/text/experimental/dripper/layout_template.py @@ -0,0 +1,900 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DripperHTMLLayoutTemplateStage: layout clustering + LBP template propagation for CC-scale HTML extraction.""" + +from __future__ import annotations + +import asyncio +import json +import time +from dataclasses import dataclass, field, replace +from typing import TYPE_CHECKING, Any, Literal + +import pandas as pd +from loguru import logger + +from nemo_curator.models.client.llm_client import GenerationConfig +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.text.experimental.dripper._layout_planning import ( + _LAYOUT_PAGE_SIGNATURE_MODES, + DripperLayoutAdvancedConfig, + _build_failed_layout_fallback_groups, + _build_layout_group_plans, + _coerce_optional_float, + _coerce_positive_int, + _item_id_response, + _labels_to_webkit_response, + _LayoutGroupPlan, + _LayoutPlanningConfig, + _select_validation_indexes, + _split_fallback_groups_by_signature, + _token_f1, +) +from nemo_curator.stages.text.experimental.dripper.stage import ( + _DRIPPER_EMPTY_INPUT_COL, + _DRIPPER_LAYOUT_FINALIZED_COL, + _DRIPPER_NEEDS_LLM_COL, + _DRIPPER_PRIMARY_ERROR_COL, + _DRIPPER_PROMPT_COL, + _STRUCTURED_OUTPUT_MODES, + _append_warning, + _apply_fallback_extraction, + _coerce_html, + _coerce_optional_str, + _coerce_usage_int, + _DripperInferenceResult, + _DripperPostResult, + _is_empty_document_error, + _item_ids_in_html, + _LLMWebKitBindings, + _load_llm_web_kit_bindings, + _load_mineru_html_bindings, + _MinerUHTMLBindings, + _numeric_series_or_zero, + _query_dripper_model, + _rebuild_batch, + _run_dripper_health_check, + _sanitize_case_output_html, + _with_structured_output_config, +) +from nemo_curator.stages.text.experimental.translation.utils.async_utils import run_async_safe +from nemo_curator.tasks import DocumentBatch + +if TYPE_CHECKING: + from nemo_curator.backends.base import WorkerMetadata + from nemo_curator.models.client.llm_client import AsyncLLMClient + +_DRIPPER_OUTPUT_HTML_COL = "dripper_html" +_DRIPPER_OUTPUT_CONTENT_COL = "dripper_content" +_DRIPPER_RAW_RESPONSE_COL = "dripper_response" +_DRIPPER_PREPROCESS_TIME_COL = "dripper_preprocess_time_s" +_DRIPPER_INFERENCE_TIME_COL = "dripper_inference_time_s" +_DRIPPER_POSTPROCESS_TIME_COL = "dripper_postprocess_time_s" +_DRIPPER_TOTAL_TIME_COL = "dripper_time_s" +_DRIPPER_ERROR_COL = "dripper_error" +_DRIPPER_WARNING_COL = "dripper_warning" +_DRIPPER_ITEM_COUNT_COL = "dripper_item_count" +_DRIPPER_REQUEST_MAX_TOKENS_COL = "dripper_request_max_tokens" +_DRIPPER_PROMPT_TOKENS_COL = "dripper_prompt_tokens" +_DRIPPER_COMPLETION_TOKENS_COL = "dripper_completion_tokens" +_DRIPPER_TOTAL_TOKENS_COL = "dripper_total_tokens" +_DRIPPER_SIMPLIFIED_HTML_COL = "dripper_simplified_html" +_DRIPPER_MAPPED_HTML_COL = "dripper_mapped_html" + +_LAYOUT_TEMPLATE_LARGE_HOST_MODES = {"standalone", "feature_hash", "dom_path_hash"} +_LAYOUT_TEMPLATE_PROPAGATION_TARGET_MODES = {"raw_html", "mapped_item_ids"} + + +@dataclass(frozen=True) +class _LayoutTemplateRowResult: + raw_response: str = "" + inference_time_s: float = 0.0 + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + main_html: str = "" + main_content: Any = "" + postprocess_time_s: float = 0.0 + error: str = "" + warning: str = "" + primary_error: str = "" + deferred_llm: bool = False + layout_finalized: bool = True + layout_cluster: str = "" + layout_representative: bool = False + layout_propagated: bool = False + layout_propagation_success: bool = False + layout_fallback_llm: bool = False + layout_standalone_llm: bool = False + layout_pending_propagation: bool = False + layout_mapping_json: str = "" + + +@dataclass(frozen=True) +class _LayoutGroupOutcome: + results: dict[int, _LayoutTemplateRowResult] + accepted: bool = True + failure_reason: str = "" + + +@dataclass(frozen=True) +class _LayoutProcessContext: + df: pd.DataFrame + semaphore: asyncio.Semaphore + propagation_semaphore: asyncio.Semaphore + inference_cache: _InferenceCache + inference_cache_lock: asyncio.Lock + needs_llm: list[bool] + + +_InferenceCache = dict[tuple[str, int], asyncio.Task[_DripperInferenceResult]] + + +def _inference_token_fields(r: _DripperInferenceResult) -> dict[str, object]: + return { + "raw_response": r.raw_response, + "inference_time_s": r.inference_time_s, + "prompt_tokens": r.prompt_tokens, + "completion_tokens": r.completion_tokens, + "total_tokens": r.total_tokens, + } + + +@dataclass(kw_only=True) +class DripperHTMLLayoutTemplateStage(ProcessingStage[DocumentBatch, DocumentBatch]): + name: str = "DripperHTMLLayoutTemplateStage" + client: AsyncLLMClient | None + model_name: str + html_col: str = "html" + url_col: str | None = "url" + host_col: str | None = None + layout_id_col: str | None = None + generation_config: GenerationConfig | None = None + structured_output_mode: Literal["none", "structured_outputs", "guided_regex"] = "none" + max_concurrent_requests: int = 64 + fallback: Literal["trafilatura", "bypass", "empty"] = "trafilatura" + output_format: str = "mm_md" + keep_intermediate: bool = False + layout_cluster_threshold: float = 0.95 + layout_template_min_cluster_size: int = 2 + layout_template_fallback_llm: bool = True + layout_template_require_success: bool = True + layout_template_max_selected_item_ratio: float | None = 0.50 + layout_template_more_noise_enable: bool = True + layout_template_validation_rows: int = 0 + layout_template_validation_min_content_f1: float = 0.98 + layout_template_large_cluster_validation_rows: int = 0 + layout_template_large_cluster_min_size: int = 0 + layout_template_propagation_target: Literal["raw_html", "mapped_item_ids"] = "raw_html" + layout_template_min_main_html_sim: float | None = None + layout_template_min_content_length_ratio: float | None = None + layout_template_max_content_length_ratio: float | None = None + dynamic_classid_similarity_threshold: float = 0.85 + layout_host_single_cluster_min_pages: int = 0 + layout_host_single_cluster_max_pages: int = 0 + layout_max_exact_host_pages: int = 0 + layout_large_host_mode: Literal["standalone", "feature_hash", "dom_path_hash"] = "standalone" + layout_propagation_concurrency: int = 32 + layout_representative_candidates: int = 1 + layout_defer_fallback_llm: bool = False + layout_defer_propagation: bool = False + layout_failed_host_fallback_signature_mode: str = "none" + layout_failed_layout_fallback_signature_mode: str = "none" + layout_page_signature_mode: str = "none" + layout_validation_signature_mode: str = "none" + health_check: bool = False + worker_count: int | None = None + + _bindings: _MinerUHTMLBindings | None = field(init=False, repr=False, default=None) + _web_bindings: _LLMWebKitBindings | None = field(init=False, repr=False, default=None) + _fallback_handler: Any = field(init=False, repr=False, default=None) + _initialized: bool = field(init=False, repr=False, default=False) + + @property + def _planning_cfg(self) -> _LayoutPlanningConfig: + adv = DripperLayoutAdvancedConfig(host_single_cluster_min_pages=self.layout_host_single_cluster_min_pages, host_single_cluster_max_pages=self.layout_host_single_cluster_max_pages, max_exact_host_pages=self.layout_max_exact_host_pages, large_host_mode=self.layout_large_host_mode, propagation_concurrency=self.layout_propagation_concurrency, representative_candidates=self.layout_representative_candidates, defer_fallback_llm=self.layout_defer_fallback_llm, defer_propagation=self.layout_defer_propagation, failed_host_fallback_signature_mode=self.layout_failed_host_fallback_signature_mode, failed_layout_fallback_signature_mode=self.layout_failed_layout_fallback_signature_mode, page_signature_mode=self.layout_page_signature_mode, validation_signature_mode=self.layout_validation_signature_mode) # fmt: skip + return _LayoutPlanningConfig(html_col=self.html_col, url_col=self.url_col, host_col=self.host_col, layout_id_col=self.layout_id_col, layout_cluster_threshold=self.layout_cluster_threshold, min_cluster_size=self.layout_template_min_cluster_size, adv=adv, web_bindings=self._web_bindings) # fmt: skip + + def __post_init__(self) -> None: + def _req(cond: bool, msg: str) -> None: + if not cond: + raise ValueError(msg) + + def _enum(val: object, valid: set, name: str) -> None: + if val not in valid: + msg = f"{name} must be one of {sorted(valid)}" + raise ValueError(msg) + + _req(self.client is not None, "DripperHTMLLayoutTemplateStage requires a non-None 'client' (AsyncLLMClient)") + self.model_name = self.model_name.strip() + _req(bool(self.model_name), "DripperHTMLLayoutTemplateStage requires a non-empty 'model_name'") + _req(self.max_concurrent_requests > 0, "max_concurrent_requests must be positive") + min_r = self.layout_template_min_content_length_ratio + max_r = self.layout_template_max_content_length_ratio + _req(0.0 < self.layout_cluster_threshold <= 1.0, "layout_cluster_threshold must be in (0, 1]") + _req(self.layout_template_min_cluster_size > 1, "layout_template_min_cluster_size must be greater than 1") + _max_sir = self.layout_template_max_selected_item_ratio + _req(_max_sir is None or 0.0 < _max_sir <= 1.0, "layout_template_max_selected_item_ratio must be in (0, 1] when set") # fmt: skip + _req(self.layout_representative_candidates > 0, "advanced.representative_candidates must be positive") + _min_sim = self.layout_template_min_main_html_sim + _req(_min_sim is None or 0.0 <= _min_sim <= 1.0, "layout_template_min_main_html_sim must be in [0, 1] when set") # fmt: skip + _f1 = self.layout_template_validation_min_content_f1 + _req(0.0 <= _f1 <= 1.0, "layout_template_validation_min_content_f1 must be in [0, 1]") + _req(self.dynamic_classid_similarity_threshold > 0, "dynamic_classid_similarity_threshold must be positive") + _req(self.layout_template_validation_rows >= 0, "layout_template_validation_rows must be non-negative") + _lcvr = self.layout_template_large_cluster_validation_rows + _req(_lcvr >= 0, "layout_template_large_cluster_validation_rows must be non-negative") + _lcms = self.layout_template_large_cluster_min_size + _req(_lcms >= 0, "layout_template_large_cluster_min_size must be non-negative") + _req(min_r is None or min_r >= 0, "layout_template_min_content_length_ratio must be non-negative when set") + _req(max_r is None or max_r >= 0, "layout_template_max_content_length_ratio must be non-negative when set") + _req(min_r is None or max_r is None or min_r <= max_r, "layout_template_min_content_length_ratio must be <= layout_template_max_content_length_ratio") # fmt: skip + _enum(self.layout_template_propagation_target, _LAYOUT_TEMPLATE_PROPAGATION_TARGET_MODES, "layout_template_propagation_target") # fmt: skip + for _val, _name in [ + (self.layout_validation_signature_mode, "advanced.validation_signature_mode"), + (self.layout_page_signature_mode, "advanced.page_signature_mode"), + (self.layout_failed_host_fallback_signature_mode, "advanced.failed_host_fallback_signature_mode"), + (self.layout_failed_layout_fallback_signature_mode, "advanced.failed_layout_fallback_signature_mode"), + ]: + _enum(_val, _LAYOUT_PAGE_SIGNATURE_MODES, _name) + _enum(self.layout_large_host_mode, _LAYOUT_TEMPLATE_LARGE_HOST_MODES, "advanced.large_host_mode") + _enum(self.structured_output_mode, _STRUCTURED_OUTPUT_MODES, "structured_output_mode") + _min_p, _max_p = self.layout_host_single_cluster_min_pages, self.layout_host_single_cluster_max_pages + _req(_min_p >= 0, "advanced.host_single_cluster_min_pages must be non-negative") + _req(_max_p >= 0, "advanced.host_single_cluster_max_pages must be non-negative") + _req(_max_p == 0 or _min_p <= _max_p, "advanced.host_single_cluster_min_pages must be <= max_pages when max is set") # fmt: skip + _req(self.layout_max_exact_host_pages >= 0, "advanced.max_exact_host_pages must be non-negative") + _req(self.layout_propagation_concurrency > 0, "advanced.propagation_concurrency must be positive") + _req(self.worker_count is None or self.worker_count > 0, "worker_count must be positive when set") + + def num_workers(self) -> int | None: + return self.worker_count + + def inputs(self) -> tuple[list[str], list[str]]: + return ["data"], [ + self.html_col, + _DRIPPER_RAW_RESPONSE_COL, + _DRIPPER_PREPROCESS_TIME_COL, + _DRIPPER_WARNING_COL, + _DRIPPER_ITEM_COUNT_COL, + _DRIPPER_REQUEST_MAX_TOKENS_COL, + _DRIPPER_SIMPLIFIED_HTML_COL, + _DRIPPER_MAPPED_HTML_COL, + _DRIPPER_PROMPT_COL, + _DRIPPER_NEEDS_LLM_COL, + _DRIPPER_PRIMARY_ERROR_COL, + _DRIPPER_EMPTY_INPUT_COL, + ] + + def outputs(self) -> tuple[list[str], list[str]]: + columns = [ + _DRIPPER_OUTPUT_HTML_COL, + _DRIPPER_OUTPUT_CONTENT_COL, + _DRIPPER_RAW_RESPONSE_COL, + _DRIPPER_INFERENCE_TIME_COL, + _DRIPPER_POSTPROCESS_TIME_COL, + _DRIPPER_TOTAL_TIME_COL, + _DRIPPER_ERROR_COL, + _DRIPPER_WARNING_COL, + _DRIPPER_PROMPT_TOKENS_COL, + _DRIPPER_COMPLETION_TOKENS_COL, + _DRIPPER_TOTAL_TOKENS_COL, + "dripper_layout_cluster", + "dripper_layout_representative", + "dripper_layout_propagated", + "dripper_layout_propagation_success", + "dripper_layout_fallback_llm", + "dripper_layout_standalone_llm", + _DRIPPER_LAYOUT_FINALIZED_COL, + ] + if self.layout_defer_propagation: + columns.extend(["dripper_layout_pending_propagation", "dripper_layout_mapping_json"]) + if self.layout_defer_fallback_llm: + columns += [_DRIPPER_SIMPLIFIED_HTML_COL, _DRIPPER_MAPPED_HTML_COL, _DRIPPER_PROMPT_COL, _DRIPPER_NEEDS_LLM_COL, _DRIPPER_PRIMARY_ERROR_COL, _DRIPPER_EMPTY_INPUT_COL] # fmt: skip + if self.keep_intermediate and not self.layout_defer_fallback_llm: + columns.extend([_DRIPPER_SIMPLIFIED_HTML_COL, _DRIPPER_MAPPED_HTML_COL]) + return ["data"], columns + + def setup(self, worker_metadata: WorkerMetadata | None = None) -> None: # noqa: ARG002 + if self._initialized: + return + self._bindings = _load_mineru_html_bindings() + self._web_bindings = _load_llm_web_kit_bindings() + self._fallback_handler = self._bindings.get_fallback_handler(self.fallback) + self.client.setup() # type: ignore[union-attr] + if self.health_check: + run_async_safe(lambda: _run_dripper_health_check(self.client, self.model_name, self.generation_config)) + self._initialized = True + + def process(self, batch: DocumentBatch) -> DocumentBatch: + if not self._initialized: + self.setup() + df = batch.to_pandas().copy() + if self.html_col not in df.columns: + msg = f"Input batch is missing required HTML column: {self.html_col!r}" + raise ValueError(msg) + results = run_async_safe(lambda: self._process_all_async(df)) + preprocess_times = _numeric_series_or_zero(df, _DRIPPER_PREPROCESS_TIME_COL) + inference_times = pd.Series([r.inference_time_s for r in results], index=df.index) + postprocess_times = pd.Series([r.postprocess_time_s for r in results], index=df.index) + for _col, _attr in [ + (_DRIPPER_OUTPUT_HTML_COL, "main_html"), + (_DRIPPER_OUTPUT_CONTENT_COL, "main_content"), + (_DRIPPER_RAW_RESPONSE_COL, "raw_response"), + (_DRIPPER_ERROR_COL, "error"), + (_DRIPPER_PROMPT_TOKENS_COL, "prompt_tokens"), + (_DRIPPER_COMPLETION_TOKENS_COL, "completion_tokens"), + (_DRIPPER_TOTAL_TOKENS_COL, "total_tokens"), + ("dripper_layout_cluster", "layout_cluster"), + ("dripper_layout_representative", "layout_representative"), + ("dripper_layout_propagated", "layout_propagated"), + ("dripper_layout_propagation_success", "layout_propagation_success"), + ("dripper_layout_fallback_llm", "layout_fallback_llm"), + ("dripper_layout_standalone_llm", "layout_standalone_llm"), + (_DRIPPER_LAYOUT_FINALIZED_COL, "layout_finalized"), + ]: + df[_col] = [getattr(r, _attr) for r in results] + df[_DRIPPER_INFERENCE_TIME_COL] = inference_times + df[_DRIPPER_POSTPROCESS_TIME_COL] = postprocess_times + df[_DRIPPER_TOTAL_TIME_COL] = preprocess_times + inference_times + postprocess_times + _existing_w = df.get(_DRIPPER_WARNING_COL, pd.Series([""] * len(df))).tolist() + df[_DRIPPER_WARNING_COL] = [_append_warning(str(e or ""), r.warning) for e, r in zip(_existing_w, results, strict=True)] # fmt: skip + if self.layout_defer_propagation: + df["dripper_layout_pending_propagation"] = [r.layout_pending_propagation for r in results] + df["dripper_layout_mapping_json"] = [r.layout_mapping_json for r in results] + if self.layout_defer_fallback_llm: + existing_primary_errors = df[_DRIPPER_PRIMARY_ERROR_COL].astype(str).tolist() + df[_DRIPPER_NEEDS_LLM_COL] = [r.deferred_llm for r in results] + df[_DRIPPER_PRIMARY_ERROR_COL] = [_append_warning(e, r.primary_error) for e, r in zip(existing_primary_errors, results, strict=True)] # fmt: skip + _base = [_DRIPPER_PROMPT_COL, _DRIPPER_NEEDS_LLM_COL, _DRIPPER_PRIMARY_ERROR_COL, _DRIPPER_EMPTY_INPUT_COL] + drop_cols = [] if self.layout_defer_fallback_llm else [*_base, _DRIPPER_LAYOUT_FINALIZED_COL] + if not self.keep_intermediate and not self.layout_defer_fallback_llm: + drop_cols.extend([_DRIPPER_SIMPLIFIED_HTML_COL, _DRIPPER_MAPPED_HTML_COL]) + df = df.drop(columns=[col for col in drop_cols if col in df.columns]) + _ma = [("layout_template_representative_rows", "layout_representative"), ("layout_template_propagated_rows", "layout_propagated"), ("layout_template_success_rows", "layout_propagation_success"), ("layout_template_fallback_llm_rows", "layout_fallback_llm"), ("layout_template_standalone_llm_rows", "layout_standalone_llm"), ("layout_template_deferred_llm_rows", "deferred_llm"), ("layout_template_finalized_rows", "layout_finalized")] # fmt: skip + self._log_metrics({"layout_template_rows": float(len(df))} | {k: float(sum(getattr(r, a) for r in results)) for k, a in _ma}) # fmt: skip + return _rebuild_batch(batch, df) + + async def _process_all_async(self, df: pd.DataFrame) -> list[_LayoutTemplateRowResult]: + propagation_semaphore = asyncio.Semaphore( + min(self.max_concurrent_requests, self.layout_propagation_concurrency) + ) + ctx = _LayoutProcessContext(df=df, semaphore=asyncio.Semaphore(self.max_concurrent_requests), propagation_semaphore=propagation_semaphore, inference_cache={}, inference_cache_lock=asyncio.Lock(), needs_llm=df[_DRIPPER_NEEDS_LLM_COL].astype(bool).tolist()) # fmt: skip + layout_plans = _build_layout_group_plans(self._planning_cfg, df) + grouped_indexes = {idx for plan in layout_plans for idx in plan.indexes} + + async def _handle_plan(plan_index: int, plan: _LayoutGroupPlan) -> dict[int, _LayoutTemplateRowResult]: + return await self._handle_group_attempt_async( + ctx, + plan.indexes, + f"layout-{plan_index:06d}", + plan.host_key, + plan.fallback_groups, + split_failed_host_fallback=True, + ) + + tasks: list[Any] = [_handle_plan(plan_index, plan) for plan_index, plan in enumerate(layout_plans)] + tasks.extend(self._handle_standalone_async(ctx, idx) for idx in range(len(df)) if idx not in grouped_indexes) + raw_results = await asyncio.gather(*tasks, return_exceptions=True) + results_by_index: dict[int, _LayoutTemplateRowResult] = {} + for raw_result in raw_results: + if isinstance(raw_result, BaseException): + logger.error("Dripper layout-template task failed: {}", raw_result) + continue + if isinstance(raw_result, tuple): + idx, result = raw_result + results_by_index[idx] = result + else: + results_by_index.update(raw_result) + _no_result_err = "layout template task produced no result" + return [results_by_index[idx] if idx in results_by_index else (self._defer_row(df.iloc[idx], primary_error=_no_result_err, layout_fallback_llm=True) if self.layout_defer_fallback_llm else self._fallback_row(df.iloc[idx], primary_error=_no_result_err)) for idx in range(len(df))] # fmt: skip + + async def _handle_standalone_async( + self, ctx: _LayoutProcessContext, idx: int + ) -> tuple[int, _LayoutTemplateRowResult]: + if self.layout_defer_fallback_llm: + return idx, self._defer_row( + ctx.df.iloc[idx], + layout_standalone_llm=ctx.needs_llm[idx], + primary_error="layout template standalone row", + ) + if ctx.needs_llm[idx]: + result = await self._infer_and_postprocess_row(ctx.df.iloc[idx], semaphore=ctx.semaphore, cache=ctx.inference_cache, cache_lock=ctx.inference_cache_lock, layout_standalone_llm=True) # fmt: skip + else: + result = self._fallback_row(ctx.df.iloc[idx]) + return idx, result + + async def _handle_group_attempt_async( # noqa: PLR0913 + self, + ctx: _LayoutProcessContext, + indexes: list[int], + cluster_id: str, + host_key: str, + fallback_groups: tuple[list[int], ...], + *, + split_failed_host_fallback: bool, + ) -> dict[int, _LayoutTemplateRowResult]: + outcome = await self._process_layout_group_with_status( + ctx, + indexes, + cluster_id, + emit_failure_fallback=not fallback_groups, + ) + if outcome.accepted or not fallback_groups: + return outcome.results + child_groups = list(fallback_groups) + if split_failed_host_fallback and self.layout_failed_host_fallback_signature_mode != "none": + child_groups = _split_fallback_groups_by_signature( + self._planning_cfg, ctx.df, child_groups, self.layout_failed_host_fallback_signature_mode + ) + fallback_results: dict[int, _LayoutTemplateRowResult] = {} + fallback_grouped_indexes: set[int] = set() + fallback_tasks = [self._handle_group_attempt_async(ctx, fallback_indexes, f"{cluster_id}-fallback-{fallback_index:06d}", host_key, tuple(_build_failed_layout_fallback_groups(self._planning_cfg, ctx.df, fallback_indexes)), split_failed_host_fallback=False) for fallback_index, fallback_indexes in enumerate(child_groups)] # fmt: skip + if fallback_tasks: + [fallback_results.update(gr) for gr in await asyncio.gather(*fallback_tasks)] + fallback_grouped_indexes = {idx for group in child_groups for idx in group} + standalone_tasks = [self._handle_standalone_async(ctx, idx) for idx in indexes if idx not in fallback_grouped_indexes] # fmt: skip + if standalone_tasks: + fallback_results.update(dict(await asyncio.gather(*standalone_tasks))) + return fallback_results + + async def _process_layout_group_with_status( + self, + ctx: _LayoutProcessContext, + indexes: list[int], + cluster_id: str, + *, + emit_failure_fallback: bool, + ) -> _LayoutGroupOutcome: + df = ctx.df + representative_idx, mapping_data, results, mapping_failures = await self._infer_representative_candidates( + ctx, indexes, cluster_id + ) + if mapping_data is None: + return await self._handle_mapping_failure(ctx, indexes, cluster_id, results, mapping_failures, emit_failure_fallback) # fmt: skip + if representative_idx is None: + msg = "representative_idx must not be None" + raise RuntimeError(msg) + sibling_indexes = [idx for idx in indexes if idx not in results] + validation_rows = self.layout_template_validation_rows + if ( + self.layout_template_large_cluster_validation_rows > 0 + and self.layout_template_large_cluster_min_size > 0 + and len(indexes) >= self.layout_template_large_cluster_min_size + ): + validation_rows = max(validation_rows, self.layout_template_large_cluster_validation_rows) + validation_indexes = _select_validation_indexes(df, sibling_indexes, validation_rows, (self.url_col, _DRIPPER_ITEM_COUNT_COL), signature_mode=self.layout_validation_signature_mode) # fmt: skip + remaining_indexes = [idx for idx in sibling_indexes if idx not in set(validation_indexes)] + validation_failed, validation_error = False, "" + if validation_indexes: + validation_failed, validation_error = await self._run_validation_rows_async( + ctx, validation_indexes, mapping_data, cluster_id, results + ) + if validation_failed: + logger.debug("Dripper layout validation failed for {}: {}", cluster_id, validation_error) + if not emit_failure_fallback: + return _LayoutGroupOutcome(results=results, accepted=False, failure_reason=validation_error) + sibling_outcome = await self._propagate_sibling_rows_async(ctx, remaining_indexes, mapping_data, cluster_id, results, validation_failed, validation_error) # fmt: skip + if sibling_outcome is not None: + return sibling_outcome + return _LayoutGroupOutcome(results=results) + + async def _handle_mapping_failure( # noqa: PLR0913 + self, + ctx: _LayoutProcessContext, + indexes: list[int], + cluster_id: str, + results: dict[int, _LayoutTemplateRowResult], + mapping_failures: list[str], + emit_failure_fallback: bool, + ) -> _LayoutGroupOutcome: + df = ctx.df + warning = "layout template mapping failed" + if mapping_failures: + warning = f"{warning}: {'; '.join(mapping_failures[:3])}" + if not emit_failure_fallback: + return _LayoutGroupOutcome(results=results, accepted=False, failure_reason=warning) + fallback_indexes = [idx for idx in indexes if idx not in results] + if self.layout_defer_fallback_llm: + for idx in fallback_indexes: + results[idx] = self._defer_row(df.iloc[idx], primary_error=warning, layout_cluster=cluster_id, layout_fallback_llm=True) # fmt: skip + elif self.layout_template_fallback_llm: + _fbs = [self._infer_and_postprocess_row(df.iloc[idx], semaphore=ctx.semaphore, cache=ctx.inference_cache, cache_lock=ctx.inference_cache_lock, layout_cluster=cluster_id, layout_fallback_llm=True, primary_error=warning) for idx in fallback_indexes] # fmt: skip + results.update(zip(fallback_indexes, await asyncio.gather(*_fbs), strict=True)) + else: + for idx in fallback_indexes: + results[idx] = replace( + self._fallback_row(df.iloc[idx], primary_error=warning), layout_cluster=cluster_id + ) + return _LayoutGroupOutcome(results=results, accepted=False, failure_reason=warning) + + async def _run_validation_rows_async( + self, + ctx: _LayoutProcessContext, + validation_indexes: list[int], + mapping_data: dict[str, Any], + cluster_id: str, + results: dict[int, _LayoutTemplateRowResult], + ) -> tuple[bool, str]: + _prop_coros = (self._propagate_layout_template_async(ctx.df.iloc[i], mapping_data, cluster_id, ctx.propagation_semaphore) for i in validation_indexes) # fmt: skip + _llm_coros = (self._infer_and_postprocess_row(ctx.df.iloc[i], semaphore=ctx.semaphore, cache=ctx.inference_cache, cache_lock=ctx.inference_cache_lock, layout_cluster=cluster_id, layout_fallback_llm=True, primary_error="layout template validation LLM") for i in validation_indexes) # fmt: skip + validation_propagated, validation_llm_results = await asyncio.gather(asyncio.gather(*_prop_coros), asyncio.gather(*_llm_coros)) # fmt: skip + failed, error = False, "" + for idx, propagated, llm_result in zip( + validation_indexes, validation_propagated, validation_llm_results, strict=True + ): + results[idx] = llm_result + content_f1 = _token_f1(propagated.main_content, llm_result.main_content) + failure_reasons = [] + if propagated.error: + failure_reasons.append(f"propagation_error={propagated.error[:160]}") + if content_f1 < self.layout_template_validation_min_content_f1: + failure_reasons.append(f"content_f1={content_f1:.3f}") + if failure_reasons: + failed = True + error = f"layout template validation failed: {' '.join(failure_reasons)} min={self.layout_template_validation_min_content_f1:.3f}" + return failed, error + + async def _propagate_sibling_rows_async( # noqa: PLR0913 + self, + ctx: _LayoutProcessContext, + remaining_indexes: list[int], + mapping_data: dict[str, Any], + cluster_id: str, + results: dict[int, _LayoutTemplateRowResult], + validation_failed: bool, + validation_error: str, + ) -> _LayoutGroupOutcome | None: + df = ctx.df + propagated_results: list[_LayoutTemplateRowResult] = [] + if remaining_indexes and not validation_failed: + if self.layout_defer_propagation: + for idx in remaining_indexes: + results[idx] = _LayoutTemplateRowResult( + layout_cluster=cluster_id, layout_pending_propagation=True, layout_finalized=False + ) + return _LayoutGroupOutcome(results=results) + propagated_results = list(await asyncio.gather(*(self._propagate_layout_template_async(df.iloc[idx], mapping_data, cluster_id, ctx.propagation_semaphore) for idx in remaining_indexes))) # fmt: skip + fallback_tasks: list[Any] = [] + fallback_indexes: list[int] = [] + for i, idx in enumerate(remaining_indexes): + error = ( + validation_error + if validation_failed + else (propagated_results[i].error if not validation_failed else "") + ) + propagated = None if validation_failed else propagated_results[i] + if validation_failed or (propagated is not None and propagated.error): + if self.layout_defer_fallback_llm: + results[idx] = self._defer_row(df.iloc[idx], primary_error=error, layout_cluster=cluster_id, layout_fallback_llm=True) # fmt: skip + elif self.layout_template_fallback_llm: + fallback_indexes.append(idx) + fallback_tasks.append(self._infer_and_postprocess_row(df.iloc[idx], semaphore=ctx.semaphore, cache=ctx.inference_cache, cache_lock=ctx.inference_cache_lock, layout_cluster=cluster_id, layout_fallback_llm=True, primary_error=error)) # fmt: skip + else: + results[idx] = replace( + self._fallback_row(df.iloc[idx], primary_error=error), layout_cluster=cluster_id + ) + elif propagated is not None: + results[idx] = propagated + if fallback_tasks: + fallback_results_list = await asyncio.gather(*fallback_tasks) + results.update(zip(fallback_indexes, fallback_results_list, strict=True)) + return None + + async def _infer_representative_candidates( + self, ctx: _LayoutProcessContext, indexes: list[int], cluster_id: str + ) -> tuple[int | None, dict[str, Any] | None, dict[int, _LayoutTemplateRowResult], list[str]]: + df = ctx.df + representative_indexes = self._select_representative_indexes(df, indexes) + representative_idx: int | None = None + mapping_data: dict[str, Any] | None = None + candidate_results: dict[int, _LayoutTemplateRowResult] = {} + mapping_failures: list[str] = [] + for candidate_idx in representative_indexes: + candidate_result, candidate_mapping = await self._infer_representative_and_mapping( + df.iloc[candidate_idx], ctx.semaphore, cluster_id, ctx.inference_cache, ctx.inference_cache_lock + ) + candidate_results[candidate_idx] = candidate_result + if candidate_mapping is not None: + representative_idx = candidate_idx + mapping_data = candidate_mapping + break + mapping_failures.append(f"{candidate_idx}:{candidate_result.primary_error or candidate_result.warning or 'mapping failed'}") # fmt: skip + results: dict[int, _LayoutTemplateRowResult] = {} + mapping_json_for_representative = json.dumps(mapping_data, default=str) if self.layout_defer_propagation and mapping_data is not None else "" # fmt: skip + for candidate_idx, candidate_result in candidate_results.items(): + is_rep = candidate_idx == representative_idx + results[candidate_idx] = replace(candidate_result, layout_cluster=cluster_id, layout_representative=is_rep, layout_fallback_llm=not is_rep, layout_mapping_json=mapping_json_for_representative if is_rep else "") # fmt: skip + return representative_idx, mapping_data, results, mapping_failures + + def _select_representative_indexes(self, df: pd.DataFrame, indexes: list[int]) -> list[int]: + candidates = [{"track_id": str(idx), "html": _coerce_html(df.iloc[idx].get(self.html_col, ""))} for idx in indexes] # fmt: skip + try: + rep = self._web_bindings.select_representative_html(candidates) + selected = int(rep["track_id"]) if rep is not None else indexes[0] + except Exception as exc: # noqa: BLE001 + logger.debug("Dripper representative selection failed: {}", exc) + selected = indexes[0] + if selected not in indexes: + selected = indexes[0] + result = [selected] + if self.layout_representative_candidates > 1: + result.extend(_select_validation_indexes(df, [idx for idx in indexes if idx != selected], self.layout_representative_candidates - 1, (self.url_col, _DRIPPER_ITEM_COUNT_COL))) # fmt: skip + return result + + async def _infer_representative_and_mapping( + self, + row: pd.Series, + semaphore: asyncio.Semaphore, + cluster_id: str, + inference_cache: _InferenceCache, + inference_cache_lock: asyncio.Lock, + ) -> tuple[_LayoutTemplateRowResult, dict[str, Any] | None]: + inference_result = await self._infer_row_cached(row, semaphore, inference_cache, inference_cache_lock) + started = time.perf_counter() + + def _make_fallback_result(primary_error: str, *, elapsed: float | None = None) -> _LayoutTemplateRowResult: + fb = self._fallback_and_convert(row, primary_error=primary_error) + return _LayoutTemplateRowResult(**_inference_token_fields(inference_result), main_html=fb.main_html, main_content=fb.main_content, postprocess_time_s=elapsed if elapsed is not None else fb.postprocess_time_s, error=fb.error, warning=fb.warning, primary_error=primary_error, layout_cluster=cluster_id) # fmt: skip + + if inference_result.primary_error: + return _make_fallback_result(_append_warning("", inference_result.primary_error)), None + html_text = _coerce_html(row.get(self.html_col, "")) + mapped_html = str(row.get(_DRIPPER_MAPPED_HTML_COL, "") or "") + case = self._build_case(row) + mapping_failure_reason = "" + try: + case.generate_output = self._bindings.generate_output_cls(response=inference_result.raw_response) + case = self._bindings.parse_result(case) + webkit_response = _labels_to_webkit_response(getattr(case.parse_result, "item_label", {})) + case = self._bindings.extract_main_html_single(case) + mapping_data = self._web_bindings.map_parser_cls({}).parse({"typical_raw_tag_html": mapped_html, "typical_raw_html": html_text, "llm_response": webkit_response}) # fmt: skip + if self.layout_template_require_success and mapping_data.get("typical_main_html_success") is False: + mapping_failure_reason = "typical_main_html_success=false" + mapping_data = None + except Exception as exc: # noqa: BLE001 + primary_error = str(exc) + logger.debug("Dripper representative mapping failed: {}", primary_error) + return _make_fallback_result(primary_error, elapsed=time.perf_counter() - started), None + post_result = self._convert_case(case) + warning = post_result.warning + if mapping_data is None: + primary_error = f"layout template mapping failed: {mapping_failure_reason or 'template unusable'}" + warning = _append_warning(warning, primary_error) + else: + primary_error = "" + mapping_data = dict(mapping_data) + mapping_data["_dripper_representative_content_len"] = len(str(post_result.main_content or "")) + return _LayoutTemplateRowResult(**_inference_token_fields(inference_result), main_html=post_result.main_html, main_content=post_result.main_content, postprocess_time_s=time.perf_counter() - started, error=post_result.error, warning=warning, primary_error=primary_error, layout_cluster=cluster_id), mapping_data # fmt: skip + + async def _propagate_layout_template_async( + self, + row: pd.Series, + mapping_data: dict[str, Any], + cluster_id: str, + semaphore: asyncio.Semaphore, + ) -> _LayoutTemplateRowResult: + async with semaphore: + return await asyncio.to_thread(self._propagate_layout_template, row, mapping_data, cluster_id) + + def _propagate_layout_template( + self, + row: pd.Series, + mapping_data: dict[str, Any], + cluster_id: str, + ) -> _LayoutTemplateRowResult: + started = time.perf_counter() + html_text = _coerce_html(row.get(self.html_col, "")) + mapped_html = str(row.get(_DRIPPER_MAPPED_HTML_COL, "") or "") + use_mapped_item_ids = ( + self.layout_template_propagation_target == "mapped_item_ids" and "_item_id" in mapped_html + ) + html_source = mapped_html if use_mapped_item_ids else html_text + try: + task_data = dict(mapping_data) | { + "html_source": html_source, + "dynamic_id_enable": True, + "dynamic_classid_enable": True, + "more_noise_enable": self.layout_template_more_noise_enable, + "dynamic_classid_similarity_threshold": self.dynamic_classid_similarity_threshold, + } + parts = self._web_bindings.layout_parser_cls({}).parse(task_data) + if self.layout_template_require_success and parts.get("main_html_success") is False: + raise RuntimeError(f"layout propagation similarity below threshold: {parts.get('main_html_sim')}") # noqa: TRY301, EM102 + if self.layout_template_min_main_html_sim is not None: + main_html_sim = _coerce_optional_float(parts.get("main_html_sim")) + if main_html_sim is not None and main_html_sim < self.layout_template_min_main_html_sim: + msg = f"layout propagation main_html_sim {main_html_sim:.3f} below {self.layout_template_min_main_html_sim:.3f}" + raise RuntimeError(msg) # noqa: TRY301 + main_html = str(parts.get("main_html_body") or "") + raw_response = "" + if use_mapped_item_ids: + all_item_ids = _item_ids_in_html(mapped_html) + main_item_ids = set(_item_ids_in_html(main_html)) + if not all_item_ids: + raise RuntimeError("layout propagation target mapped HTML has no item ids") # noqa: TRY301, EM101 + if not main_item_ids: + raise RuntimeError("layout propagation produced no target item ids") # noqa: TRY301, EM101 + selected_item_ratio = len(main_item_ids) / len(all_item_ids) + if ( + self.layout_template_max_selected_item_ratio is not None + and selected_item_ratio > self.layout_template_max_selected_item_ratio + ): + msg = f"layout propagation selected item ratio {selected_item_ratio:.3f} exceeds {self.layout_template_max_selected_item_ratio:.3f}" + raise RuntimeError(msg) # noqa: TRY301 + raw_response = _item_id_response(all_item_ids, main_item_ids) + post_result = self._postprocess_raw_response(row, raw_response) + else: + _case = self._build_case(row) + _case.output_data = self._bindings.output_cls(main_html=main_html) + post_result = self._convert_case(_case) + content_ratio_error = self._propagated_content_length_ratio_error(post_result.main_content, mapping_data) + if content_ratio_error: + raise RuntimeError(content_ratio_error) # noqa: TRY301 + return _LayoutTemplateRowResult(raw_response=raw_response, main_html=post_result.main_html, main_content=post_result.main_content, postprocess_time_s=time.perf_counter() - started, error=post_result.error, warning=post_result.warning, layout_cluster=cluster_id, layout_propagated=True, layout_propagation_success=not bool(post_result.error)) # fmt: skip + except Exception as exc: # noqa: BLE001 + primary_error = str(exc) + logger.debug("Dripper layout propagation failed: {}", primary_error) + fallback_result = self._fallback_and_convert(row, primary_error=primary_error) + return _LayoutTemplateRowResult(main_html=fallback_result.main_html, main_content=fallback_result.main_content, postprocess_time_s=time.perf_counter() - started, error=fallback_result.error or primary_error, warning=fallback_result.warning, primary_error=primary_error, layout_cluster=cluster_id, layout_propagated=True) # fmt: skip + + def _propagated_content_length_ratio_error(self, propagated_content: object, mapping_data: dict[str, Any]) -> str: + min_r, max_r = self.layout_template_min_content_length_ratio, self.layout_template_max_content_length_ratio + if min_r is None and max_r is None: + return "" + rep_len = _coerce_positive_int(mapping_data.get("_dripper_representative_content_len")) + if rep_len <= 0: + return "" + ratio = len(str(propagated_content or "")) / rep_len + if min_r is not None and ratio < min_r: + return f"layout propagation content length ratio {ratio:.3f} below {min_r:.3f}" + if max_r is not None and ratio > max_r: + return f"layout propagation content length ratio {ratio:.3f} exceeds {max_r:.3f}" + return "" + + async def _infer_and_postprocess_row( # noqa: PLR0913 + self, + row: pd.Series, + *, + semaphore: asyncio.Semaphore | None = None, + cache: _InferenceCache | None = None, + cache_lock: asyncio.Lock | None = None, + layout_cluster: str = "", + layout_fallback_llm: bool = False, + layout_standalone_llm: bool = False, + primary_error: str = "", + ) -> _LayoutTemplateRowResult: + if cache is None or cache_lock is None: + prompt = str(row.get(_DRIPPER_PROMPT_COL, "") or "") + row_max_tokens = _coerce_usage_int(row.get(_DRIPPER_REQUEST_MAX_TOKENS_COL, 0)) + inference_result = await self._infer_prompt(prompt, row_max_tokens, semaphore) + else: + inference_result = await self._infer_row_cached(row, semaphore, cache, cache_lock) + if inference_result.primary_error: + merged_primary = _append_warning(primary_error, inference_result.primary_error) + fb = self._fallback_and_convert(row, primary_error=merged_primary) + return _LayoutTemplateRowResult(**_inference_token_fields(inference_result), main_html=fb.main_html, main_content=fb.main_content, postprocess_time_s=fb.postprocess_time_s, error=fb.error, warning=fb.warning, primary_error=merged_primary, layout_cluster=layout_cluster, layout_fallback_llm=layout_fallback_llm, layout_standalone_llm=layout_standalone_llm) # fmt: skip + post_result = self._postprocess_raw_response(row, inference_result.raw_response) + return _LayoutTemplateRowResult(**_inference_token_fields(inference_result), main_html=post_result.main_html, main_content=post_result.main_content, postprocess_time_s=post_result.postprocess_time_s, error=post_result.error, warning=_append_warning(primary_error, post_result.warning), layout_cluster=layout_cluster, layout_fallback_llm=layout_fallback_llm, layout_standalone_llm=layout_standalone_llm) # fmt: skip + + async def _infer_row_cached( + self, + row: pd.Series, + semaphore: asyncio.Semaphore, + inference_cache: _InferenceCache, + inference_cache_lock: asyncio.Lock, + ) -> _DripperInferenceResult: + prompt = str(row.get(_DRIPPER_PROMPT_COL, "") or "") + row_max_tokens = _coerce_usage_int(row.get(_DRIPPER_REQUEST_MAX_TOKENS_COL, 0)) + if not prompt.strip(): + return _DripperInferenceResult(primary_error="empty Dripper prompt", warning="empty Dripper prompt") + key = (prompt, row_max_tokens) + async with inference_cache_lock: + task = inference_cache.get(key) + owns_request = task is None + if task is None: + task = asyncio.create_task(self._infer_prompt(prompt, row_max_tokens, semaphore)) + inference_cache[key] = task + result = await task + if owns_request: + return result + return replace(result, inference_time_s=0.0, prompt_tokens=0, completion_tokens=0, total_tokens=0) + + async def _infer_prompt( + self, + prompt: str, + row_max_tokens: int, + semaphore: asyncio.Semaphore, + ) -> _DripperInferenceResult: + if not prompt.strip(): + return _DripperInferenceResult(primary_error="empty Dripper prompt", warning="empty Dripper prompt") + async with semaphore: + started = time.perf_counter() + try: + generation_config = self.generation_config or GenerationConfig() + if row_max_tokens > 0 and generation_config.max_tokens != row_max_tokens: + generation_config = replace(generation_config, max_tokens=row_max_tokens) + generation_config = _with_structured_output_config(generation_config, prompt, self.structured_output_mode) # fmt: skip + raw_response, prompt_tokens, completion_tokens, total_tokens = await _query_dripper_model(self.client, self.model_name, [{"role": "user", "content": prompt}], generation_config) # fmt: skip + except Exception as exc: # noqa: BLE001 + error = str(exc) + logger.debug("Dripper inference failed; postprocess stage will apply fallback: {}", error) + return _DripperInferenceResult(inference_time_s=time.perf_counter() - started, primary_error=error, warning=error) # fmt: skip + return _DripperInferenceResult(raw_response=raw_response, inference_time_s=time.perf_counter() - started, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens) # fmt: skip + + def _postprocess_raw_response(self, row: pd.Series, raw_response: str) -> _DripperPostResult: + started = time.perf_counter() + case = self._build_case(row) + try: + case.generate_output = self._bindings.generate_output_cls(response=raw_response) + case = self._bindings.parse_result(case) + case = self._bindings.extract_main_html_single(case) + except Exception as exc: # noqa: BLE001 + pe = str(exc) + logger.debug("Dripper parse/extract failed, applying {} fallback: {}", self.fallback, pe) + return replace( + self._fallback_and_convert(row, primary_error=pe), postprocess_time_s=time.perf_counter() - started + ) + return replace(self._convert_case(case), postprocess_time_s=time.perf_counter() - started) + + def _fallback_row(self, row: pd.Series, *, primary_error: str = "") -> _LayoutTemplateRowResult: + r = self._fallback_and_convert(row, primary_error=_append_warning(primary_error, str(row.get(_DRIPPER_PRIMARY_ERROR_COL, "") or ""))) # fmt: skip + return _LayoutTemplateRowResult(main_html=r.main_html, main_content=r.main_content, postprocess_time_s=r.postprocess_time_s, error=r.error, warning=r.warning, primary_error=primary_error) # fmt: skip + + def _defer_row(self, row: pd.Series, *, primary_error: str = "", layout_cluster: str = "", layout_fallback_llm: bool = False, layout_standalone_llm: bool = False) -> _LayoutTemplateRowResult: # fmt: skip + nlm = bool(row.get(_DRIPPER_NEEDS_LLM_COL, False)) + return _LayoutTemplateRowResult(raw_response=str(row.get(_DRIPPER_RAW_RESPONSE_COL, "") or ""), inference_time_s=float(row.get(_DRIPPER_INFERENCE_TIME_COL, 0.0) or 0.0), prompt_tokens=_coerce_usage_int(row.get(_DRIPPER_PROMPT_TOKENS_COL, 0)), completion_tokens=_coerce_usage_int(row.get(_DRIPPER_COMPLETION_TOKENS_COL, 0)), total_tokens=_coerce_usage_int(row.get(_DRIPPER_TOTAL_TOKENS_COL, 0)), error=str(row.get(_DRIPPER_ERROR_COL, "") or ""), warning=_append_warning(str(row.get(_DRIPPER_WARNING_COL, "") or ""), primary_error), primary_error=primary_error, deferred_llm=nlm, layout_finalized=False, layout_cluster=layout_cluster, layout_fallback_llm=layout_fallback_llm and nlm, layout_standalone_llm=layout_standalone_llm and nlm) # fmt: skip + + def _build_case(self, row: pd.Series) -> object: + html_text = _coerce_html(row.get(self.html_col, "")) + url = _coerce_optional_str(row.get(self.url_col) if self.url_col else None) + case = self._bindings.case_cls(self._bindings.input_cls(raw_html=html_text, url=url)) + simplified_html = str(row.get(_DRIPPER_SIMPLIFIED_HTML_COL, "") or "") + mapped_html = str(row.get(_DRIPPER_MAPPED_HTML_COL, "") or "") + if simplified_html or mapped_html: + case.process_data = self._bindings.process_data_cls(simpled_html=simplified_html, map_html=mapped_html) + return case + + def _fallback_and_convert(self, row: pd.Series, *, primary_error: str = "") -> _DripperPostResult: + started = time.perf_counter() + case = self._build_case(row) + if bool(row.get(_DRIPPER_EMPTY_INPUT_COL, False)) or not _coerce_html(row.get(self.html_col, "")).strip(): + return _DripperPostResult(postprocess_time_s=time.perf_counter() - started, warning=_append_warning(primary_error, "empty HTML input")) # fmt: skip + fallback_result = _apply_fallback_extraction(self._bindings, self._fallback_handler, case, primary_error) + case = fallback_result[0] + if fallback_result[2]: + return _DripperPostResult(postprocess_time_s=time.perf_counter() - started, error=fallback_result[2], warning=fallback_result[1]) # fmt: skip + result = self._convert_case(case, warning=fallback_result[1]) + return replace(result, postprocess_time_s=time.perf_counter() - started) + + def _convert_case(self, case: object, *, warning: str = "") -> _DripperPostResult: + conversion_error = "" + try: + _sanitize_case_output_html(case) + case = self._bindings.convert2content(case, output_format=self.output_format) + except (TypeError, AttributeError, ValueError, RuntimeError) as exc: + conversion_error = str(exc) + logger.debug("Dripper content conversion failed: {}", conversion_error) + output_data = getattr(case, "output_data", None) + main_html = getattr(output_data, "main_html", "") if output_data is not None else "" + main_content = getattr(output_data, "main_content", "") if output_data is not None else "" + main_content = "" if main_content is None else main_content + error = "" + if conversion_error: + if _is_empty_document_error(conversion_error) and not str(main_html).strip(): + warning = _append_warning(warning, conversion_error) + else: + error = conversion_error + return _DripperPostResult(main_html=main_html, main_content=main_content, error=error, warning=warning) diff --git a/nemo_curator/stages/text/experimental/dripper/propagation_stage.py b/nemo_curator/stages/text/experimental/dripper/propagation_stage.py new file mode 100644 index 0000000000..02eafa500e --- /dev/null +++ b/nemo_curator/stages/text/experimental/dripper/propagation_stage.py @@ -0,0 +1,365 @@ +from __future__ import annotations + +import contextlib +import json +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from loguru import logger + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.text.experimental.dripper._layout_planning import _token_f1 +from nemo_curator.stages.text.experimental.dripper.stage import ( + _coerce_html, + _convert_main_html, + _load_llm_web_kit_bindings, + _load_mineru_html_bindings, + _MinerUHTMLBindings, + _rebuild_batch, + _strip_xml_incompatible_chars, +) +from nemo_curator.tasks import DocumentBatch + +if TYPE_CHECKING: + import pandas as pd + + +_PENDING_COL = "dripper_layout_pending_propagation" +_MAPPING_COL = "dripper_layout_mapping_json" +_CLUSTER_COL = "dripper_layout_cluster" +_REPRESENTATIVE_COL = "dripper_layout_representative" + +# Number of siblings sampled to validate static-LBP trustworthiness per cluster. +_K_SAMPLE_SIBLINGS = 3 + +# Maximum HTML bytes forwarded to the content converter (guards against OOM). +_MAX_CONTENT_HTML_BYTES = 200_000 + + +@dataclass +class _StaticTrustConfig: + memo: dict[str, bool] + lbp_fn: Any # (html, mapping_data, dynamic) -> (str, str) + content_fn: Any # (main_html, url) -> (str, str) + threshold: float + + +@dataclass +class _PropagationConfig: + lbp_fn: Any # (html, mapping_data, dynamic) -> (str, str) + content_fn: Any # (main_html, url) -> (str, str) + min_ratio: float + max_ratio: float + + +def _run_lbp( + params: dict[str, Any], + html: str, + mapping_data: dict[str, Any], + dynamic: bool, + _parser_cache: dict | None = None, +) -> tuple[str, str]: + html_source = html.strip() + if not html_source: + return "", "empty_html" + try: + from llm_web_kit.main_html_parser.parser.layout_batch_parser import LayoutBatchParser + + task_data = dict(mapping_data) + if "_parsed_element_dict" in task_data: + task_data["html_element_dict"] = task_data.pop("_parsed_element_dict") + task_data["html_source"] = html_source + task_data["dynamic_id_enable"] = task_data["dynamic_classid_enable"] = dynamic + task_data["more_noise_enable"] = params.get("more_noise_enable", True) + task_data["dynamic_classid_similarity_threshold"] = params.get("dynamic_classid_similarity_threshold", 0.70) + element_dict = task_data.get("html_element_dict") + cache_key = id(element_dict) if element_dict is not None else None + if _parser_cache is not None and cache_key is not None: + if cache_key not in _parser_cache: + _parser_cache[cache_key] = LayoutBatchParser({}) + parser = _parser_cache[cache_key] + else: + parser = LayoutBatchParser({}) + parts = parser.parse(task_data) + except Exception as exc: # noqa: BLE001 + return "", f"layout_parser_error={exc!s:.200}" + main_html = str(parts.get("main_html_body") or "") + if not main_html.strip(): + if parts.get("main_html_success") is False: + return "", f"main_html_success_false sim={parts.get('main_html_sim', 'n/a')}" + return "", "layout_parser_empty_output" + return main_html, "" + + +def _run_content_convert( + bindings: _MinerUHTMLBindings, + main_html: str, + url: str, +) -> tuple[str, str]: + if len(main_html) > _MAX_CONTENT_HTML_BYTES: + main_html = main_html[:_MAX_CONTENT_HTML_BYTES] + try: + sanitized = _strip_xml_incompatible_chars(main_html) + content = _convert_main_html(bindings, sanitized, url) + return str(content or ""), "" + except Exception as exc: # noqa: BLE001 + return "", f"content_conversion_error={exc!s:.150}" + + +def _cluster_static_trustworthy( + cluster_id: object, + sample_rows: list[dict[str, Any]], + mapping_data: dict[str, Any], + cfg: _StaticTrustConfig, +) -> bool: + if mapping_data is None: + return False + key = str(cluster_id) + if key in cfg.memo: + return cfg.memo[key] + f1s: list[float] = [] + for row in sample_rows[:_K_SAMPLE_SIBLINGS]: + html = _coerce_html(row.get("html", "")) + if not html.strip(): + continue + sh, se = cfg.lbp_fn(html, mapping_data, False) + dh, de = cfg.lbp_fn(html, mapping_data, True) + if not dh or de: + continue + url = row.get("url", "") + if not sh or se: + f1s.append(0.0) + else: + sc, _ = cfg.content_fn(sh, url) + dc, _ = cfg.content_fn(dh, url) + f1s.append(_token_f1(sc, dc)) + ok = bool(f1s) and (sum(f1s) / len(f1s) >= cfg.threshold) + cfg.memo[key] = ok + return ok + + +def _lbp_once( + html: str, + url: str, + mapping_data: dict[str, Any], + dynamic: bool, + prop_cfg: _PropagationConfig, +) -> tuple[str, str, str]: + lh, le = prop_cfg.lbp_fn(html, mapping_data, dynamic) + if not lh or le: + return "", "", le + rc, ce = prop_cfg.content_fn(lh, url) + if ce: + return "", "", ce + rep_len = (mapping_data or {}).get("_dripper_representative_content_len") + if rep_len and rep_len > 0: + ratio = len(rc) / rep_len + if ratio < prop_cfg.min_ratio: + return "", "", f"content_length_ratio_low={ratio:.3f}" + if ratio > prop_cfg.max_ratio: + return "", "", f"content_length_ratio_high={ratio:.3f}" + return lh, rc, "" + + +def _sibling_propagate( + row: dict[str, Any], + mapping_data: dict[str, Any] | None, + use_static: bool, + prop_cfg: _PropagationConfig, +) -> tuple[str, str, str, str]: + url = row.get("url", "") + html = _coerce_html(row.get("html", "")) + method, main_html, content, error = "fallback", "", "", "" + + if mapping_data is not None: + if use_static: + main_html, content, error = _lbp_once(html, url, mapping_data, False, prop_cfg) + if main_html: + method = "lbp_static" + if not main_html: + dh, dc, de = _lbp_once(html, url, mapping_data, True, prop_cfg) + if dh: + main_html, method, content, error = dh, "layout_batch_parser", dc, "" + elif de: + error = f"static_failed({error}); dynamic_failed({de})" if error else de + + if not main_html: + method = "fallback" + error = error or "no_template_available" + + return main_html, content, error, method + + +@dataclass(kw_only=True) +class DripperHTMLLayoutPropagationStage(ProcessingStage[DocumentBatch, DocumentBatch]): + html_col: str = "html" + output_html_col: str = "dripper_html" + output_content_col: str = "dripper_content" + postprocess_time_col: str = "dripper_postprocess_time_s" + error_col: str = "dripper_error" + url_col: str = "url" + + dynamic_classid_similarity_threshold: float = 0.85 + more_noise_enable: bool = True + layout_template_validation_min_content_f1: float = 0.95 + layout_template_min_content_length_ratio: float | None = 0.25 + layout_template_max_content_length_ratio: float | None = 4.0 + propagation_target: str = "raw_html" + use_static_lbp: bool = True + static_validation_min_f1: float = 0.97 + + _bindings: Any = field(init=False, repr=False, default=None) + _web_bindings: Any = field(init=False, repr=False, default=None) + _cluster_static_ok: dict = field(init=False, repr=False, default_factory=dict) + + def outputs(self) -> tuple[list[str], list[str]]: + return ["data"], [ + self.output_html_col, + self.output_content_col, + self.postprocess_time_col, + self.error_col, + "dripper_layout_propagated", + "dripper_layout_propagation_success", + "dripper_layout_propagation_method", + _PENDING_COL, + ] + + def setup(self, worker_metadata: Any = None) -> None: # noqa: ANN401, ARG002 + if self._bindings is not None: + return + self._bindings = _load_mineru_html_bindings() + self._web_bindings = _load_llm_web_kit_bindings() + self._cluster_static_ok = {} + + def _make_lbp_fn(self, parser_cache: dict | None = None) -> Any: # noqa: ANN401 # returns Callable[[str, dict, bool], tuple[str, str]] + params = { + "more_noise_enable": self.more_noise_enable, + "dynamic_classid_similarity_threshold": self.dynamic_classid_similarity_threshold, + } + + def _lbp(html: str, mapping_data: dict, dynamic: bool = True) -> tuple[str, str]: + return _run_lbp(params, html, mapping_data, dynamic, _parser_cache=parser_cache) + + return _lbp + + def _make_content_fn(self) -> Any: # noqa: ANN401 # returns Callable[[str, str], tuple[str, str]] + bindings = self._bindings + + def _content(main_html: str, url: str) -> tuple[str, str]: + return _run_content_convert(bindings, main_html, url) + + return _content + + def _make_prop_cfg(self, parser_cache: dict | None = None) -> _PropagationConfig: + return _PropagationConfig( + lbp_fn=self._make_lbp_fn(parser_cache), + content_fn=self._make_content_fn(), + min_ratio=self.layout_template_min_content_length_ratio or 0.0, + max_ratio=self.layout_template_max_content_length_ratio or float("inf"), + ) + + def _make_trust_cfg(self, parser_cache: dict | None = None) -> _StaticTrustConfig: + return _StaticTrustConfig( + memo=self._cluster_static_ok, + lbp_fn=self._make_lbp_fn(parser_cache), + content_fn=self._make_content_fn(), + threshold=self.static_validation_min_f1, + ) + + def process(self, batch: DocumentBatch) -> DocumentBatch: # noqa: C901, PLR0912, PLR0915 + if self._bindings is None: + self.setup() + + df = batch.to_pandas() + + if _PENDING_COL not in df.columns: + return batch + + pending_mask = df[_PENDING_COL].astype(bool) + if not pending_mask.any(): + return batch + + mapping_by_cluster: dict[str, dict[str, Any]] = {} + if _MAPPING_COL in df.columns and _REPRESENTATIVE_COL in df.columns: + rep_rows = df[df[_REPRESENTATIVE_COL].astype(bool)] + for _, row in rep_rows.iterrows(): + mapping_json = str(row.get(_MAPPING_COL) or "") + cluster = str(row.get(_CLUSTER_COL) or "") + if mapping_json and cluster: + with contextlib.suppress(Exception): + mapping_by_cluster[cluster] = json.loads(mapping_json) + + cluster_pending: dict[str, list] = {} + for idx in df.index[pending_mask]: + cid = str(df.loc[idx, _CLUSTER_COL] if _CLUSTER_COL in df.columns else "") + cluster_pending.setdefault(cid, []).append(idx) + + for cid, idxs in cluster_pending.items(): + mapping_data = mapping_by_cluster.get(cid) + parser_cache: dict = {} + prop_cfg = self._make_prop_cfg(parser_cache) + + # memoised: validate static-LBP trustworthiness once per cluster + use_static = False + if self.use_static_lbp and mapping_data is not None: + sample_rows = [df.loc[i].to_dict() for i in idxs[:_K_SAMPLE_SIBLINGS]] + trust_cfg = self._make_trust_cfg(parser_cache) + use_static = _cluster_static_trustworthy(cid, sample_rows, mapping_data, trust_cfg) + + for idx in idxs: + row = df.loc[idx] + t0 = time.perf_counter() + propagated_html = "" + propagated_content = "" + error = "" + success = False + method = "fallback" + + if mapping_data is None: + error = f"no_mapping_data_for_cluster={cid}" + else: + try: + row_dict = row.to_dict() + propagated_html, propagated_content, error, method = _sibling_propagate( + row_dict, mapping_data, use_static, prop_cfg + ) + if propagated_html and not error: + success = True + except Exception as exc: # noqa: BLE001 + error = f"propagation_exception={exc!s:.200}" + + elapsed = time.perf_counter() - t0 + df.loc[idx, self.output_html_col] = propagated_html + df.loc[idx, self.output_content_col] = propagated_content + df.loc[idx, self.postprocess_time_col] = elapsed + df.loc[idx, self.error_col] = error + df.loc[idx, "dripper_layout_propagated"] = True + df.loc[idx, "dripper_layout_propagation_success"] = success + df.loc[idx, "dripper_layout_propagation_method"] = method + df.loc[idx, _PENDING_COL] = False # consumed + + n_pending = int(pending_mask.sum()) + n_success = ( + int(df["dripper_layout_propagation_success"].sum()) + if "dripper_layout_propagation_success" in df.columns + else 0 + ) + logger.info( + "DripperHTMLLayoutPropagationStage: propagated {}/{} rows in batch", + n_success, + n_pending, + ) + return _rebuild_batch(batch, df) + + def _run_propagation( + self, + row: pd.Series, + mapping_data: dict[str, Any], + ) -> tuple[str, str, str]: + if self._bindings is None: + self.setup() + row_dict = row.to_dict() if hasattr(row, "to_dict") else dict(row) + prop_cfg = self._make_prop_cfg() + main_html, content, error, _ = _sibling_propagate(row_dict, mapping_data, False, prop_cfg) + return main_html, content, error diff --git a/nemo_curator/stages/text/experimental/dripper/stage.py b/nemo_curator/stages/text/experimental/dripper/stage.py new file mode 100644 index 0000000000..331ceb761a --- /dev/null +++ b/nemo_curator/stages/text/experimental/dripper/stage.py @@ -0,0 +1,424 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import re +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING, Any + +import pandas as pd +from loguru import logger + +from nemo_curator.models.client.llm_client import GenerationConfig +from nemo_curator.tasks import DocumentBatch + +if TYPE_CHECKING: + from collections.abc import Callable + + from nemo_curator.models.client.llm_client import AsyncLLMClient + + +@dataclass(frozen=True) +class _MinerUHTMLBindings: + input_cls: type + case_cls: type + output_cls: type + process_data_cls: type + generate_output_cls: type + simplify_single_input: Callable[[Any], Any] + build_prompt: Callable[..., Any] + parse_result: Callable[[Any], Any] + extract_main_html_single: Callable[[Any], Any] + extract_main_html_fallback: Callable[..., Any] + convert2content: Callable[..., Any] + get_fallback_handler: Callable[[str], Any] + + +def _always_similar(_left: object, _right: object, _max_layer_n: int) -> float: + return 1.0 + + +@dataclass(frozen=True) +class _LLMWebKitBindings: + get_feature: Callable[[str], Any] + cluster_html_struct: Callable[..., Any] + select_representative_html: Callable[[list[dict[str, str]]], dict[str, str] | None] + map_parser_cls: type + layout_parser_cls: type + similarity: Callable[..., float] = _always_similar + + +@dataclass(frozen=True) +class _DripperRowResult: + main_html: str = "" + main_content: Any = "" + raw_response: str = "" + preprocess_time_s: float = 0.0 + inference_time_s: float = 0.0 + postprocess_time_s: float = 0.0 + total_time_s: float = 0.0 + error: str = "" + warning: str = "" + simplified_html: str = "" + mapped_html: str = "" + item_count: int = 0 + prompt_chars: int = 0 + request_max_tokens: int = 0 + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +@dataclass(frozen=True) +class _DripperInferenceResult: + raw_response: str = "" + inference_time_s: float = 0.0 + primary_error: str = "" + warning: str = "" + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +@dataclass(frozen=True) +class _DripperPostResult: + main_html: str = "" + main_content: Any = "" + postprocess_time_s: float = 0.0 + error: str = "" + warning: str = "" + + +@dataclass(frozen=True) +class _DripperPrepResult: + empty_input: bool = False + needs_llm: bool = False + preprocess_time_s: float = 0.0 + warning: str = "" + primary_error: str = "" + simplified_html: str = "" + mapped_html: str = "" + item_count: int = 0 + prompt: str = "" + prompt_chars: int = 0 + request_max_tokens: int = 0 + + +_DRIPPER_PROMPT_COL = "_dripper_prompt" +_DRIPPER_NEEDS_LLM_COL = "_dripper_needs_llm" +_DRIPPER_PRIMARY_ERROR_COL = "_dripper_primary_error" +_DRIPPER_EMPTY_INPUT_COL = "_dripper_empty_input" +_DRIPPER_LAYOUT_FINALIZED_COL = "_dripper_layout_finalized" + + +def _load_mineru_html_bindings() -> _MinerUHTMLBindings: + from mineru_html.base import ( + MinerUHTMLCase, + MinerUHTMLGenerateOutput, + MinerUHTMLInput, + MinerUHTMLOutput, + MinerUHTMLProcessData, + ) + from mineru_html.process import ( + build_prompt, + convert2content, + extract_main_html_fallback, + extract_main_html_single, + get_fallback_handler, + parse_result, + simplify_single_input, + ) + + return _MinerUHTMLBindings( + input_cls=MinerUHTMLInput, + case_cls=MinerUHTMLCase, + output_cls=MinerUHTMLOutput, + process_data_cls=MinerUHTMLProcessData, + generate_output_cls=MinerUHTMLGenerateOutput, + simplify_single_input=simplify_single_input, + build_prompt=build_prompt, + parse_result=parse_result, + extract_main_html_single=extract_main_html_single, + extract_main_html_fallback=extract_main_html_fallback, + convert2content=convert2content, + get_fallback_handler=get_fallback_handler, + ) + + +def _load_llm_web_kit_bindings() -> _LLMWebKitBindings: + from llm_web_kit.html_layout.html_layout_cosin import get_feature, similarity + from llm_web_kit.main_html_parser.parser.layout_batch_parser import LayoutBatchParser + from llm_web_kit.main_html_parser.parser.tag_mapping import MapItemToHtmlTagsParser + from llm_web_kit.main_html_parser.typical_html.typical_html import select_representative_html + + # Use GPU-accelerated DBSCAN when available (cuML + cupy), falls back to sklearn + from nemo_curator.stages.text.experimental.dripper.gpu_layout_clustering import ( + cluster_html_struct_gpu, + ) + + return _LLMWebKitBindings( + get_feature=get_feature, + cluster_html_struct=cluster_html_struct_gpu, + select_representative_html=select_representative_html, + map_parser_cls=MapItemToHtmlTagsParser, + layout_parser_cls=LayoutBatchParser, + similarity=similarity, + ) + + +async def _run_dripper_health_check( + client: AsyncLLMClient, + model_name: str, + generation_config: GenerationConfig | None, +) -> None: + extra_kwargs = generation_config.extra_kwargs if generation_config is not None else None + hc_config = GenerationConfig(max_tokens=8, temperature=0.0, top_p=1.0, extra_kwargs=extra_kwargs) + try: + response = await client.query_model( + model=model_name, + generation_config=hc_config, + messages=[{"role": "user", "content": 'Return exactly: "1main"'}], + ) + except RuntimeError: + raise + except Exception as exc: + msg = f"Dripper LLM health check failed: {exc}. Ensure the inference server is reachable." + raise RuntimeError(msg) from exc + result = response[0] if response else "" + if not result: + msg = "Dripper LLM health check returned an empty response" + raise RuntimeError(msg) + logger.info("Dripper LLM health check passed") + + +async def _query_dripper_model( + client: AsyncLLMClient, + model_name: str, + messages: list[dict[str, str]], + generation_config: GenerationConfig, +) -> tuple[str, int, int, int]: + query_model_with_usage = getattr(client, "query_model_with_usage", None) + if callable(query_model_with_usage): + response = await query_model_with_usage( + model=model_name, + messages=messages, + generation_config=generation_config, + ) + contents = getattr(response, "contents", []) + return ( + contents[0] if contents else "", + _coerce_usage_int(getattr(response, "prompt_tokens", None)), + _coerce_usage_int(getattr(response, "completion_tokens", None)), + _coerce_usage_int(getattr(response, "total_tokens", None)), + ) + + response = await client.query_model( + model=model_name, + messages=messages, + generation_config=generation_config, + ) + return response[0] if response else "", 0, 0, 0 + + +def _rebuild_batch(batch: DocumentBatch, df: pd.DataFrame) -> DocumentBatch: + new_batch = DocumentBatch( + dataset_name=batch.dataset_name, + data=df, + _metadata=batch._metadata, + _stage_perf=batch._stage_perf, + ) + new_batch.task_id = batch.task_id + return new_batch + + +def _sanitize_case_output_html(case: object) -> None: + output_data = getattr(case, "output_data", None) + if output_data is None: + return + main_html = getattr(output_data, "main_html", None) + if isinstance(main_html, str): + output_data.main_html = _strip_xml_incompatible_chars(main_html) + + +def _get_processed_attr(case: object, attr: str) -> str: + process_data = getattr(case, "process_data", None) + value = getattr(process_data, attr, "") if process_data is not None else "" + return value if isinstance(value, str) else "" + + +def _case_has_item_ids(case: object) -> bool: + return "_item_id" in _get_processed_attr(case, "simpled_html") or "_item_id" in _get_processed_attr( + case, "map_html" + ) + + +def _count_item_ids(case: object) -> int: + html = _get_processed_attr(case, "simpled_html") or _get_processed_attr(case, "map_html") + return len(set(_ITEM_ID_RE.findall(html))) + + +def _coerce_html(value: object) -> str: + if _is_missing(value): + return "" + if isinstance(value, bytes | bytearray): + raw_bytes = bytes(value) + decoded: str | None = None + try: + decoded = raw_bytes.decode("utf-8") + except UnicodeDecodeError: + try: + from charset_normalizer import detect as _detect + + enc = _detect(raw_bytes)["encoding"] + if enc and enc != "utf-8": + decoded = raw_bytes.decode(enc) + except Exception: # noqa: BLE001 + decoded = None + if decoded is None: + decoded = raw_bytes.decode("utf-8", errors="replace") + return _strip_xml_incompatible_chars(decoded or "") + return _strip_xml_incompatible_chars(str(value)) + + +def _coerce_optional_str(value: object) -> str | None: + if _is_missing(value): + return None + text = str(value) + return text if text else None + + +def _is_empty_document_error(error: str) -> bool: + normalized = error.lower() + return "document is empty" in normalized or "empty html tree" in normalized or "empty html input" in normalized + + +def _generation_config_for_item_count(stage: Any, item_count: int) -> GenerationConfig: # noqa: ANN401 + base = stage.generation_config or GenerationConfig() + if not stage.dynamic_max_tokens or base.max_tokens is None or item_count <= 0: + return base + dynamic_max_tokens = max( + stage.dynamic_min_max_tokens, + item_count * stage.dynamic_max_tokens_per_item + stage.dynamic_max_token_padding, + ) + return replace(base, max_tokens=min(base.max_tokens, dynamic_max_tokens)) + + +def _apply_fallback_extraction( + bindings: object, fallback_handler: object, case: object, primary_error: str +) -> tuple[object, str, str]: + try: + case = bindings.extract_main_html_fallback(case, fallback_handler=fallback_handler) + except Exception as fallback_exc: # noqa: BLE001 + if primary_error: + return case, primary_error, f"{primary_error}; fallback failed: {fallback_exc}" + return case, "", f"fallback failed: {fallback_exc}" + else: + return case, primary_error, "" + + +def _numeric_series_or_zero(df: pd.DataFrame, column: str) -> pd.Series: + if column not in df.columns: + return pd.Series([0.0] * len(df), index=df.index) + return pd.to_numeric(df[column], errors="coerce").fillna(0.0) + + +def _append_warning(existing: str, new_warning: str) -> str: + if not existing: + return new_warning + if not new_warning: + return existing + return f"{existing}; {new_warning}" + + +def _convert_main_html(bindings: _MinerUHTMLBindings, main_html: str, url: object) -> str: + case = bindings.case_cls(bindings.input_cls(raw_html="", url=_coerce_optional_str(url))) + case.output_data = bindings.output_cls(main_html=main_html) + _sanitize_case_output_html(case) + case = bindings.convert2content(case, output_format="mm_md") + output_data = getattr(case, "output_data", None) + return str(getattr(output_data, "main_content", "") or "") if output_data else "" + + +def _is_missing(value: object) -> bool: + if value is None: + return True + try: + missing = pd.isna(value) + except (TypeError, ValueError): + return False + return bool(missing) if isinstance(missing, bool) else False + + +_XML_CHAR_SINGLE = {0x09, 0x0A, 0x0D} +_XML_CHAR_RANGES = ((0x20, 0xD7FF), (0xE000, 0xFFFD), (0x10000, 0x10FFFF)) + + +def _strip_xml_incompatible_chars(value: str) -> str: + return "".join( + c for c in value if (cp := ord(c)) in _XML_CHAR_SINGLE or any(lo <= cp <= hi for lo, hi in _XML_CHAR_RANGES) + ) + + +def _coerce_usage_int(value: object) -> int: + if isinstance(value, bool): + return 0 + if isinstance(value, int): + return value + if isinstance(value, float) and value.is_integer(): + return int(value) + if isinstance(value, str) and value.isdigit(): + return int(value) + return 0 + + +def _with_structured_output_config( + generation_config: GenerationConfig, + prompt: str, + mode: str, +) -> GenerationConfig: + if mode == "none": + return generation_config + item_ids = _item_ids_in_html(prompt) + if not item_ids or not all(item_id.isdigit() for item_id in item_ids): + return generation_config + + item_pattern = "".join(f"{re.escape(i)}(main|other)" for i in item_ids) + regex = f"\\s*{item_pattern}\\s*" + extra_kwargs = dict(generation_config.extra_kwargs or {}) + raw_extra_body = extra_kwargs.get("extra_body") + if raw_extra_body is not None and not isinstance(raw_extra_body, dict): + logger.warning("Skipping Dripper structured output because extra_body is not a dict") + return generation_config + extra_body: dict[str, Any] = dict(raw_extra_body) if isinstance(raw_extra_body, dict) else {} + + if mode == "structured_outputs": + extra_body["structured_outputs"] = {"regex": regex} + elif mode == "guided_regex": + extra_body["guided_regex"] = regex + else: + return generation_config + extra_kwargs["extra_body"] = extra_body + return replace(generation_config, extra_kwargs=extra_kwargs) + + +def _item_ids_in_html(html: str) -> list[str]: + # dict.fromkeys preserves insertion order and deduplicates + return list(dict.fromkeys(_ITEM_ID_RE.findall(html))) + + +_ITEM_ID_RE = re.compile(r"""_item_id\s*=\s*["']?([^"'\s>]+)""") + +_STRUCTURED_OUTPUT_MODES = {"none", "structured_outputs", "guided_regex"} diff --git a/nemo_curator/stages/text/experimental/dripper/workflow.py b/nemo_curator/stages/text/experimental/dripper/workflow.py new file mode 100644 index 0000000000..49e13af4e3 --- /dev/null +++ b/nemo_curator/stages/text/experimental/dripper/workflow.py @@ -0,0 +1,141 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DripperHTMLWorkflow — end-to-end HTML content extraction pipeline.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from loguru import logger + +from nemo_curator.pipeline import Pipeline +from nemo_curator.pipeline.workflow import WorkflowRunResult +from nemo_curator.stages.text.experimental.dripper._base_stages import ( + DripperHTMLInferenceStage, + DripperHTMLPostprocessStage, + DripperHTMLPreprocessStage, +) +from nemo_curator.stages.text.experimental.dripper.layout_template import DripperHTMLLayoutTemplateStage + +if TYPE_CHECKING: + from nemo_curator.backends.base import BaseExecutor + from nemo_curator.models.client.llm_client import AsyncLLMClient + from nemo_curator.stages.base import ProcessingStage + from nemo_curator.tasks import Task + + +@dataclass(kw_only=True) +class DripperHTMLWorkflow: + """End-to-end HTML content extraction pipeline (layout clustering + LLM inference).""" + + client: AsyncLLMClient | None + model_name: str + html_col: str = "html" + url_col: str | None = "url" + output_col: str = "dripper_content" + perform_layout_clustering: bool = True + layout_cluster_threshold: float = 0.95 + fallback: str = "trafilatura" + output_format: str = "mm_md" + max_concurrent_requests: int = 64 + health_check: bool = True + verbose: bool = True + + def __post_init__(self) -> None: + if self.client is None: + msg = "DripperHTMLWorkflow requires a non-None 'client' (AsyncLLMClient)" + raise ValueError(msg) + self.model_name = self.model_name.strip() + if not self.model_name: + msg = "DripperHTMLWorkflow requires a non-empty 'model_name'" + raise ValueError(msg) + if not (0.0 < self.layout_cluster_threshold <= 1.0): + msg = "layout_cluster_threshold must be in (0, 1]" + raise ValueError(msg) + if self.max_concurrent_requests <= 0: + msg = "max_concurrent_requests must be positive" + raise ValueError(msg) + + def run(self, executor: BaseExecutor, initial_tasks: list[Task] | None = None) -> WorkflowRunResult: + start = time.time() + + if self.verbose: + logger.info( + "DripperHTMLWorkflow starting — model={}, layout_clustering={}", + self.model_name, + self.perform_layout_clustering, + ) + + stages = self._build_stages() + pipeline = Pipeline(name="dripper_html_extraction") + for stage in stages: + pipeline.add_stage(stage) + + output_tasks = pipeline.run(executor=executor, initial_tasks=initial_tasks) + + elapsed = time.time() - start + + if self.verbose: + logger.info( + "DripperHTMLWorkflow complete in {:.1f}s", + elapsed, + ) + + result = WorkflowRunResult(workflow_name="dripper_html_extraction") + result.add_metadata("elapsed_s", elapsed) + result.add_metadata("stages", [s.name for s in stages]) + result.add_pipeline_tasks("dripper_html_extraction", output_tasks) + return result + + def _build_stages(self) -> list[ProcessingStage]: + preprocess = DripperHTMLPreprocessStage(html_col=self.html_col, url_col=self.url_col) + + if self.perform_layout_clustering: + # Preprocess → LayoutTemplate handles clustering + representative LLM + sibling propagation + # (DripperHTMLLayoutTemplateStage also handles singletons/standalone pages internally) + return [ + preprocess, + DripperHTMLLayoutTemplateStage( + client=self.client, + model_name=self.model_name, + html_col=self.html_col, + url_col=self.url_col, + layout_cluster_threshold=self.layout_cluster_threshold, + layout_template_fallback_llm=True, + fallback=self.fallback, + output_format=self.output_format, + max_concurrent_requests=self.max_concurrent_requests, + health_check=self.health_check, + ), + ] + + # Standalone extraction path: Preprocess → Inference → Postprocess + return [ + preprocess, + DripperHTMLInferenceStage( + client=self.client, + model_name=self.model_name, + max_concurrent_requests=self.max_concurrent_requests, + ), + DripperHTMLPostprocessStage( + html_col=self.html_col, + url_col=self.url_col, + fallback=self.fallback, + output_format=self.output_format, + output_content_col=self.output_col, + ), + ] diff --git a/pyproject.toml b/pyproject.toml index bd10a5337b..e899c50f56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -194,6 +194,7 @@ text_cpu = [ "s5cmd", "trafilatura==2.0.0", "warcio", + "xxhash", # Filters "fasttext==0.9.3", "sentencepiece", @@ -276,6 +277,7 @@ sdg_cuda12 = [ "nemo_curator[inference_server]", ] + # All dependencies all = [ "nemo_curator[audio_cuda12]", @@ -390,6 +392,7 @@ source = ["nemo_curator", "/opt/Curator/nemo_curator", "/home/runner/work/Curato [tool.ruff] line-length = 119 +extend-exclude = ["**/*.ipynb"] # notebooks checked separately [tool.ruff.lint] select = ["ALL"] ignore = [ @@ -425,11 +428,26 @@ fixable = ["ALL"] "INP001", # no __init__.py is required ] "tests/**/*.py" = [ - "S101", # asserts allowed in tests - "ANN201", # allow methods to not return something - "ARG002", # allow unused method args (mock.patch decorator injects args not always referenced) + "S101", # asserts allowed in tests + "ANN201", # allow methods to not return something + "ARG002", # allow unused method args (mock.patch decorator injects args not always referenced) "PLR2004", # magic value used in comparison - "ERA001", # allow commented-out code + "ERA001", # allow commented-out code +] +# Broader ignores for the dripper experimental test files, which use complex mock +# objects, intentional error message literals, and un-annotated helper functions. +"tests/stages/text/experimental/dripper/**" = [ + "ANN", # type annotations not required in test helpers + "BLE001", # broad exception catch fine in test helpers + "C901", # complex test-fixture functions are necessary for full mock coverage + "EM101", # exception string literals fine in test helpers + "EM102", # exception f-string literals fine in test helpers + "PLR0913", # too-many-args fine in test helper factories + "ARG001", # unused function args fine in mock callbacks (fallback_handler, etc.) + "PD101", # series.nunique() is fine for correctness assertions in tests + "PLW0603", # global statements for test module-level state + "INP001", # no __init__.py for sub-scripts loaded via importlib + "TCH", # no TYPE_CHECKING blocks needed in test helpers ] "benchmarking/**" = [ "BLE001", # allow catching blind exceptions (benchmark runners need catch-all error handling) @@ -438,9 +456,72 @@ fixable = ["ALL"] "BLE001", # allow catching blind exceptions (Sphinx extensions need robust error handling) ] "tutorials/**" = [ - "INP001", # no __init__.py is required + "INP001", # no __init__.py is required "PLE2515", # ignore \u200b complaint ] +# pipeline logic, and intentional script patterns not suitable for library code. +"tutorials/text/dripper-common-crawl/**" = [ + "ANN", # type annotations not required in tutorial scripts + "BLE001", # allow catching blind exceptions in scripts + "S101", # allow asserts in scripts + "S603", # subprocess calls with shell=False are fine in tutorials + "S607", # partial executable paths fine in tutorials + "TRY", # try/except style is tutorial-appropriate + "PERF", # micro-perf rules too strict for tutorials + "ERA001", # allow commented-out code in tutorials + "PLR2004", # magic values fine in scripts + "TCH", # no need to move typing imports to TYPE_CHECKING blocks + "C901", # complexity checks too strict for scripts + "PLR0912", # too-many-branches fine in scripts + "PLR0913", # too-many-args fine in scripts + "PLR0915", # too-many-statements fine in scripts + "EM", # error messages don't need separate variable in scripts + "ANN401", # Any type fine in tutorial scripts + "SIM", # simplification suggestions too strict for tutorial scripts + "RUF001", # unicode chars fine in comments/strings in tutorials + "RUF002", # unicode chars fine in docstrings in tutorials + "RUF003", # unicode chars fine in comments + "N806", # UPPER_CASE constants inside functions are conventional in scripts + "PLW0602", # global without assignment fine in module-level state pattern + "PLW0603", # global statement for module-level worker caches is intentional pattern + "PLW1508", # int defaults for os.environ.get are cast immediately; fine in scripts + "S301", # pickle use is intentional (lossless template serialization) + "S302", # marshal use not present but suppress + "PT018", # composite assert fine in tests helper + "B023", # loop variable capture fine in tutorial closures + "B007", # unused loop var fine + "E741", # ambiguous variable names fine in compact scripts + "F841", # unused assignments fine in scripts (often defensive) + "A004", # import shadowing builtin fine in tutorial notebooks + "B905", # zip without strict= fine in tutorial visualization code + "E402", # module-level import not at top fine in notebook cells + "PLW2901", # loop variable overwrite fine in tutorial scripts + "B904", # raise-without-from-cause fine in script error handlers + "PLR0911", # too-many-return-statements fine in scripts with guard clauses + "S110", # try/except/pass fine in optional-feature guards in scripts + "ICN001", # lazy internal imports may use non-canonical alias (e.g. _pa) + "EXE001", # shebang without executable bit is fine in repo scripts + "PD008", # .at vs .loc performance hint irrelevant in tutorial data-processing scripts + "C408", # dict() vs {} literal style — fine in tutorials + "S112", # try/except/continue with no logging fine in optional-feature guards + "E702", # semicolon-separated statements fine in compact tutorial scripts + "E701", # colon-separated one-liners fine in compact tutorial scripts + "PD002", # inplace=True fine in tutorial data-processing scripts + "RET504", # intermediate variable before return is a common readable pattern in scripts + "ARG001", # unused function argument fine in callback/hook signatures in scripts + "ARG002", # unused method argument fine in interface-conforming methods in scripts + "N803", # UpperCase argument names are conventional for class-like params in scripts + "N802", # function name casing fine in dunder/mangled methods in scripts + "S105", # PASS/FAIL/SKIP ANSI-color constants are not passwords + "RUF059", # unpacked-but-unused variable fine in scripts that need side effects + "C401", # generator vs set-comprehension style is fine in tutorial scripts + "PD011", # .values is conventional shorthand in tutorial notebooks/scripts +] +"tutorials/text/dripper-common-crawl/dashboard_server.py" = [ + "S108", # /tmp/nbx.sh is a deliberately temporary helper script + "S103", # os.chmod 0o755 is intentional for the helper script + "ASYNC221", # subprocess.run in async context is acceptable for SSH polling +] "fern/**/*.py" = [ "INP001", # Fern CLI helper scripts; not an installable package ] diff --git a/tests/stages/text/experimental/dripper/__init__.py b/tests/stages/text/experimental/dripper/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/tests/stages/text/experimental/dripper/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/stages/text/experimental/dripper/test_stage.py b/tests/stages/text/experimental/dripper/test_stage.py new file mode 100644 index 0000000000..5e7f8ba512 --- /dev/null +++ b/tests/stages/text/experimental/dripper/test_stage.py @@ -0,0 +1,489 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Behavioral unit tests for Dripper stages.""" + +from __future__ import annotations + +import re +from collections.abc import Iterable +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any + +import pandas as pd +import pytest + +from nemo_curator.models.client.llm_client import AsyncLLMClient, GenerationConfig +from nemo_curator.stages.text.experimental.dripper import ( + DripperHTMLExtractionStage, + DripperHTMLInferenceStage, + DripperHTMLLayoutTemplateStage, + DripperHTMLPreprocessStage, +) +from nemo_curator.stages.text.experimental.dripper import stage as stage_mod +from nemo_curator.tasks import DocumentBatch + +# --------------------------------------------------------------------------- +# Fake types / helpers +# --------------------------------------------------------------------------- + + +@dataclass +class FakeInput: + raw_html: str + url: str | None = None + + +@dataclass +class FakeOutput: + main_html: str + main_content: str | None = None + + +@dataclass +class FakeCase: + input_data: FakeInput + case_id: str = "fake-case" + process_data: object = None + generate_input: object = None + generate_output: object = None + parse_result: object = None + output_data: object = None + + +class RecordingAsyncClient(AsyncLLMClient): + def __init__(self, responses: list[str]) -> None: + super().__init__(max_concurrent_requests=8, max_retries=0, base_delay=0.0) + self.responses = responses + self.calls: list[dict[str, Any]] = [] + self.setup_calls = 0 + + def setup(self) -> None: + self.setup_calls += 1 + + async def _query_model_impl( + self, + *, + messages: Iterable, + model: str, + conversation_formatter: object = None, + generation_config: GenerationConfig | dict | None = None, + ) -> list[str]: + self.calls.append({"messages": list(messages), "model": model, "generation_config": generation_config}) + return [self.responses.pop(0)] + + +def _make_mineru_bindings(label_aware: bool = False) -> stage_mod._MinerUHTMLBindings: + def simplify_single_input(case: FakeCase) -> FakeCase: + if "preprocess-fails" in case.input_data.raw_html: + raise RuntimeError("preprocess failed") + body = ( + "
No item ids
" + if "no-items" in case.input_data.raw_html + else f'
{case.input_data.raw_html}
' + ) + case.process_data = SimpleNamespace( + simpled_html=body, map_html=f"{case.input_data.raw_html}" + ) + return case + + def parse_result(case: FakeCase) -> FakeCase: + if case.generate_output.response == "bad-response": + raise RuntimeError("parse failed") + if label_aware: + case.parse_result = SimpleNamespace( + item_label=dict(re.findall(r"(\d+)(main|other)", case.generate_output.response)) + ) + else: + case.parse_result = SimpleNamespace(item_label={"1": "main"}) + return case + + def extract_main_html_single(case: FakeCase) -> FakeCase: + if label_aware: + labels = getattr(case.parse_result, "item_label", {}) + main_ids = [iid for iid, lbl in labels.items() if lbl == "main"] + case.output_data = FakeOutput(main_html="|".join(f"main:{iid}" for iid in main_ids)) + else: + main_html = ( + "" if "empty-main" in case.input_data.raw_html else f"
{case.input_data.raw_html}
" + ) + case.output_data = FakeOutput(main_html=main_html) + return case + + def extract_main_html_fallback(case: FakeCase, fallback_handler: object) -> FakeCase: + main_html = ( + "" if "empty-main" in case.input_data.raw_html else f"{case.input_data.raw_html}" + ) + case.output_data = FakeOutput(main_html=main_html) + return case + + def convert2content(case: FakeCase, output_format: str) -> FakeCase: + if not case.output_data.main_html: + raise RuntimeError("ExtractorChain base exception#Error during extraction: Document is empty") + case.output_data.main_content = f"{output_format}:{case.output_data.main_html}" + return case + + return stage_mod._MinerUHTMLBindings( + input_cls=FakeInput, + case_cls=FakeCase, + output_cls=FakeOutput, + process_data_cls=SimpleNamespace, + generate_output_cls=lambda response: SimpleNamespace(response=response), + simplify_single_input=simplify_single_input, + build_prompt=lambda case, v: setattr( + case, "generate_input", SimpleNamespace(full_prompt=f"{v}:{case.process_data.simpled_html}") + ) + or case, + parse_result=parse_result, + extract_main_html_single=extract_main_html_single, + extract_main_html_fallback=extract_main_html_fallback, + convert2content=convert2content, + get_fallback_handler=lambda fb: SimpleNamespace(name=fb), + ) + + +def _make_llm_web_kit_bindings( + *, map_parser_cls=None, layout_parser_cls=None, get_feature=None, cluster_html_struct=None +) -> stage_mod._LLMWebKitBindings: + class _DefaultMapParser: + def __init__(self, template_data: dict) -> None: + pass + + def parse(self, typical_data: dict) -> dict: + return { + "html_element_dict": {"labels": typical_data["llm_response"]}, + "typical_dict_html": typical_data["typical_raw_tag_html"], + "typical_main_html": "
template
", + "similarity_layer": 3, + "typical_main_html_success": True, + } + + class _DefaultLayoutParser: + def __init__(self, template_data: dict) -> None: + pass + + def parse(self, task_data: dict) -> dict: + return { + "main_html_body": f"{task_data['html_source']}", + "main_html_success": True, + } + + def _default_cluster( + samples: list[dict[str, Any]], threshold: float = 0.95 + ) -> tuple[list[dict[str, Any]], list[int]]: + for s in samples: + s["layout_id"] = 0 + return samples, [0] + + return stage_mod._LLMWebKitBindings( + get_feature=get_feature or (lambda html: {"tags": {1: ["body"], 2: [html]}}), + cluster_html_struct=cluster_html_struct or _default_cluster, + select_representative_html=lambda candidates: candidates[0] if candidates else None, + map_parser_cls=map_parser_cls or _DefaultMapParser, + layout_parser_cls=layout_parser_cls or _DefaultLayoutParser, + ) + + +def _batch(data: dict) -> DocumentBatch: + return DocumentBatch(task_id="t", dataset_name="d", data=pd.DataFrame(data)) + + +@pytest.fixture(autouse=True) +def patch_mineru_bindings(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(stage_mod, "_load_mineru_html_bindings", _make_mineru_bindings) + + +# --------------------------------------------------------------------------- +# DripperHTMLExtractionStage +# --------------------------------------------------------------------------- + + +def test_extraction_stage_runs_pipeline_with_async_client() -> None: + client = RecordingAsyncClient(["1main"]) + stage = DripperHTMLExtractionStage( + client=client, + model_name="dripper", + html_col="html", + health_check=False, + keep_intermediate=True, + generation_config=GenerationConfig(max_tokens=2048), + ) + out = stage.process(_batch({"url": ["https://example.test/a"], "html": ["Hello"]})).to_pandas() + + assert client.setup_calls == 1 + assert out["dripper_response"].tolist() == ["1main"] + assert out["dripper_html"].tolist() == ["
Hello
"] + assert out["dripper_simplified_html"].str.contains("_item_id").all() + assert client.calls[0]["model"] == "dripper" + + +def test_extraction_stage_error_paths_use_fallback_and_warnings() -> None: + def _run(html: str, responses: list[str]) -> pd.Series: + client = RecordingAsyncClient(responses) + stage = DripperHTMLExtractionStage(client=client, model_name="dripper", html_col="html", health_check=False) + return stage.process(_batch({"html": [html]})).to_pandas().iloc[0] + + row = _run("Fallback", ["bad-response"]) + assert row["dripper_html"] == "Fallback" + assert "parse failed" in row["dripper_warning"] + + row2 = _run("no-items", []) + assert "no _item_id attributes" in row2["dripper_warning"] + + row3 = _run("", []) + assert row3["dripper_warning"] == "empty HTML input" + + row4 = _run("empty-main", ["1main"]) + assert "Document is empty" in row4["dripper_warning"] + assert row4["dripper_content"] == "" + + +def test_extraction_stage_decodes_bytes(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(stage_mod, "_decode_html_bytes", lambda _: None) + client = RecordingAsyncClient(["1main"]) + stage = DripperHTMLExtractionStage(client=client, model_name="dripper", html_col="html", health_check=False) + out = stage.process(_batch({"html": [b"Bad\xffByte"]})).to_pandas() + assert out.loc[0, "dripper_error"] == "" + assert client.calls + + +def test_extraction_stage_missing_bindings_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + stage_mod, "_load_mineru_html_bindings", lambda: (_ for _ in ()).throw(RuntimeError("missing mineru")) + ) + stage = DripperHTMLExtractionStage( + client=RecordingAsyncClient(["1main"]), model_name="dripper", html_col="html", health_check=False + ) + with pytest.raises(RuntimeError, match="missing mineru"): + stage.setup() + + +# --------------------------------------------------------------------------- +# DripperHTMLInferenceStage +# --------------------------------------------------------------------------- + + +def test_inference_stage_deduplicates_identical_prompts() -> None: + client = RecordingAsyncClient(["1main", "1other"]) + preprocess = DripperHTMLPreprocessStage(html_col="html", generation_config=GenerationConfig(max_tokens=2048)) + inference = DripperHTMLInferenceStage( + client=client, model_name="dripper", health_check=False, generation_config=GenerationConfig(max_tokens=2048) + ) + batch = _batch({"html": ["Same", "Same", "Different"]}) + out = inference.process(preprocess.process(batch)).to_pandas() + assert len(client.calls) == 2 + assert out["dripper_response"].tolist() == ["1main", "1main", "1other"] + assert out["dripper_inference_time_s"].iloc[1] == 0.0 + + +# --------------------------------------------------------------------------- +# DripperHTMLLayoutTemplateStage +# --------------------------------------------------------------------------- + + +def test_layout_stage_uses_precomputed_layout_id_column() -> None: + stage = DripperHTMLLayoutTemplateStage( + client=RecordingAsyncClient(["1main"]), + model_name="dripper", + health_check=False, + host_col="url_host_name", + layout_id_col="dripper_layout_id", + ) + stage._web_bindings = _make_llm_web_kit_bindings() + df = pd.DataFrame( + { + "url": [f"https://a.example/{i}" for i in range(5)] + ["https://b.example/1", "https://b.example/2"], + "url_host_name": ["a.example"] * 5 + ["b.example"] * 2, + "dripper_layout_id": [ + "a.example_0", + "a.example_0", + "a.example_1", + "a.example_1", + "-1", + "a.example_0", + "a.example_0", + ], + "html": ["

x

"] * 7, + stage_mod._DRIPPER_NEEDS_LLM_COL: [True] * 7, + } + ) + plans = stage._build_layout_group_plans(df) + assert [(p.host_key, p.source, p.indexes) for p in plans] == [ + ("a.example", "precomputed_layout:a.example_0", [0, 1]), + ("a.example", "precomputed_layout:a.example_1", [2, 3]), + ("b.example", "precomputed_layout:a.example_0", [5, 6]), + ] + + +def test_layout_stage_propagates_siblings(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(stage_mod, "_load_llm_web_kit_bindings", _make_llm_web_kit_bindings) + client = RecordingAsyncClient(["1main"]) + preprocess = DripperHTMLPreprocessStage( + html_col="html", url_col="url", generation_config=GenerationConfig(max_tokens=2048) + ) + layout = DripperHTMLLayoutTemplateStage( + client=client, + model_name="dripper", + health_check=False, + generation_config=GenerationConfig(max_tokens=2048), + layout_template_fallback_llm=True, + layout_template_require_success=True, + ) + + def _no_fallback(*_a, **_kw): + raise AssertionError("fallback should not run") + + monkeypatch.setattr(layout, "_fallback_row", _no_fallback) + batch = _batch( + { + "url": ["https://example.test/a", "https://example.test/b", "https://example.test/c"], + "html": ["Rep", "Sib1", "Sib2"], + } + ) + out = layout.process(preprocess.process(batch)).to_pandas() + assert len(client.calls) == 1 + assert out["dripper_layout_representative"].tolist() == [True, False, False] + assert out["dripper_layout_propagated"].tolist() == [False, True, True] + assert out["dripper_layout_propagation_success"].tolist() == [False, True, True] + + +def test_layout_stage_validation_falls_back_to_llm(monkeypatch: pytest.MonkeyPatch) -> None: + class _DivergingLayoutParser: + def __init__(self, template_data: dict) -> None: + pass + + def parse(self, task_data: dict) -> dict: + return {"main_html_body": '
propagated sibling
', "main_html_success": True} + + class _LabelMapParser: + def __init__(self, template_data: dict) -> None: + pass + + def parse(self, typical_data: dict) -> dict: + return { + "html_element_dict": {"labels": typical_data["llm_response"]}, + "typical_dict_html": typical_data["typical_raw_tag_html"], + "typical_main_html": '
template
', + "similarity_layer": 3, + "typical_main_html_success": True, + } + + monkeypatch.setattr(stage_mod, "_load_mineru_html_bindings", lambda: _make_mineru_bindings(label_aware=True)) + monkeypatch.setattr( + stage_mod, + "_load_llm_web_kit_bindings", + lambda: _make_llm_web_kit_bindings(map_parser_cls=_LabelMapParser, layout_parser_cls=_DivergingLayoutParser), + ) + client = RecordingAsyncClient(["1main", "1main", "1main"]) + preprocess = DripperHTMLPreprocessStage(html_col="html", url_col="url") + layout = DripperHTMLLayoutTemplateStage( + client=client, + model_name="dripper", + health_check=False, + layout_template_fallback_llm=True, + layout_template_require_success=True, + layout_template_max_selected_item_ratio=1.0, + layout_template_validation_rows=1, + layout_template_validation_min_content_f1=0.98, + ) + batch = _batch( + { + "url": [f"https://example.test/{c}" for c in "abc"], + "html": [ + '

Rep main

Rep nav

', + '

Val main

Val nav

', + '

Rem main

Rem nav

', + ], + } + ) + out = layout.process(preprocess.process(batch)).to_pandas() + assert len(client.calls) == 3 + assert out["dripper_layout_fallback_llm"].tolist() == [False, True, True] + assert "layout template validation failed" in out.loc[1, "dripper_warning"] + + +def test_layout_stage_splits_by_url_shape(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(stage_mod, "_load_llm_web_kit_bindings", lambda: _make_llm_web_kit_bindings()) + client = RecordingAsyncClient(["1main", "1main"]) + layout = DripperHTMLLayoutTemplateStage( + client=client, + model_name="dripper", + health_check=False, + layout_template_max_selected_item_ratio=1.0, + layout_page_signature_mode="url_shape", + ) + preprocess = DripperHTMLPreprocessStage(html_col="html", url_col="url") + batch = _batch( + { + "url": [ + "https://x.test/archive.html?start=10", + "https://x.test/archive.html?start=20", + "https://x.test/news/123.html", + "https://x.test/news/456.html", + ], + "html": ["

Archive 1

", "

Archive 2

", "

Article 1

", "

Article 2

"], + } + ) + out = layout.process(preprocess.process(batch)).to_pandas() + assert len(client.calls) == 2 + assert out["dripper_layout_cluster"].nunique() == 2 + + +def test_layout_stage_uses_feature_hash_for_large_hosts(monkeypatch: pytest.MonkeyPatch) -> None: + def _get_feature(html: str) -> dict: + if "same" in html: + return {"tags": {1: ["body"], 2: ["article", "nav"]}} + return {"tags": {1: ["body"], 2: ["aside"]}} + + def _no_dbscan(samples: list, threshold: float = 0.95): + raise AssertionError("feature_hash mode should not call exact DBSCAN") + + monkeypatch.setattr( + stage_mod, + "_load_llm_web_kit_bindings", + lambda: _make_llm_web_kit_bindings(get_feature=_get_feature, cluster_html_struct=_no_dbscan), + ) + client = RecordingAsyncClient(["1main", "1main"]) + layout = DripperHTMLLayoutTemplateStage( + client=client, + model_name="dripper", + health_check=False, + layout_template_max_exact_host_pages=2, + layout_template_large_host_mode="feature_hash", + ) + preprocess = DripperHTMLPreprocessStage(html_col="html", url_col="url") + batch = _batch( + { + "url": [f"https://x.test/{c}" for c in "abcd"], + "html": [ + "same rep", + "same sib", + "other lone", + "same sib2", + ], + } + ) + out = layout.process(preprocess.process(batch)).to_pandas() + assert len(client.calls) == 2 + assert out["dripper_layout_representative"].tolist() == [True, False, False, False] + assert out["dripper_layout_standalone_llm"].tolist() == [False, False, True, False] + + +def test_layout_stage_validation_indexes_cover_strata() -> None: + df = pd.DataFrame({"url": [f"https://t.test/{i}" for i in range(10)], "dripper_item_count": list(range(10))}) + cols = ("url", "dripper_item_count") + assert stage_mod._select_validation_indexes(df, [], 2, cols) == [] + assert stage_mod._select_validation_indexes(df, [1, 2, 3, 4], 2, cols) == [1, 4] + assert stage_mod._select_validation_indexes(df, list(range(10)), 4, cols) == [0, 3, 6, 9] diff --git a/tests/stages/text/experimental/dripper/test_workflow.py b/tests/stages/text/experimental/dripper/test_workflow.py new file mode 100644 index 0000000000..f33c632fc1 --- /dev/null +++ b/tests/stages/text/experimental/dripper/test_workflow.py @@ -0,0 +1,164 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DripperHTMLWorkflow — no GPU, Ray, or LLM server required.""" + +from __future__ import annotations + +from collections.abc import Iterable + +import pytest + +from nemo_curator.models.client.llm_client import AsyncLLMClient, GenerationConfig +from nemo_curator.pipeline.workflow import WorkflowRunResult +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.text.experimental.dripper import DripperHTMLWorkflow + + +class _StubLLMClient(AsyncLLMClient): + def __init__(self) -> None: + super().__init__(max_concurrent_requests=1, max_retries=0, base_delay=0.0) + + def setup(self) -> None: + pass + + async def _query_model_impl( + self, + *, + messages: Iterable, + model: str, + conversation_formatter: object = None, + generation_config: GenerationConfig | dict | None = None, + ) -> list[str]: + return [""] + + +@pytest.fixture +def stub_client() -> _StubLLMClient: + return _StubLLMClient() + + +@pytest.fixture +def base_workflow(stub_client: _StubLLMClient) -> DripperHTMLWorkflow: + return DripperHTMLWorkflow( + client=stub_client, model_name="test-model", perform_layout_clustering=False, health_check=False + ) + + +class TestDripperHTMLWorkflow: + def test_instantiation_defaults(self, stub_client: _StubLLMClient) -> None: + wf = DripperHTMLWorkflow(client=stub_client, model_name="test-model") + assert wf.perform_layout_clustering is True + assert wf.layout_cluster_threshold == pytest.approx(0.95) + assert wf.fallback == "trafilatura" + assert wf.output_format == "mm_md" + assert wf.max_concurrent_requests == 64 + assert wf.health_check is True + assert wf.verbose is True + assert wf.html_col == "html" + assert wf.url_col == "url" + assert wf.output_col == "dripper_content" + + def test_custom_fields_stored(self, stub_client: _StubLLMClient) -> None: + wf = DripperHTMLWorkflow( + client=stub_client, + model_name="custom-model", + layout_cluster_threshold=0.85, + perform_layout_clustering=False, + fallback="bypass", + output_format="text", + max_concurrent_requests=32, + health_check=False, + verbose=False, + ) + assert wf.model_name == "custom-model" + assert wf.layout_cluster_threshold == pytest.approx(0.85) + assert wf.fallback == "bypass" + assert wf.output_format == "text" + assert wf.max_concurrent_requests == 32 + + @pytest.mark.parametrize("with_clustering", [True, False]) + def test_build_stages_returns_processing_stages(self, stub_client: _StubLLMClient, with_clustering: bool) -> None: + wf = DripperHTMLWorkflow( + client=stub_client, model_name="test-model", perform_layout_clustering=with_clustering, health_check=False + ) + stages = wf._build_stages() + assert len(stages) > 0 + assert all(isinstance(s, ProcessingStage) for s in stages) + assert all(s.name.strip() for s in stages) + + def test_layout_clustering_toggle(self, stub_client: _StubLLMClient) -> None: + with_clust = DripperHTMLWorkflow( + client=stub_client, model_name="test-model", perform_layout_clustering=True, health_check=False + ) + without_clust = DripperHTMLWorkflow( + client=stub_client, model_name="test-model", perform_layout_clustering=False, health_check=False + ) + assert len(with_clust._build_stages()) > len(without_clust._build_stages()) + with_names = [s.name for s in with_clust._build_stages()] + without_names = [s.name for s in without_clust._build_stages()] + assert any("Layout" in n for n in with_names) + assert not any("Layout" in n for n in without_names) + + def test_core_stage_order(self, base_workflow: DripperHTMLWorkflow) -> None: + names = [s.name for s in base_workflow._build_stages()] + assert "DripperHTMLPreprocessStage" in names + assert "DripperHTMLInferenceStage" in names + assert "DripperHTMLPostprocessStage" in names + assert names.index("DripperHTMLPreprocessStage") < names.index("DripperHTMLInferenceStage") + assert names.index("DripperHTMLInferenceStage") < names.index("DripperHTMLPostprocessStage") + + def test_custom_column_names_propagate(self, stub_client: _StubLLMClient) -> None: + wf = DripperHTMLWorkflow( + client=stub_client, + model_name="test-model", + html_col="raw_html", + url_col="page_url", + output_col="extracted_text", + perform_layout_clustering=False, + health_check=False, + ) + stages = wf._build_stages() + preprocess = next(s for s in stages if s.name == "DripperHTMLPreprocessStage") + postprocess = next(s for s in stages if s.name == "DripperHTMLPostprocessStage") + assert preprocess.html_col == "raw_html" + assert preprocess.url_col == "page_url" + assert postprocess.output_content_col == "extracted_text" + + def test_post_init_validation_raises_for_none_client(self) -> None: + with pytest.raises(ValueError, match="non-None"): + DripperHTMLWorkflow(client=None, model_name="test-model") + + def test_post_init_validation_raises_for_empty_model(self, stub_client: _StubLLMClient) -> None: + with pytest.raises(ValueError, match="non-empty"): + DripperHTMLWorkflow(client=stub_client, model_name=" ") + + def test_post_init_validation_raises_for_bad_threshold(self, stub_client: _StubLLMClient) -> None: + with pytest.raises(ValueError, match="layout_cluster_threshold"): + DripperHTMLWorkflow(client=stub_client, model_name="m", layout_cluster_threshold=1.5) + + def test_run_returns_workflow_run_result( + self, base_workflow: DripperHTMLWorkflow, monkeypatch: pytest.MonkeyPatch + ) -> None: + from nemo_curator.pipeline import Pipeline + + monkeypatch.setattr(Pipeline, "run", lambda _self, _executor, _initial_tasks=None: []) + + from nemo_curator.backends.xenna import XennaExecutor + + result = base_workflow.run(executor=XennaExecutor()) + assert isinstance(result, WorkflowRunResult) + assert result.get_metadata("elapsed_s") >= 0.0 + assert isinstance(result.get_metadata("stages"), list) + assert len(result.get_metadata("stages")) > 0 diff --git a/tutorials/text/dripper-common-crawl/README.md b/tutorials/text/dripper-common-crawl/README.md new file mode 100644 index 0000000000..2caa2740c4 --- /dev/null +++ b/tutorials/text/dripper-common-crawl/README.md @@ -0,0 +1,49 @@ +# Dripper Common Crawl Smoke + +This tutorial runs Dripper/MinerU-HTML through NeMo Curator's inference server +path on a bounded Common Crawl sample. It is intended for single-node H100 +smoke runs before scaling to a full snapshot. + +The Python runner: + +1. Streams WARC records from `CC-MAIN-2025-26`. +2. Starts Ray through Curator's `SlurmRayClient` on SLURM, or `RayClient` + outside SLURM. +3. Starts a Curator `InferenceServer` with the Dripper model. +4. Points `AsyncOpenAIClient` at the server endpoint. +5. Optionally runs warmup pages, then runs `DripperHTMLExtractionStage`. +6. Writes extracted rows plus steady-state and end-to-end H100-hour metrics. + +Run the standalone baseline directly (single node, 8 GPUs): + +```bash +python tutorials/text/dripper-common-crawl/run_mineru_html_standalone.py \ + --input-manifest-path /path/to/manifest.parquet \ + --output-dir /path/to/output --replicas 8 --max-concurrent-requests 64 +``` + +Useful overrides: `--max-pages`, `--replicas`, `--max-concurrent-requests`, +`--warmup-pages`. Wrap this in your scheduler's job script (e.g. an `sbatch` +wrapper) for your cluster. + +Throughput knobs that should not change Dripper extraction semantics: + +- `ENABLE_PREFIX_CACHING=1` is the default and reuses identical prompt prefixes + in vLLM. +- `DISABLE_THINKING=1` is the default and passes + `chat_template_kwargs={"enable_thinking": false, "thinking": false}` through + the OpenAI-compatible vLLM request. Dripper expects JSON/compact labels, so + disabling thinking avoids `...` text that MinerU-HTML cannot parse. +- `MAX_CONCURRENT_REQUESTS`, `MAX_NUM_SEQS`, and `MAX_NUM_BATCHED_TOKENS` tune + request batching. +- `GPU_MEMORY_UTILIZATION` defaults to `0.9` in the Nebius wrapper to increase + KV-cache capacity. +- `WARMUP_PAGES` excludes cold first-request overhead from the steady-state + `h100_hours_per_page` metric while still reporting end-to-end timing. + +Use `ENFORCE_EAGER=1` for short debug runs where startup time matters more than +steady-state throughput. Leave it unset for cost estimation runs. + +The submit script expects PBSS/Common Crawl credentials to be available from +the environment or from the user's remote cache environment file. It does not +print secret values. diff --git a/tutorials/text/dripper-common-crawl/STYLE_GAPS.md b/tutorials/text/dripper-common-crawl/STYLE_GAPS.md new file mode 100644 index 0000000000..60bc497a7b --- /dev/null +++ b/tutorials/text/dripper-common-crawl/STYLE_GAPS.md @@ -0,0 +1,637 @@ +# Style Gaps: SemanticDedup Tutorial vs Dripper Tutorial + +## Swarm Results (2026-06-14) + +### Fixed in 4-agent swarm + +**Agent 1 (P1 Critical Bugs)** +- Added `_convert_main_html()` to stage.py (was missing, broke propagation_stage.py) +- Fixed `DripperHTMLExtractionStage._coerce_html` → module-level `_coerce_html()` in stage.py +- Replaced assert statements with explicit RuntimeError in propagation_stage.py +- Added missing `@dataclass(kw_only=True)` to DripperHTMLPreprocessStage +- Fixed test_stage.py import paths (were importing deleted symbols from stage.py) + +**Agent 2 (Field Reduction)** +- DripperHTMLLayoutTemplateStage: 61 → 30 fields +- Created DripperLayoutAdvancedConfig for 12 CC-scale tuning knobs +- Fixed 14 output column name overrides (now use _DRIPPER_*_COL constants) + +**Agent 3 (Tutorial → Library Migration)** +- LBP static/dynamic split logic moved to propagation_stage.py +- stage3_cpu_propagation.py: 795 → 674 lines +- stage_gpu_pipeline.py: 648 → 541 lines (uses DripperHTMLPostprocessStage) + +**Agent 4 (layout_template.py Size)** +- layout_template.py: 1,872 → 1,569 lines (-303 lines) +- Planning functions extracted to module level (_layout_planning.py: 431 lines) +- Exception handling tightened + +### New gaps identified (Iteration 7+) + +**Gap 7.1 — stage3_ray_propagation.py reimplements 6 helpers already in the library** +- File: `tutorials/text/dripper-common-crawl/stage3_ray_propagation.py` lines 81–210 +- `_coerce_html` (line 81), `_parse_mapping_json` (line 104), `_token_f1` (line 135), + `_load_cluster_manifest_shard` (line 153), `_load_inference_results` (line 183), + `_atomic_write_parquet` (line 207) are all re-implemented locally. +- The library already exports `_coerce_html`, `_token_f1`, `_atomic_write_parquet`-equivalent + from `nemo_curator.stages.text.experimental.dripper.stage` and `_url_helpers`. +- The local `_coerce_html` (line 81–84) skips `_strip_xml_incompatible_chars` and + `_decode_html_bytes` that the library version applies, creating a silent divergence. +- **Fix:** Replace all 6 local copies with imports from the library. The local + `_coerce_html` divergence is a correctness risk — the library version must be used. + Estimated removal: ~60 lines. + +**Gap 7.2 — stage3_ray_propagation.py uses stdlib `logging` not loguru (1,080 lines)** +- File: `tutorials/text/dripper-common-crawl/stage3_ray_propagation.py` line 44, 58 +- `import logging` + `logger = logging.getLogger(__name__)` — not loguru. +- stage3_cpu_propagation.py already uses `from loguru import logger` (line 46). +- The two Stage 3 variants have inconsistent logging: structured loguru in the + ProcessPoolExecutor variant, stdlib in the Ray variant. +- **Fix:** Replace `import logging` / `logging.getLogger` with `from loguru import logger` + at line 44/58. This is a one-line swap; loguru is already in the project deps. + +**Gap 7.3 — `_make_stage_cls` in stage_gpu_pipeline.py still uses the anonymous factory pattern** +- File: `tutorials/text/dripper-common-crawl/stage_gpu_pipeline.py` lines 122–154 +- Despite Agent 3 migrating postprocessing to `DripperHTMLPostprocessStage`, Stage 1c and + Stage 2 are still wrapped via `_make_stage_cls(stage_name, setup_fn, process_fn)` which + produces anonymous classes with no stable `name` attribute and no import path. +- The `process_batch` override (line 144–151) reconstructs a `DocumentBatch` without + preserving `_metadata` or `_stage_perf`, silently dropping pipeline telemetry. +- **Fix:** Replace the Stage 1c anonymous stage with `DripperHTMLPreprocessStage` (already + in `preprocessing.py`) and the Stage 2 LLM call with `DripperHTMLInferenceStage` from + `inference.py`. `_make_stage_cls` can then be deleted entirely (~33 lines removed). + +**Gap 7.4 — layout_template.py `process()` carries 3 noqa complexity suppressions** +- File: `nemo_curator/stages/text/experimental/dripper/layout_template.py` line 498 +- `def process(...)` is decorated `# noqa: C901, PLR0912, PLR0915` (too-complex, + too-many-branches, too-many-statements). +- The method dispatches plan execution, collects results, writes output columns, and + handles timing — all in one function body that was only partially split by Agent 4. +- **Fix:** Extract the output-column assembly loop (currently lines ~580–625) into + `_assemble_output_df(df, row_results) -> pd.DataFrame` and the plan-dispatch loop into + `_execute_plans_async(ctx, plans) -> dict`. This should remove all three noqa suppressions. + +**Gap 7.5 — `stage.py` `_run_dripper_health_check` silently accepts `RuntimeError` re-raise without re-raise guard** +- File: `nemo_curator/stages/text/experimental/dripper/stage.py` lines 219–226 +- The health-check catches all non-RuntimeError exceptions and re-raises as `RuntimeError`, + but the `except RuntimeError: raise` guard (line 219–220) is a bare re-raise that lets + `RuntimeError` from `client.query_model` propagate with no additional context. +- The empty-response guard (line 226) uses a no-`EM101` string literal directly in + `raise RuntimeError(...)` without assigning to a variable first — ruff `EM101` is + suppressed via the `# noqa: EM101` comment rather than fixed. +- **Fix:** Assign the error string to `msg` before raising (matching the pattern used + elsewhere in the file). Add `f"Dripper LLM health check timed out or returned no data " + f"(model={model_name!r})"` as the RuntimeError message so the caller sees the model name. + +### PR Status +- Total Python LOC: 13,957 (8,755 tutorial + 5,012 library + 190 workflow) +- F1 (5 retests): 0.8442–0.8443 stable +- Ruff: All checks passed + +--- + +## Status Update (2026-06-14) + +### Completed ✅ +- Priority 1 (quickstart): ✅ 344→145 lines +- Priority 2 (loguru): ✅ 43 print() eliminated +- Priority 3 (DripperConfig): ✅ dataclass + YAML bridge +- Priority 4 (test_workflow): ✅ 10 synthetic tests, 152 lines +- Priority 5 (type annotations): ✅ completed +- Item 6 (WorkflowRunResult): ✅ typed return + +--- + +## Iteration 2-4 Architectural Improvements + +- stage.py split: 3,776→489 lines (-87%) +- layout_template.py extracted: 2,356 lines focused file +- stage.py now only 489 lines (shared utilities) +- workflow.py: WorkflowRunResult return type +- quickstart.py: 344→145 lines +- test_workflow.py: new, 152 lines +- 4 consecutive cluster retests: F1=0.8442~0.8443 confirmed stable + +--- + +## Remaining Gaps (Iter 5+) + +- layout_template.py still 2,356 lines (SemanticDedup equivalent: ~322) +- stage3_cpu_propagation.py: 902 lines +- run_pipeline.py: 723 lines (Slurm orchestrator, inherently cluster-specific) +- pipeline_metrics.py: 265 lines (could use Curator's built-in metric tracking) + +--- + +**Date:** 2026-06-14 +**Scope:** Code style and maintainability comparison between `SemanticDeduplicationWorkflow` +(the established pattern in `nemo_curator/stages/deduplication/semantic/workflow.py` and its +image tutorial `tutorials/image/getting-started/image_dedup_example.py`) versus the Dripper +CC-scale tutorial scripts under `tutorials/text/dripper-common-crawl/`. + +--- + +## 1. Entry Point / User API + +**SemanticDedup approach:** +```python +# tutorials/image/getting-started/image_dedup_example.py — 8 lines to run the full pipeline +pipeline = SemanticDeduplicationWorkflow( + input_path=args.embeddings_dir, + output_path=args.removal_parquets_dir, + id_field="image_id", + embedding_field="embedding", + n_clusters=100, + eps=0.01, +) +pipeline.run(pairwise_executor=executor) # single call; returns WorkflowRunResult +``` + +**Dripper current approach:** +```bash +# To run the full pipeline the user must: +# 1. Edit configs/template.yaml with cluster paths, model params, resource overrides +# 2. python run_pipeline.py --config configs/template.yaml +# → SSH to a login node, generate 7+ sbatch scripts, submit them one by one via aftercorr +# 3. Monitor 7 Slurm array jobs (stage1a/1b/gpu_pipeline/stage3/stage3b_build/3b_gpu/3b_merge) +# 4. Optionally call: python compare_f1.py --baseline ... --pipeline ... +``` + +**Gap:** The Dripper tutorial has no single Python entry point that a developer can call +in a local or CI environment. The "entry point" (`run_pipeline.py`) is a Slurm-SSH +orchestrator that requires a live cluster with hardcoded Lustre paths, not a composable +Python API. A reviewer cannot run `python tutorial.py` to see the pipeline work. + +**Fix:** Mirror the `DripperHTMLWorkflow` class (already in +`nemo_curator/stages/text/experimental/dripper/workflow.py`) in the tutorial by adding a +`demo.py` or `quickstart.py` that instantiates `DripperHTMLWorkflow` and calls +`workflow.run(executor)` — the same one-liner pattern the SemanticDedup image tutorial +uses. + +--- + +## 2. Stage Construction Pattern + +**SemanticDedup approach:** +```python +# Internally, SemanticDeduplicationWorkflow builds stages in _run_kmeans_stage / +# _run_pairwise_stage via named, typed constructors: +kmeans_stage = KMeansStage( + n_clusters=self.n_clusters, + id_field=self.id_field, + embedding_field=self.embedding_field, + input_path=self.input_path, + output_path=self.kmeans_output_path, + ... +) +pipeline.add_stage(kmeans_stage) +``` + +**Dripper current approach:** +```python +# stage_gpu_pipeline.py — stages are constructed dynamically via a factory function +# that builds anonymous ProcessingStage subclasses closed over free callables: +def _make_stage_cls(stage_name: str, setup_fn: Callable, process_fn: Callable) -> type: + """Build a NeMo ProcessingStage class, cached by stage_name.""" + class _Stage(ProcessingStage[_DocumentBatch, _DocumentBatch]): + name = stage_name + resources = Resources(cpus=1.0) + batch_size = 1 + def setup(self, _worker_metadata=None): setup_fn() + def process_batch(self, tasks): ... + _STAGE_CLS_CACHE[stage_name] = _Stage + return _Stage +``` + +**Gap:** The dynamic `_make_stage_cls` pattern produces anonymous, unconfigurable stage +classes that are harder to introspect, test, and reuse. There is no stable class name to +`isinstance`-check or import in tests. The SemanticDedup pattern uses named, first-class +`ProcessingStage` subclasses (`KMeansStage`, `PairwiseStage`) that can be imported and +composed independently. + +**Fix:** Replace `_make_stage_cls` with proper named `ProcessingStage` subclasses +(e.g. `DripperHTML1cPreprocessStage`) that live in `nemo_curator/stages/`. The workflow +file already does this correctly for the library-level stages; the tutorial should import +them rather than reinvent them. + +--- + +## 3. Configuration + +**SemanticDedup approach:** +```python +# All configuration is expressed as typed __init__ parameters with defaults: +# nemo_curator/stages/deduplication/semantic/workflow.py +class SemanticDeduplicationWorkflow(WorkflowBase): + def __init__( + self, + input_path: str | list[str], + output_path: str, + n_clusters: int, + eps: float | None = None, + distance_metric: Literal["cosine", "l2"] = "cosine", + which_to_keep: Literal["hard", "easy", "random"] = "hard", + verbose: bool = True, + ... + ): +``` + +**Dripper current approach:** +```yaml +# configs/template.yaml — resource and model params in YAML +resources: + gpu_pipeline: + model: "opendatalab/MinerU-HTML-v1.1-hunyuan0.5B-compact" + max_tokens: 2048 + gpu_mem_util: 0.90 + max_num_seqs: 512 +``` +```python +# stage_gpu_pipeline.py — same params duplicated as argparse arguments: +p.add_argument("--model", default="opendatalab/MinerU-HTML-v1.1-hunyuan0.5B-compact") +p.add_argument("--max-tokens", type=int, default=2048) +p.add_argument("--gpu-mem-util", type=float, default=0.90) +p.add_argument("--max-num-seqs", type=int, default=512) +``` + +**Gap:** Model and resource parameters are defined twice: once in `configs/template.yaml` +and once as `argparse` defaults in each stage script. There is no single authoritative +source of truth. Adding a new parameter requires editing both files; defaults can silently +diverge. The YAML schema is also undocumented (no schema validation or dataclass mapping). + +**Fix:** Map the YAML config directly onto the `DripperHTMLWorkflow` dataclass fields. +Provide a `DripperConfig.from_yaml(path)` classmethod that validates types, so the YAML +becomes a serialization of the typed Python config rather than a separate parallel format. + +--- + +## 4. LOC Comparison + +| File | LOC | Purpose | +|---|---|---| +| `image_dedup_example.py` (SemanticDedup tutorial) | 301 | Full runnable image dedup pipeline | +| `nemo_curator/stages/deduplication/semantic/workflow.py` | 431 | Library workflow class | +| **SemanticDedup total** | **732** | | +| `stage_gpu_pipeline.py` | 660 | Combined stages 1c+2+2b | +| `stage3_cpu_propagation.py` | 858 | Stage 3 propagation | +| `run_pipeline.py` | 718 | Slurm orchestrator | +| `compare_f1.py` | 143 | Validation script | +| `stage1b_gpu_dbscan.py` | 357 | Stage 1b clustering | +| `stage1c_cpu_preprocess.py` | 137 | Stage 1c preprocessing | +| `stage3b_fallback_llm.py` | 135 | Stage 3b fallback | +| `pipeline_metrics.py` | 265 | Metrics tracking | +| **Dripper tutorial total** | **3,273** | (tutorial scripts only) | +| **Total dripper lines added in PR** | **~9,114** | (git diff stat) | + +**Gap:** The Dripper tutorial is 4.5x larger than the SemanticDedup tutorial to express a +conceptually similar "run pipeline, get output" operation. Much of this LOC lives in +bespoke SSH/Slurm orchestration, inline subprocess management, and duplicated argparse +boilerplate that the SemanticDedup pattern encapsulates in reusable library classes. + +**Fix:** Move the reusable logic (stage classes, argparse defaults, metrics) into the +library (`nemo_curator/stages/text/experimental/dripper/`). The tutorial should thin down +to ~150–200 LOC, importing from the library rather than reimplementing it. + +--- + +## 5. Error Handling + +**SemanticDedup approach:** +```python +# nemo_curator/stages/deduplication/semantic/workflow.py +def run(self, ...): + try: + self._setup_directories() + ... + return workflow_result + except Exception as e: + logger.error(f"Semantic deduplication pipeline failed: {e}") + raise # re-raise so the caller sees the original exception and traceback +``` +Configuration errors are caught eagerly in `_validate_config()` with typed `ValueError` / +`TypeError` before any compute begins. + +**Dripper current approach:** +```python +# stage_gpu_pipeline.py — bare except swallows errors into the output record +try: + case = _b.case_cls(_b.input_cls(raw_html=html, url=url)) + ... +except Exception as exc: + out["prompt"] = f"ERROR:{type(exc).__name__}:{str(exc)[:100]}" + +# stage3_cpu_propagation.py — similar pattern +try: + ... +except Exception as exc: + logger.debug("loader failed; trying next") + +# stage3_cpu_propagation.py — corrupt-file recovery silently unlinks +try: + meta = pq.read_metadata(str(out_path)) +except OSError: + out_path.unlink(missing_ok=True) # corrupt file — remove and reprocess +``` + +**Gap:** Dripper tutorials use broad `except Exception` guards in many hot-path functions, +converting errors into silent per-record error strings or log-only debug messages. This +means a systematic misconfiguration (wrong model path, missing column) can process +millions of pages and only be detected by inspecting `dripper_error` fields in output +parquet files rather than raising at startup. The SemanticDedup pattern validates eagerly +and re-raises so CI detects failures immediately. + +**Fix:** Add a `validate()` method (or call it from `DripperHTMLWorkflow.__post_init__`) +that checks required inputs before any Ray workers are spawned. Reserve broad per-record +exception capture only for the innermost HTML-parsing call, and surface aggregate error +counts via the `WorkflowRunResult` metadata rather than silent sentinel strings. + +--- + +## 6. Type Annotation Completeness + +**SemanticDedup approach:** +``` +nemo_curator/stages/deduplication/semantic/workflow.py: 5/7 functions annotated (71%) +nemo_curator/stages/text/experimental/dripper/workflow.py: 2/2 functions annotated (100%) +``` +All public methods have full return-type annotations. `__init__` parameters use +`str | list[str]`, `Literal[...]`, typed defaults throughout. + +**Dripper current approach:** +``` +tutorials/text/dripper-common-crawl/stage_gpu_pipeline.py: 19/21 annotated (90%) +tutorials/text/dripper-common-crawl/stage3_cpu_propagation.py: 20/31 annotated (65%) +tutorials/text/dripper-common-crawl/compare_f1.py: 5/5 annotated (100%) +``` +Notable unannotated functions in `stage3_cpu_propagation.py`: + +```python +# Missing return type on several private helpers (31 total, 11 unannotated): +def _apply_ratio_guard(content, url, prop_config): # no -> annotation +def _try_lbp_once(row, prop_config): # no -> annotation +def _sibling_propagate(siblings, gpu_row, ...): # no -> annotation +def _make_rep_or_singleton_row(row, role): # no -> annotation +def _make_fallback_row(row, role, error): # no -> annotation +``` + +**Gap:** `stage3_cpu_propagation.py` has 65% annotation coverage — a 35-point gap from +the SemanticDedup library style. Missing annotations on functions with complex return +types (`dict[str, Any]`, `list[dict]`) make it harder for mypy and IDE tooling to catch +bugs at authorship time. + +**Fix:** Add `-> dict[str, Any]` / `-> list[dict[str, Any]]` / `-> None` to the 11 +unannotated public and private helpers in `stage3_cpu_propagation.py`. Enable `mypy` in +CI for the tutorial directory with `--ignore-missing-imports`. + +--- + +## 7. Logging Style + +**SemanticDedup approach:** +```python +# nemo_curator/stages/deduplication/semantic/workflow.py +from loguru import logger # single consistent import + +logger.info("Starting K-means clustering stage (RayActorPoolExecutor)...") +logger.success(f"K-means clustering completed in {kmeans_time:.2f} seconds") +logger.warning( + f"n_clusters={self.n_clusters} is less than {MIN_RECOMMENDED_N_CLUSTERS}. ..." +) +logger.error(f"Semantic deduplication pipeline failed: {e}") +# 38 logger.* calls; 0 print() calls in the workflow +``` + +**Dripper current approach (mixed, inconsistent):** +```python +# stage_gpu_pipeline.py — uses print() with flush=True, no logger at all +print(f"[gpu-pipeline] Stage 1c: {ok:,}/{len(df):,} prompts in {elapsed:.1f}s", flush=True) +print(f"[gpu-pipeline] Stage 2: {len(df):,} pages over {n_gpus} GPUs", flush=True) +print(f"[gpu-pipeline] ALL DONE: ...", flush=True) +# 0 logger.* calls + +# stage3_cpu_propagation.py — uses stdlib logging.getLogger AND print() in the same file +logger = logging.getLogger(__name__) # stdlib, not loguru +... +logger.debug("pickle.loads from bytes failed; trying string decode") +print(f"[stage3] shard {shard_index}: {len(tasks):,} cluster tasks...", flush=True) +# 2 logger.* calls, 12 print() calls + +# compare_f1.py — print() only, 19 calls +print("[f1] loading baseline...", flush=True) + +# run_pipeline.py — logging.getLogger AND 5 print() calls +logger = logging.getLogger(__name__) +``` + +**Gap:** Across the four main Dripper tutorial files there are 43 `print()` calls and +only 7 `logger.*` calls (all using stdlib `logging`, not `loguru`). The `[stage-prefix]` +convention embedded in print strings is a manual workaround for the structured context +loguru provides natively. This makes it impossible to globally adjust log levels, redirect +to files, or suppress output in tests without patching `sys.stdout`. + +**Fix:** Replace all `print(f"[stage3] ...", flush=True)` calls with +`logger.info("...")` using `loguru` (matching the library convention). In test code, use +`loguru`'s `caplog`/`capfd` sink rather than patching stdout. + +--- + +## 8. Test Coverage Style + +**SemanticDedup approach:** +```python +# tests/stages/deduplication/semantic/test_workflow.py +class TestSemanticDeduplicationWorkflow: + def setup_method(self): + # Creates synthetic blobs in memory; no Slurm, no cluster needed + self.X, _ = make_blobs(n_samples=..., n_features=3, random_state=42) + self.df = pd.DataFrame({"id": ..., "embeddings": self.X.tolist()}) + + def test_semantic_deduplication_with_duplicate_identification(self, tmpdir, ...): + pipeline = SemanticDeduplicationWorkflow( + input_path=input_dir, output_path=output_dir, + n_clusters=self.n_clusters, eps=0.01, ... + ) + results = pipeline.run(pairwise_executor=executor) + assert results.get_metadata("total_time") > 0 + assert duplicates_identified == expected_removed # exact count verified +``` +Tests exercise the full Python API end-to-end; no subprocess spawning, no SSH, no Slurm. + +**Dripper current approach:** +```python +# tests/stages/text/experimental/dripper/test_stage.py +# Tests the underlying stage classes (good), but tests the tutorial-level +# orchestration only via the test_pipeline_correctness.py which: +# - Requires a running Ray cluster +# - Reads from filesystem paths set via environment variables +# - Has no synthetic data generation (needs pre-existing parquet files) +# tutorials/text/dripper-common-crawl/test_pipeline_correctness.py: +# "Run full pipeline on a small subset and verify F1 > threshold" +# → this is an integration test masquerading as a unit test +``` + +**Gap:** The Dripper library-level stage tests are good (`test_stage.py`), but the +tutorial has no self-contained unit test for the orchestration layer (the equivalent of +`test_workflow.py` for SemanticDedup). The only end-to-end test requires a live cluster. +SemanticDedup's test synthesizes data in-process and verifies exact duplicate counts, +giving immediate CI feedback. + +**Fix:** Add a `tests/stages/text/experimental/dripper/test_workflow.py` that instantiates +`DripperHTMLWorkflow` with a `FakeAsyncLLMClient`, generates a tiny in-memory HTML +dataset, runs the pipeline via `XennaExecutor`, and asserts on output column presence and +content length > 0. Mirror the `setup_method` / `tmpdir` pattern from +`test_workflow.py`. + +--- + +## 9. Documentation and Docstrings + +**SemanticDedup approach:** +```python +# nemo_curator/stages/deduplication/semantic/workflow.py — class-level docstring: +class SemanticDeduplicationWorkflow(WorkflowBase): + """ + End-to-End Semantic Deduplication Workflow. + It consists of the following stages: + - KMeansStage: ... + - PairwiseStage: ... + - IdentifyDuplicatesStage (optional): ... + """ + + def __init__(self, ...): + """ + Initialize the semantic deduplication workflow. + + Args: + input_path: Directory or list of directories containing input files with embeddings + output_path: Directory to write output files (i.e. ids to remove) + n_clusters: Number of clusters for K-means + eps: Epsilon value for duplicate identification + ... # every parameter documented + """ +``` + +**Dripper current approach:** +```python +# stage_gpu_pipeline.py — module docstring only, no class or __init__ docstrings +"""Combined Stage 1c + Stage 2 + Stage 2b in a single GPU job. + +Eliminates two intermediate parquet round-trips and two Slurm queue waits. +INPUT: Stage 1b output dir. OUTPUT: combined parquet with Stage 2b schema. +RUNS ON: batch GPU partition (8xH100). Replaces JOB1c + JOB2 + JOB2b. +""" +# _WorkerConfig dataclass has no field-level docstring: +@dataclass +class _WorkerConfig: + model: str + gpu_mem_util: float + max_model_len: int + max_num_seqs: int + max_num_batched_tokens: int + max_tokens: int + kv_cache_dtype: str + # No description of what each field does + +# DripperHTMLWorkflow (in nemo_curator/stages/text/experimental/dripper/workflow.py) +# has good class + field docstrings — but the tutorial files that call it do not. +``` + +**Gap:** The tutorial stage scripts (`stage_gpu_pipeline.py`, `stage3_cpu_propagation.py`) +have module-level docstrings and per-function docstrings on most private helpers, but no +`Args:` / `Returns:` sections in the Google/NumPy style used by the SemanticDedup +workflow. The `_WorkerConfig` and `_HyperParams` dataclasses lack field-level +documentation. A newcomer cannot tell which fields are required vs. optional or what the +units are (e.g. `gpu_mem_util` is a fraction 0.0–1.0, not a percentage). + +**Fix:** Add `Args:` / `Returns:` sections to the 10 public-facing functions in the +tutorial scripts. Add field comments (`#: fraction of GPU memory, 0.0–1.0`) to +`_WorkerConfig` and `_HyperParams`. + +--- + +## 10. Overall LOC in PR vs SemanticDedup Baseline + +```bash +# git diff origin/main --stat | grep -E "dripper|tutorial" | tail -5 + tutorials/text/dripper-common-crawl/stage3b_fallback_llm.py | 135 + + tutorials/text/dripper-common-crawl/stage_gpu_pipeline.py | 660 ++++ + tutorials/text/dripper-common-crawl/run_pipeline.py | 718 ++++ + tutorials/text/dripper-common-crawl/stage3_cpu_propagation.py | 858 +++++ + Total lines added (dripper + tutorial): ~9,114 +``` + +Compared to SemanticDedup (library + tutorial) which totals **732 lines** for full +end-to-end coverage, the Dripper PR adds **12.4x** more code to express a pipeline that +could theoretically be expressed in the same idiom. A large fraction of this overhead is: + +- Slurm/SSH orchestration that belongs in a cluster-specific runner, not the tutorial +- Bespoke argparse blocks repeated across 6 stage scripts (instead of one config dataclass) +- Inline `sys.path` manipulation (`sys.path.insert(0, str(Path(__file__).parent))`) +- `print(flush=True)` plumbing repeated instead of a shared logger + +--- + +## Prioritized TODO List + +### Priority 1 — Add a self-contained quickstart entry point +**Impact: Discoverability, testability** +Create `tutorials/text/dripper-common-crawl/quickstart.py` (~100 LOC) that: +- Instantiates `DripperHTMLWorkflow` from the library +- Uses a `FakeAsyncLLMClient` or a local model for smoke-test +- Calls `workflow.run(XennaExecutor())` +- Prints a summary table of results +This eliminates the "must have a Slurm cluster to try Dripper" barrier for new +contributors. + +### Priority 2 — Unify logging to loguru +**Impact: Debuggability, test isolation** +Replace all 43 `print(f"[stage-prefix] ...", flush=True)` calls in the four main tutorial +files with `from loguru import logger; logger.info(...)`. Remove `logging.getLogger` +usage in tutorial files (keep it only where stdlib `logging` is truly required for a +third-party library). This makes it possible to suppress output in tests and redirect to +files in production with a one-line sink configuration. + +### Priority 3 — Eliminate YAML/argparse configuration duplication +**Impact: Maintainability, correctness** +Add a `DripperConfig` dataclass (or extend `DripperHTMLWorkflow` fields) that can be +serialized to/from YAML. Remove the parallel argparse defaults in each stage script that +duplicate `configs/template.yaml`. A single `DripperConfig.from_yaml(path)` classmethod +provides one authoritative source of truth for all parameters. + +### Priority 4 — Add a `test_workflow.py` with synthetic data +**Impact: CI coverage, regression prevention** +Mirror `tests/stages/deduplication/semantic/test_workflow.py` for Dripper: a +`TestDripperHTMLWorkflow` class that builds a 10-row HTML dataset in memory, runs the +full pipeline with a fake client, and asserts on output columns and non-empty content. +This gives the same level of API coverage that SemanticDedup has without requiring a +Slurm cluster. + +### Priority 5 — Complete type annotations in `stage3_cpu_propagation.py` +**Impact: Type safety, IDE support** +Add return-type annotations to the 11 unannotated functions +(`_apply_ratio_guard`, `_try_lbp_once`, `_sibling_propagate`, +`_make_rep_or_singleton_row`, `_make_fallback_row`, and 6 others). Add +field-level docstrings to `_WorkerConfig` and `_HyperParams`. Enable `mypy` in CI for +the tutorial directory. This closes the 35-point annotation gap relative to the +SemanticDedup library style and will catch the next `dict` vs `list` confusion at +type-check time rather than at runtime. + +--- + +## 6. Return Type from workflow.run() + +**SemanticDedup approach:** +```python +result = workflow.run(executor) +result.get_metadata("final_output_path") # WorkflowRunResult with typed methods +``` + +**Dripper current approach:** +```python +result = workflow.run(executor) +result["output_tasks"] # plain dict — no typed access, no metadata protocol +``` + +**Gap:** DripperHTMLWorkflow.run() returns a plain dict instead of WorkflowRunResult. + +**Fix:** Return `WorkflowRunResult` from `nemo_curator.pipeline.workflow`. diff --git a/tutorials/text/dripper-common-crawl/compare_f1.py b/tutorials/text/dripper-common-crawl/compare_f1.py new file mode 100644 index 0000000000..ab77dbb7f1 --- /dev/null +++ b/tutorials/text/dripper-common-crawl/compare_f1.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""compare_f1.py — token-level F1: clustering pipeline vs standalone Dripper. + +Treats standalone Dripper (run B) as reference, Stage 3 output as prediction. +Reports F1 distribution overall and by cluster_role (multiset token overlap). +Both-empty → F1=1.0; one-empty → F1=0.0. +""" + +import argparse +import glob +import re +from collections import Counter + +import pyarrow.parquet as pq + +_TOK = re.compile(r"\w+", re.UNICODE) +_F1_HIGH = 0.80 + + +def tokenize(text: str) -> Counter: + return Counter(_TOK.findall(text.lower())) if text else Counter() + + +def f1(pred: str, ref: str) -> float: + pc, rc = tokenize(pred), tokenize(ref) + if not pc and not rc: + return 1.0 + if not pc or not rc: + return 0.0 + common = sum((pc & rc).values()) + if common == 0: + return 0.0 + p = common / sum(pc.values()) + r = common / sum(rc.values()) + return 2 * p * r / (p + r) + + +def load_url_content(path_glob: str, content_col: str) -> dict: + out = {} + for f in sorted(glob.glob(path_glob)): + pf = pq.ParquetFile(f) + cols = [c for c in ["url", content_col, "cluster_role"] if c in pf.schema_arrow.names] + for batch in pf.iter_batches(batch_size=4000, columns=cols): + for r in batch.to_pylist(): + u = r.get("url") + if u is None: + continue + out[str(u)] = (str(r.get(content_col) or ""), str(r.get("cluster_role") or "")) + return out + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--baseline", required=True, help="standalone dripper_results.parquet") + ap.add_argument("--pipeline", required=True, help="Stage 3 output dir (shard_*.parquet)") + ap.add_argument("--baseline-col", default="dripper_content") + ap.add_argument("--pipeline-col", default="dripper_content") + args = ap.parse_args() + + bglob = args.baseline if args.baseline.endswith(".parquet") else f"{args.baseline.rstrip('/')}/*.parquet" + pglob = args.pipeline if args.pipeline.endswith(".parquet") else f"{args.pipeline.rstrip('/')}/*.parquet" + base = load_url_content(bglob, args.baseline_col) + pipe = load_url_content(pglob, args.pipeline_col) + print(f"[f1] baseline={len(base):,} pipeline={len(pipe):,}", flush=True) + + common_urls = set(base) & set(pipe) + print( + f"[f1] common={len(common_urls):,} baseline-only={len(set(base) - set(pipe)):,} pipeline-only={len(set(pipe) - set(base)):,}", + flush=True, + ) + + scores: list[float] = [] + by_role: dict = {} + n_both_empty = 0 + for u in common_urls: + pred, role = pipe[u] + ref, _ = base[u] + s = f1(pred, ref) + scores.append(s) + by_role.setdefault(role or "unknown", []).append(s) + if not pred and not ref: + n_both_empty += 1 + + scores.sort() + n = len(scores) + mean = sum(scores) / n if n else 0.0 + median = scores[n // 2] if n else 0.0 + p10 = scores[int(0.10 * n)] if n else 0.0 + p25 = scores[int(0.25 * n)] if n else 0.0 + n_high = sum(1 for s in scores if s >= _F1_HIGH) + n_zero = sum(1 for s in scores if s == 0.0) + + print("\n" + "=" * 64) + print(" F1: clustering pipeline vs standalone Dripper (reference)") + print("=" * 64) + print(f" pages compared: {n:,}") + print(f" mean / median F1: {mean:.4f} / {median:.4f}") + print(f" p25 / p10 F1: {p25:.4f} / {p10:.4f}") + print(f" pages F1 >= {_F1_HIGH}: {n_high:,} ({n_high / max(n, 1) * 100:.1f}%)") + print(f" pages F1 == 0: {n_zero:,} ({n_zero / max(n, 1) * 100:.1f}%)") + print(f" both-empty (agree): {n_both_empty:,}") + print(" " + "-" * 60) + print(f" {'role':<16}{'pages':>10}{'mean F1':>10}{'>=0.80':>10}{'F1==0':>10}") + for role, ss in sorted(by_role.items()): + m = sum(ss) / len(ss) + ge = sum(1 for x in ss if x >= _F1_HIGH) / len(ss) * 100 + z = sum(1 for x in ss if x == 0.0) / len(ss) * 100 + print(f" {role:<16}{len(ss):>10,}{m:>10.4f}{ge:>9.1f}%{z:>9.1f}%") + print("=" * 64) + + +if __name__ == "__main__": + main() diff --git a/tutorials/text/dripper-common-crawl/configs/template.yaml b/tutorials/text/dripper-common-crawl/configs/template.yaml new file mode 100644 index 0000000000..94be4b92ba --- /dev/null +++ b/tutorials/text/dripper-common-crawl/configs/template.yaml @@ -0,0 +1,107 @@ +# ============================================================ +# Dripper CC Clustering Pipeline — Config Template +# Usage: python run_pipeline.py --config configs/my_run.yaml +# ============================================================ + +cluster: + login_node: "vjawa@nb-hel-cs-001-vscode-01.nvidia.com" + dc_node: "vjawa@nb-hel-cs-001-dc-01.nvidia.com" # fast transfer node + account: "nemotron_n4_pre" + venv: "/lustre/fsw/portfolios/llmservice/users/vjawa/dripper_cc_main_2025_26_smoke/.venv" + cached_venv: "/lustre/fsw/portfolios/llmservice/users/vjawa/dripper_cached_venv" + hf_cache: "/lustre/fsw/portfolios/llmservice/users/vjawa/hf_cache" + # repo root on cluster — must contain tutorials/text/dripper-common-crawl/ + remote_repo: "/lustre/fsw/portfolios/llmservice/projects/llmservice_fm_text/users/vjawa/nemo_curator_dripper_layout_clustering_20260611_194849/curator" + +# Output base — {snapshot} and {ts} (YYYYMMDD_HHMMSS) are expanded at runtime. +output_base: "/lustre/fsw/portfolios/llmservice/users/vjawa/cc_pipeline_{snapshot}_{ts}" + +# ── Snapshots to process ────────────────────────────────────── +snapshots: + - name: "CC-MAIN-2025-26" + manifest: "/lustre/fsw/portfolios/llmservice/users/vjawa/nemo_curator_dripper_sorted_host_buckets_20260611" + # Set to a pre-existing standalone output for validation (optional). + # Leave empty ("") to skip F1 validation for this snapshot. + validation_baseline: "" + + # Uncomment to add another snapshot: + # - name: "CC-MAIN-2024-51" + # manifest: "/lustre/.../cc_main_2024_51_manifest.parquet" + # validation_baseline: "" + +# ── Sharding ────────────────────────────────────────────────── +# All array stages must have the same shard count so aftercorr works. +sharding: + num_shards: 80 # total shards for stage1a, stage1b, stage3 + gpu_pipeline_shards: 80 # shards for stage 1c+2+2b GPU array + +# ── Validation ──────────────────────────────────────────────── +validation: + enabled: true + f1_threshold: 0.85 # warn/halt if mean F1 falls below this + halt_on_failure: false # if true, cancel stage3b downstream on F1 failure + sample_size: 10000 # sample N URLs for fast validation (full run is slow) + +# ── Resources per stage ─────────────────────────────────────── +resources: + stage1a: + partition: "cpu_short" + cpus: 64 + mem: "230G" + time: "04:00:00" + cpus_per_actor: 1 # 64 actors with 1 CPU each + + stage1b: + partition: "batch" + gpus_per_node: 1 + cpus: 4 + mem: "32G" + time: "12:00:00" + batch_size: 16 # hosts per actor call + gpu_min_size: 5 # min cluster size for GPU path + + gpu_pipeline: + partition: "batch" + gpus_per_node: 8 + cpus: 64 + mem: "240G" + time: "08:00:00" + model: "opendatalab/MinerU-HTML-v1.1-hunyuan0.5B-compact" + max_tokens: 2048 + gpu_mem_util: 0.90 + max_model_len: 32768 + max_num_seqs: 512 + max_num_batched_tokens: 16384 + kv_cache_dtype: "fp8" + + stage3: + partition: "cpu_short" + cpus: 64 + mem: "230G" + time: "01:00:00" + num_workers: 64 + + stage3b_build: + partition: "cpu_short" + cpus: 8 + mem: "64G" + time: "00:15:00" + + stage3b_gpu: + partition: "batch" + gpus_per_node: 8 + cpus: 64 + mem: "240G" + time: "01:00:00" + + stage3b_merge: + partition: "cpu_short" + cpus: 4 + mem: "32G" + time: "00:15:00" + + validation: + partition: "cpu_short" + cpus: 4 + mem: "16G" + time: "00:30:00" diff --git a/tutorials/text/dripper-common-crawl/quickstart.py b/tutorials/text/dripper-common-crawl/quickstart.py new file mode 100644 index 0000000000..433ffbd20f --- /dev/null +++ b/tutorials/text/dripper-common-crawl/quickstart.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dripper quickstart: DripperHTMLWorkflow on 20 synthetic pages. Use --dry-run for no-GPU mode.""" + +from __future__ import annotations + +import argparse +import sys + +import pandas as pd +from loguru import logger + + +def _make_synthetic_df(n: int = 20) -> pd.DataFrame: + templates = [ + "

{t}

{b}

", + "

{t}

{b}

", + "

{t}

{b}

", + ] + bodies = [ + "The quick brown fox jumps over the lazy dog.", + "Scientists discover a new method to improve efficiency.", + "Community gathers to celebrate the annual harvest festival.", + "Regular exercise improves cognitive function, study finds.", + "Markets close higher on strong earnings reports this quarter.", + ] + rows = [] + for i in range(n): + t, b = f"Article {i}", bodies[i % len(bodies)] + rows.append( + { + "url": f"https://example{i % 3}.com/page-{i:04d}", + "url_host_name": f"example{i % 3}.com", + "html": templates[i % len(templates)].format(t=t, b=b), + } + ) + return pd.DataFrame(rows) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Dripper quickstart — DripperHTMLWorkflow on synthetic data", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--server-url", default="http://localhost:8000/v1", help="Base URL of an OpenAI-compatible inference server." + ) + parser.add_argument( + "--model-name", + default="opendatalab/MinerU-HTML-v1.1-hunyuan0.5B-compact", + help="Model ID served at --server-url.", + ) + parser.add_argument("--dry-run", action="store_true", help="Skip LLM inference (no server needed).") + args = parser.parse_args() + + try: + from nemo_curator.backends.xenna import XennaExecutor + from nemo_curator.models.client.openai_client import OpenAIClient + from nemo_curator.stages.text.experimental.dripper import DripperHTMLWorkflow + from nemo_curator.tasks import DocumentBatch + except ImportError as exc: + logger.error("Run: pip install 'nemo-curator[dripper]'\n {}", exc) + sys.exit(1) + + # Build the LLM client (or a no-op stub for --dry-run) + if args.dry_run: + from nemo_curator.models.client.llm_client import AsyncLLMClient + + class _DryRunClient(AsyncLLMClient): + def __init__(self): + super().__init__(max_concurrent_requests=1, max_retries=0, base_delay=0.0) + + def setup(self): + pass + + async def _query_model_impl( + self, *, messages, model, conversation_formatter=None, generation_config=None + ) -> list[str]: + return [""] + + client = _DryRunClient() + logger.info("Dry-run mode: LLM inference skipped.") + else: + client = OpenAIClient(model=args.model_name, base_url=args.server_url, api_key="EMPTY") + logger.info("Using OpenAI-compatible client at {}", args.server_url) + + # Construct the workflow + workflow = DripperHTMLWorkflow( + client=client, + model_name=args.model_name, + perform_layout_clustering=True, + layout_cluster_threshold=0.95, + fallback="trafilatura", + output_format="mm_md", + ) + + # Build input tasks from a 20-row in-memory DataFrame + df = _make_synthetic_df(n=20) + initial_tasks = [DocumentBatch(task_id="quickstart-0", dataset_name="synthetic", data=df)] + logger.info("Running DripperHTMLWorkflow on {} synthetic pages...", len(df)) + + # Run + result = workflow.run(executor=XennaExecutor(), initial_tasks=initial_tasks) + + # Show results + output_tasks = result.pipeline_tasks.get("dripper_html_extraction") or [] + if output_tasks: + out_df = output_tasks[0].to_pandas() + sample_cols = [c for c in ["url", "dripper_content", "dripper_error"] if c in out_df.columns] + print(out_df[sample_cols].head(5).to_string()) + else: + logger.warning("No output tasks returned — check your pipeline configuration.") + + +if __name__ == "__main__": + main() diff --git a/tutorials/text/dripper-common-crawl/run_pipeline.py b/tutorials/text/dripper-common-crawl/run_pipeline.py new file mode 100644 index 0000000000..2259fd2e12 --- /dev/null +++ b/tutorials/text/dripper-common-crawl/run_pipeline.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Single-command Dripper pipeline: input parquet(s) → output parquet with extracted content. + +Usage (recommended — layout clustering for host-chunked input): + + python run_pipeline.py \\ + --input /data/host_pages.parquet \\ + --output /data/output/ \\ + --server-url http://localhost:8000/v1 + +Usage (standalone — no clustering, every page gets its own LLM call): + + python run_pipeline.py --input /data/pages.parquet --output /data/output/ \\ + --server-url http://localhost:8000/v1 --no-clustering + +Input parquet must have: url, html (url_host_name recommended for clustering) +Output adds: dripper_content, dripper_html, dripper_error + +Pipeline stages: + With clustering (default): Preprocess → LayoutTemplate (cluster + LLM reps + propagate siblings) + Without clustering: Preprocess → Inference → Postprocess +""" + +from __future__ import annotations + +import argparse +import os +import sys +import time +from pathlib import Path + +import pandas as pd +import pyarrow.parquet as pq +from loguru import logger + + +def _load_input(path: str) -> pd.DataFrame: + p = Path(path) + if p.is_dir(): + files = sorted(p.glob("*.parquet")) + if not files: + raise FileNotFoundError(f"No parquet files in {path}") + return pd.concat([pq.read_table(f).to_pandas() for f in files], ignore_index=True) + return pq.read_table(str(p)).to_pandas() + + +def run(args: argparse.Namespace) -> int: + from nemo_curator.backends.ray_actor_pool import RayActorPoolExecutor + from nemo_curator.models.client.openai_client import OpenAIClient + from nemo_curator.stages.text.experimental.dripper import DripperHTMLWorkflow + from nemo_curator.tasks import DocumentBatch + + t0 = time.perf_counter() + df = _load_input(args.input) + logger.info("Loaded {:,} pages from {}", len(df), args.input) + + missing = {"url", "html"} - set(df.columns) + if missing: + logger.error("Input missing required columns: {}", sorted(missing)) + return 1 + + client = OpenAIClient(model=args.model_name, base_url=args.server_url, api_key="EMPTY") + workflow = DripperHTMLWorkflow( + client=client, + model_name=args.model_name, + html_col=args.html_col, + url_col=args.url_col, + output_col=args.output_col, + perform_layout_clustering=not args.no_clustering, + layout_cluster_threshold=args.cluster_threshold, + fallback=args.fallback, + output_format=args.output_format, + max_concurrent_requests=args.max_concurrent_requests, + health_check=not args.no_health_check, + ) + + chunk = max(1, len(df) // max(1, args.workers)) + tasks = [ + DocumentBatch(dataset_name="dripper", data=df.iloc[i : i + chunk].reset_index(drop=True)) + for i in range(0, len(df), chunk) + ] + result = workflow.run(executor=RayActorPoolExecutor(), initial_tasks=tasks) + output_tasks = result.pipeline_tasks.get("dripper_html_extraction", []) + if not output_tasks: + logger.error("Pipeline returned no output — check server and logs") + return 1 + + out_df = pd.concat([t.to_pandas() for t in output_tasks], ignore_index=True) + + # Summary + n = len(out_df) + ok = int(out_df.get(args.output_col, pd.Series()).astype(str).str.len().gt(10).sum()) + elapsed = time.perf_counter() - t0 + logger.info( + "Done — pages={:,} content_ok={} ({:.0f}%) elapsed={:.1f}s ({:.0f} p/s)", + n, + ok, + 100 * ok / max(1, n), + elapsed, + n / max(elapsed, 0.001), + ) + + # Write output + out_dir = Path(args.output) + out_dir.mkdir(parents=True, exist_ok=True) + stem = Path(args.input).stem if not Path(args.input).is_dir() else "output" + out_path = out_dir / f"{stem}.parquet" + tmp = out_path.with_suffix(f".tmp_{os.getpid()}.parquet") + out_df.to_parquet(str(tmp), index=False, compression="snappy") + tmp.rename(out_path) + logger.info("Output → {}", out_path) + return 0 + + +def main() -> int: + p = argparse.ArgumentParser( + description="Dripper HTML extraction: input parquet → output parquet with extracted content", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p.add_argument("--input", required=True, help="Input parquet file or directory (url, html required)") + p.add_argument("--output", required=True, help="Output directory") + p.add_argument("--server-url", default="http://localhost:8000/v1", help="OpenAI-compatible server URL") + p.add_argument("--model-name", default="opendatalab/MinerU-HTML-v1.1-hunyuan0.5B-compact") + p.add_argument("--no-clustering", action="store_true", help="Standalone extraction (no layout clustering)") + p.add_argument("--cluster-threshold", type=float, default=0.95, help="DOM similarity threshold") + p.add_argument("--fallback", default="trafilatura", choices=["trafilatura", "bypass", "empty"]) + p.add_argument("--output-format", default="mm_md") + p.add_argument("--output-col", default="dripper_content", help="Name of output content column") + p.add_argument("--html-col", default="html") + p.add_argument("--url-col", default="url") + p.add_argument("--max-concurrent-requests", type=int, default=64) + p.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 2)) + p.add_argument("--no-health-check", action="store_true") + p.add_argument("--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"]) + args = p.parse_args() + + logger.remove() + logger.add(sys.stdout, level=args.log_level.upper()) + return run(args) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tutorials/text/dripper-common-crawl/stage1a_feature_extraction.py b/tutorials/text/dripper-common-crawl/stage1a_feature_extraction.py new file mode 100644 index 0000000000..ea8f7845ab --- /dev/null +++ b/tutorials/text/dripper-common-crawl/stage1a_feature_extraction.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage 1a: CPU-only DOM feature extraction via llm_web_kit get_feature().""" + +import argparse +import json +import os +from pathlib import Path + +import pandas as pd +import pyarrow.parquet as pq +from loguru import logger + +from nemo_curator.backends.ray_actor_pool import RayActorPoolExecutor +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import DocumentBatch + +OUTPUT_COLS = [ + "url", + "url_host_name", + "html", + "dom_feature", + "warc_filename", + "warc_record_offset", + "warc_record_length", +] + + +class DOMFeatureExtractionStage(ProcessingStage[DocumentBatch, DocumentBatch]): + name: str = "DOMFeatureExtractionStage" + + def __init__(self, cpus_per_actor: int = 4) -> None: + super().__init__() + self._resources = Resources(cpus=float(cpus_per_actor)) + self._web = None + + def setup(self, _worker_metadata: object = None) -> None: + from nemo_curator.stages.text.experimental.dripper.stage import _load_llm_web_kit_bindings + + self._web = _load_llm_web_kit_bindings() + + def process(self, batch: DocumentBatch) -> DocumentBatch: + df = batch.to_pandas().copy() + + def _extract(html: object) -> str: + if isinstance(html, bytes): + html = html.decode("utf-8", errors="replace") + if not isinstance(html, str) or not html.strip(): + return "" + try: + return json.dumps(self._web.get_feature(html)) + except Exception: + return "" + + df["dom_feature"] = [_extract(h) for h in df["html"]] + return DocumentBatch(dataset_name=batch.dataset_name, data=df) + + +def run(args: argparse.Namespace) -> None: + inp = Path(args.input) + if inp.is_dir(): + exact = inp / f"shard_{args.shard_index:04d}.parquet" + if exact.exists(): + inp = exact + else: + candidates = sorted(inp.glob("*.parquet")) + if not candidates: + raise FileNotFoundError(f"No parquet files in {args.input}") + inp = candidates[0] + + pf = pq.ParquetFile(str(inp)) + total = pf.metadata.num_rows + start = total * args.shard_index // args.num_shards + end = total * (args.shard_index + 1) // args.num_shards + need = ["url", "url_host_name", "html", "warc_filename", "warc_record_offset", "warc_record_length"] + cols = [c for c in need if c in pf.schema_arrow.names] + rows_seen, parts = 0, [] + for batch in pf.iter_batches(batch_size=65_536, columns=cols): + df_b = batch.to_pandas() + lo, hi = max(0, start - rows_seen), min(len(df_b), end - rows_seen) + rows_seen += len(df_b) + if lo < hi: + parts.append(df_b.iloc[lo:hi]) + if rows_seen >= end: + break + shard_df = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame(columns=cols) + logger.info("shard {}/{}: {:,} pages", args.shard_index, args.num_shards, len(shard_df)) + if len(shard_df) == 0: + return + + n_actors = max(1, (os.cpu_count() or 4) // max(1, args.cpus_per_actor)) + chunk = max(1, len(shard_df) // n_actors) + tasks = [ + DocumentBatch(dataset_name="stage1a", data=shard_df.iloc[i : i + chunk].reset_index(drop=True)) + for i in range(0, len(shard_df), chunk) + ] + stage = DOMFeatureExtractionStage(cpus_per_actor=args.cpus_per_actor) + pipeline = Pipeline(name="stage1a") + pipeline.add_stage(stage) + result_tasks = pipeline.run(executor=RayActorPoolExecutor(), initial_tasks=tasks) or [] + + out_df = ( + pd.concat([t.to_pandas() for t in result_tasks if hasattr(t, "to_pandas")], ignore_index=True) + if result_tasks + else pd.DataFrame(columns=OUTPUT_COLS) + ) + for col in OUTPUT_COLS: + if col not in out_df.columns: + out_df[col] = None + + out = Path(args.output) + out.mkdir(parents=True, exist_ok=True) + out_path = out / (f"shard_{args.shard_index:04d}.parquet" if args.num_shards > 1 else "shard_0000.parquet") + tmp = out_path.with_suffix(".parquet.tmp") + out_df.to_parquet(str(tmp), index=False, compression="snappy") + tmp.rename(out_path) + + feat_ok = int((out_df["dom_feature"].astype(str) != "").sum()) + logger.info("feature_ok={}/{} output -> {}", feat_ok, len(out_df), out_path) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--input", required=True) + p.add_argument("--output", required=True) + p.add_argument("--shard-index", type=int, default=int(os.environ.get("SLURM_ARRAY_TASK_ID", "0"))) + p.add_argument("--num-shards", type=int, default=1) + p.add_argument("--cpus-per-actor", type=int, default=4) + run(p.parse_args()) + + +if __name__ == "__main__": + main() diff --git a/tutorials/text/dripper-common-crawl/stage1b_gpu_dbscan.py b/tutorials/text/dripper-common-crawl/stage1b_gpu_dbscan.py new file mode 100644 index 0000000000..32fc86f107 --- /dev/null +++ b/tutorials/text/dripper-common-crawl/stage1b_gpu_dbscan.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage 1b: GPU DBSCAN clustering of DOM layout features → cluster assignments.""" + +from __future__ import annotations + +import argparse +import json +import os +import time +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +from loguru import logger + +from nemo_curator.backends.ray_actor_pool import RayActorPoolExecutor +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import DocumentBatch + +OUTPUT_COLS = [ + "url", + "url_host_name", + "html", + "cluster_id", + "cluster_role", + "layout_cluster_id", + "is_representative", + "cluster_size", + "warc_filename", + "warc_record_offset", + "warc_record_length", +] + + +def _singleton_row(url: str, host: str, html: object, src: dict, include_html: bool = True) -> dict: + row: dict[str, Any] = { + "url": url, + "url_host_name": host, + "cluster_id": "", + "cluster_role": "singleton", + "layout_cluster_id": "", + "is_representative": False, + "cluster_size": 1, + "warc_filename": src.get("warc_filename"), + "warc_record_offset": src.get("warc_record_offset"), + "warc_record_length": src.get("warc_record_length"), + } + if include_html: + row["html"] = html + return row + + +@dataclass(kw_only=True) +class HostDBSCANStage(ProcessingStage[DocumentBatch, DocumentBatch]): + """GPU DBSCAN clustering — one DocumentBatch per host, one GPU per Ray actor.""" + + name: str = "host_dbscan" + resources: Resources = field(default_factory=lambda: Resources(cpus=4.0, gpus=1.0)) + threshold: float = 0.95 + min_cluster_size: int = 2 + gpu_min_size: int = 5 + max_host_size: int = 3000 + _cluster_gpu: Any = field(init=False, repr=False, default=None) + _has_gpu: bool = field(init=False, repr=False, default=False) + _web: Any = field(init=False, repr=False, default=None) + + def setup(self, _worker_metadata: object = None) -> None: + from nemo_curator.stages.text.experimental.dripper.gpu_layout_clustering import ( + _gpu_available, + cluster_html_struct_gpu, + ) + from nemo_curator.stages.text.experimental.dripper.stage import _load_llm_web_kit_bindings + + self._cluster_gpu = cluster_html_struct_gpu + self._has_gpu = _gpu_available() + self._web = _load_llm_web_kit_bindings() + logger.info( + "actor setup: has_gpu={} CUDA_VISIBLE_DEVICES={}", + self._has_gpu, + os.environ.get("CUDA_VISIBLE_DEVICES", "unset"), + ) + + def process(self, batch: DocumentBatch) -> DocumentBatch: + samples = batch.to_pandas().to_dict("records") + return DocumentBatch( + dataset_name=batch.dataset_name, data=pd.DataFrame(self._cluster_host(batch.dataset_name, samples)) + ) + + def _run_clustering(self, chunk: list[dict], chunk_idx: int | None = None) -> list[dict]: + try: + if self._cluster_gpu and self._has_gpu and len(chunk) >= self.gpu_min_size: + cc, _ = self._cluster_gpu(chunk, threshold=self.threshold, gpu_min_size=self.gpu_min_size) + elif self._web: + cc, _ = self._web.cluster_html_struct(chunk, threshold=self.threshold) + else: + cc = chunk + for i, s in enumerate(cc): + s["layout_id"] = 0 if i == 0 else -1 + if chunk_idx is not None: + for s in cc: + lid = s.get("layout_id", -1) + if lid >= 0: + s["layout_id"] = chunk_idx * 100_000 + lid + except Exception as exc: + logger.warning("{} failed: {}", f"chunk {chunk_idx}" if chunk_idx is not None else "DBSCAN", exc) + cc = chunk + return cc + + def _cluster_host(self, host: str, samples: list[dict]) -> list[dict]: + if len(samples) > self.max_host_size: + clustered: list[dict] = [] + for ci, s in enumerate(range(0, len(samples), self.max_host_size)): + clustered.extend(self._run_clustering(samples[s : s + self.max_host_size], chunk_idx=ci)) + else: + clustered = self._run_clustering(samples) + by_lid: dict[int, list] = defaultdict(list) + for s in clustered: + by_lid[int(s.get("layout_id", -1))].append(s) + rows = [] + for lid, members in by_lid.items(): + if lid < 0 or len(members) < self.min_cluster_size: + for m in members: + rows.append(_singleton_row(m["url"], host, None, m, include_html=False)) + continue + cid = f"{host}:cluster_{lid}" + try: + rep_url = ( + self._web.select_representative_html( + [{"track_id": m["url"], "html": m.get("html", "")} for m in members] + )["track_id"] + if self._web + else members[0]["url"] + ) + except Exception: + rep_url = members[0]["url"] + for m in members: + is_rep = m["url"] == rep_url + rows.append( + { + "url": m["url"], + "url_host_name": host, + "cluster_id": cid, + "cluster_role": "representative" if is_rep else "sibling", + "layout_cluster_id": cid, + "is_representative": is_rep, + "cluster_size": len(members), + "warc_filename": m.get("warc_filename"), + "warc_record_offset": m.get("warc_record_offset"), + "warc_record_length": m.get("warc_record_length"), + } + ) + return rows + + +def run(args: argparse.Namespace) -> None: + inp = Path(args.input) + if inp.is_dir(): + exact = inp / f"shard_{args.shard_index:04d}.parquet" + inp = exact if exact.exists() else sorted(inp.glob("shard_*.parquet"))[0] + pf = pq.ParquetFile(str(inp)) + total = pf.metadata.num_rows + start = total * args.shard_index // args.num_shards + end = total * (args.shard_index + 1) // args.num_shards + need = ["url", "url_host_name", "dom_feature", "html", "warc_filename", "warc_record_offset", "warc_record_length"] + cols = [c for c in need if c in pf.schema_arrow.names] + rows_seen, parts = 0, [] + for batch in pf.iter_batches(batch_size=65_536, columns=cols): + df = batch.to_pandas() + lo, hi = max(0, start - rows_seen), min(len(df), end - rows_seen) + rows_seen += len(df) + if lo < hi: + parts.append(df.iloc[lo:hi]) + if rows_seen >= end: + break + shard_df = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame() + logger.info("shard {}/{}: {:,} pages", args.shard_index, args.num_shards, len(shard_df)) + if len(shard_df) == 0: + return + + html_lookup = {rec["url"]: rec.get("html") for rec in shard_df.to_dict("records")} + by_host: dict[str, list] = defaultdict(list) + singleton_rows: list[dict] = [] + for rec in shard_df.to_dict("records"): + feat_json = rec.get("dom_feature", "") + if not feat_json: + singleton_rows.append(_singleton_row(rec["url"], rec.get("url_host_name", ""), rec.get("html"), rec)) + continue + try: + feat = json.loads(feat_json) + except Exception: + feat = None + if feat is None: + continue + host = str(rec.get("url_host_name") or "") + by_host[host].append( + { + "track_id": rec["url"], + "url": rec["url"], + "html": rec.get("html", ""), + "feature": feat, + "warc_filename": rec.get("warc_filename"), + "warc_record_offset": rec.get("warc_record_offset"), + "warc_record_length": rec.get("warc_record_length"), + } + ) + + host_tasks = [DocumentBatch(dataset_name=h, data=pd.DataFrame(s)) for h, s in by_host.items()] + t0 = time.perf_counter() + stage = HostDBSCANStage( + threshold=args.threshold, + min_cluster_size=args.min_cluster_size, + gpu_min_size=args.gpu_min_size, + max_host_size=int(os.environ.get("STAGE1B_MAX_HOST_SIZE", "3000")), + ) + pipeline = Pipeline(name="stage1b_dbscan") + pipeline.add_stage(stage) + output_tasks = pipeline.run(executor=RayActorPoolExecutor(), initial_tasks=host_tasks) if host_tasks else [] + elapsed = time.perf_counter() - t0 + + out_dir = Path(args.output) + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / (f"shard_{args.shard_index:04d}.parquet" if args.num_shards > 1 else "shard_0000.parquet") + frames = [] + for task in output_tasks: + df = task.to_pandas() + if not df.empty: + if "html" not in df.columns: + df["html"] = df["url"].map(html_lookup) + frames.append(df[[c for c in OUTPUT_COLS if c in df.columns]]) + if singleton_rows: + sing_df = pd.DataFrame(singleton_rows) + if "html" not in sing_df.columns or sing_df["html"].isna().all(): + sing_df["html"] = sing_df["url"].map(html_lookup) + frames.append(sing_df[[c for c in OUTPUT_COLS if c in sing_df.columns]]) + out_df = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame(columns=OUTPUT_COLS) + tmp = out_path.with_suffix(".parquet.tmp") + pq.write_table(pa.Table.from_pandas(out_df, preserve_index=False), str(tmp), compression="snappy") + tmp.rename(out_path) + n_reps = int((out_df["cluster_role"] == "representative").sum()) + n_sing = int((out_df["cluster_role"] == "singleton").sum()) + logger.info( + "GPU DBSCAN done in {:.1f}s reps={} singletons={} call_reduction={:.1%}", + elapsed, + n_reps, + n_sing, + 1.0 - (n_reps + n_sing) / max(len(out_df), 1), + ) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--input", required=True) + p.add_argument("--output", required=True) + p.add_argument("--shard-index", type=int, default=int(os.environ.get("SLURM_ARRAY_TASK_ID", "0"))) + p.add_argument("--num-shards", type=int, default=1) + p.add_argument("--threshold", type=float, default=0.95) + p.add_argument("--min-cluster-size", type=int, default=2) + p.add_argument("--gpu-min-size", type=int, default=200) + run(p.parse_args()) + + +if __name__ == "__main__": + main() diff --git a/tutorials/text/dripper-common-crawl/stage1c_cpu_preprocess.py b/tutorials/text/dripper-common-crawl/stage1c_cpu_preprocess.py new file mode 100644 index 0000000000..e7f3f98e31 --- /dev/null +++ b/tutorials/text/dripper-common-crawl/stage1c_cpu_preprocess.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage 1c: CPU preprocessing for Stage 2 GPU inference (thin wrapper around DripperHTMLPreprocessStage).""" + +import argparse +import glob as _g +import os +from pathlib import Path + +import pandas as pd +import pyarrow.parquet as pq +from loguru import logger + +from nemo_curator.backends.ray_actor_pool import RayActorPoolExecutor +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.text.experimental.dripper import DripperHTMLPreprocessStage +from nemo_curator.tasks import DocumentBatch + +OUTPUT_COLS = [ + "url", + "url_host_name", + "cluster_id", + "cluster_role", + "dripper_simplified_html", + "dripper_mapped_html", + "_dripper_prompt", + "_dripper_needs_llm", + "dripper_item_count", + "html", + "warc_filename", + "warc_record_offset", + "warc_record_length", +] + + +def run(args: argparse.Namespace) -> None: + inp = Path(args.input) + if inp.is_dir(): + files = sorted(_g.glob(str(inp / f"shard_{args.shard_index:04d}.parquet"))) + if not files: + files = sorted(_g.glob(str(inp / "shard_*.parquet"))) + inp = Path(files[0]) if files else inp + + df = pq.ParquetFile(str(inp)).read().to_pandas() + + # Filter to representatives and singletons only + if "cluster_role" in df.columns: + mask = df["cluster_role"].isin(["representative", "singleton"]) + elif "is_representative" in df.columns: + mask = df["is_representative"].astype(bool) + else: + mask = pd.Series(True, index=df.index) + df = df[mask].reset_index(drop=True) + + logger.info("{:,} representative/singleton pages to preprocess", len(df)) + + out = Path(args.output) + out.mkdir(parents=True, exist_ok=True) + out_path = out / (f"shard_{args.shard_index:04d}.parquet" if args.num_shards > 1 else "shard_0000.parquet") + + if len(df) == 0: + pd.DataFrame(columns=OUTPUT_COLS).to_parquet(str(out_path), index=False) + return + + n_workers = args.workers + chunk = max(1, len(df) // n_workers) + tasks = [ + DocumentBatch(dataset_name="stage1c", data=df.iloc[i : i + chunk].reset_index(drop=True)) + for i in range(0, len(df), chunk) + ] + + # Simple Curator pattern: construct library stage, build pipeline, call run() + stage = DripperHTMLPreprocessStage( + html_col="html", + url_col="url", + worker_count=n_workers, + ) + pipeline = Pipeline(name="stage1c") + pipeline.add_stage(stage) + result_tasks = pipeline.run(executor=RayActorPoolExecutor(), initial_tasks=tasks) or [] + + result_df = pd.concat([t.to_pandas() for t in result_tasks], ignore_index=True) if result_tasks else df + + tmp = out_path.with_suffix(".parquet.tmp") + result_df.to_parquet(str(tmp), index=False, compression="snappy") + tmp.rename(out_path) + + # Count prompts successfully built (non-empty _dripper_prompt for rows that need LLM) + if "_dripper_prompt" in result_df.columns: + ok = int((result_df["_dripper_prompt"].astype(str).str.len() > 10).sum()) + else: + ok = 0 + logger.info("prompts_ok={}/{} output -> {}", ok, len(result_df), out_path) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--input", required=True, help="Stage 1b output dir or parquet") + p.add_argument("--output", required=True, help="Output dir") + p.add_argument("--shard-index", type=int, default=int(os.environ.get("SLURM_ARRAY_TASK_ID", "0"))) + p.add_argument("--num-shards", type=int, default=1) + p.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 2)) + run(p.parse_args()) + + +if __name__ == "__main__": + main() diff --git a/tutorials/text/dripper-common-crawl/stage2b_cpu_postprocess.py b/tutorials/text/dripper-common-crawl/stage2b_cpu_postprocess.py new file mode 100644 index 0000000000..aa5ffa6070 --- /dev/null +++ b/tutorials/text/dripper-common-crawl/stage2b_cpu_postprocess.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage 2b: CPU postprocessing from LLM responses (thin wrapper around DripperHTMLPostprocessStage).""" + +import argparse +import os +from pathlib import Path + +import pandas as pd +import pyarrow.parquet as pq +from loguru import logger + +from nemo_curator.backends.ray_actor_pool import RayActorPoolExecutor +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.text.experimental.dripper import DripperHTMLPostprocessStage +from nemo_curator.tasks import DocumentBatch + +_MIN_NONEMPTY_LEN: int = 5 +_MIN_ERROR_LEN: int = 2 + + +def run(args: argparse.Namespace) -> None: + inp = Path(args.input) + if inp.is_dir(): + files = sorted(inp.glob(f"shard_{args.shard_index:04d}.parquet")) or sorted(inp.glob("*.parquet")) + inp = files[0] if files else inp + + df = pq.ParquetFile(str(inp)).read().to_pandas() + logger.info("{:,} pages to postprocess ({} workers)", len(df), args.workers) + + n_workers = args.workers + chunk = max(1, len(df) // n_workers) + tasks = [ + DocumentBatch(dataset_name="stage2b", data=df.iloc[i : i + chunk].reset_index(drop=True)) + for i in range(0, len(df), chunk) + ] + + # Simple Curator pattern: construct library stage, build pipeline, call run() + stage = DripperHTMLPostprocessStage( + html_col="html", + url_col="url", + fallback="trafilatura", + output_format="mm_md", + worker_count=n_workers, + ) + pipeline = Pipeline(name="stage2b") + pipeline.add_stage(stage) + result_tasks = pipeline.run(executor=RayActorPoolExecutor(), initial_tasks=tasks) or [] + + result_df = pd.concat([t.to_pandas() for t in result_tasks], ignore_index=True) if result_tasks else df + + out = Path(args.output) + out.mkdir(parents=True, exist_ok=True) + out_path = out / ( + f"shard_{args.shard_index:04d}.parquet" if args.num_shards > 1 else "postprocess_results.parquet" + ) + tmp = out_path.with_suffix(".parquet.tmp") + result_df.to_parquet(str(tmp), index=False, compression="snappy") + tmp.rename(out_path) + + content_ok = int( + (result_df["dripper_content"].astype(str).str.len() > _MIN_NONEMPTY_LEN).sum() + if "dripper_content" in result_df.columns + else 0 + ) + errors = int( + (result_df["dripper_error"].astype(str).str.len() > _MIN_ERROR_LEN).sum() + if "dripper_error" in result_df.columns + else 0 + ) + logger.info( + "content_ok={}/{} errors={} output -> {}", + content_ok, + len(result_df), + errors, + out_path, + ) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--input", required=True, help="Stage 2 output dir") + p.add_argument("--output", required=True, help="Output dir") + p.add_argument("--shard-index", type=int, default=int(os.environ.get("SLURM_ARRAY_TASK_ID", "0"))) + p.add_argument("--num-shards", type=int, default=1) + p.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 2)) + run(p.parse_args()) + + +if __name__ == "__main__": + main() diff --git a/tutorials/text/dripper-common-crawl/stage3_cpu_propagation.py b/tutorials/text/dripper-common-crawl/stage3_cpu_propagation.py new file mode 100644 index 0000000000..04f6c47454 --- /dev/null +++ b/tutorials/text/dripper-common-crawl/stage3_cpu_propagation.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage 3: CPU propagation sharding wrapper (logic in DripperHTMLLayoutPropagationStage).""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import time +from pathlib import Path + +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +from loguru import logger + +from nemo_curator.backends.ray_actor_pool import RayActorPoolExecutor +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.text.experimental.dripper.propagation_stage import DripperHTMLLayoutPropagationStage +from nemo_curator.tasks import DocumentBatch + +OUTPUT_COLUMNS = [ + "url", + "url_host_name", + "cluster_id", + "cluster_role", + "dripper_content", + "dripper_html", + "dripper_error", + "dripper_time_s", + "propagation_success", + "propagation_method", +] +_MANIFEST_META_COLS = [ + "url", + "url_host_name", + "cluster_id", + "cluster_role", + "warc_filename", + "warc_record_offset", + "warc_record_length", +] +_INFERENCE_COLS = [ + "cluster_id", + "layout_cluster_id", + "url", + "llm_output_raw", + "inference_time_s", + "error", + "dripper_error", + "dripper_content", + "dripper_html", + "mapping_json", +] +_NULL_VALS = frozenset(("none", "null", "nan", "")) +_DEFAULT_NUM_SHARDS = 80 +_DEFAULT_NUM_WORKERS = int(os.environ.get("SLURM_CPUS_PER_TASK", "64")) + + +def _load_cluster_manifest_shard(path: str) -> pd.DataFrame: + sn = pq.read_schema(path).names + df = pq.read_table(path, columns=[c for c in _MANIFEST_META_COLS if c in sn]).to_pandas() + df.setdefault("cluster_id", None) + if "cluster_id" not in df.columns: + df["cluster_id"] = None + if "cluster_role" not in df.columns: + df["cluster_role"] = "singleton" + df["html"] = None + if "html" in sn: + smask = df["cluster_role"] == "sibling" + if smask.any(): + hdf = pq.read_table(path, columns=["url", "html"]).to_pandas().drop_duplicates("url", keep="first") + df.loc[smask, "html"] = df.loc[smask, "url"].map(hdf.set_index("url")["html"]) + return df + + +def _load_inference_results(path: str) -> pd.DataFrame: + sn = pq.read_schema(path).names + df = pq.read_table(path, columns=[c for c in _INFERENCE_COLS if c in sn]).to_pandas() + if "cluster_id" not in df.columns and "layout_cluster_id" in df.columns: + df = df.rename(columns={"layout_cluster_id": "cluster_id"}) + if "error" not in df.columns and "dripper_error" in df.columns: + df = df.rename(columns={"dripper_error": "error"}) + return df + + +def _load_gpu_df(gpu_dir: Path, shard_index: int, cluster_ids: set, urls: set) -> pd.DataFrame: + exact = gpu_dir / f"shard_{shard_index:04d}.parquet" + files = ( + [exact] if exact.exists() else (sorted(gpu_dir.glob("shard_*.parquet")) or sorted(gpu_dir.glob("*.parquet"))) + ) + if not files: + raise FileNotFoundError(f"No GPU inference result files found in {gpu_dir}") + frames = [] + for f in files: + try: + sdf = _load_inference_results(str(f)) + if sdf.empty: + continue + mask = pd.Series(False, index=sdf.index) + if "cluster_id" in sdf.columns and cluster_ids: + mask |= sdf["cluster_id"].astype(str).isin(cluster_ids) + if "url" in sdf.columns and urls: + null_cid = sdf["cluster_id"].isna() | sdf["cluster_id"].astype(str).isin(_NULL_VALS) + mask |= null_cid & sdf["url"].astype(str).isin(urls) + if not (filt := sdf[mask]).empty: + frames.append(filt) + except OSError as exc: + logger.warning("could not read GPU shard {}: {}", f, exc) + gpu_df = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame() + logger.info("{:,} GPU result rows loaded ({} files)", len(gpu_df), len(files)) + return gpu_df + + +def process_shard( + cluster_manifest_dir: str, + inference_results_dir: str, + output_dir: str, + shard_index: int, + num_shards: int, + num_workers: int, +) -> dict: + t_start = time.perf_counter() + out_dir = Path(output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / f"shard_{shard_index:04d}.parquet" + if out_path.exists(): + try: + meta = pq.read_metadata(str(out_path)) + if meta.num_rows > 0: + logger.info("SKIP shard {} — already exists ({:,} rows)", shard_index, meta.num_rows) + return {"status": "skipped", "shard": shard_index, "rows": meta.num_rows} + out_path.unlink(missing_ok=True) + except OSError: + out_path.unlink(missing_ok=True) + + manifest_dir = Path(cluster_manifest_dir) + all_files = sorted(manifest_dir.glob("shard_*.parquet")) or sorted(manifest_dir.glob("*.parquet")) + if not all_files: + raise FileNotFoundError(f"No manifest shards found in {manifest_dir}") + n = len(all_files) + my_files = all_files[n * shard_index // num_shards : n * (shard_index + 1) // num_shards] + if not my_files: + logger.info("shard {}: no manifest files — writing empty shard", shard_index) + pq.write_table(pa.table({c: [] for c in OUTPUT_COLUMNS}), str(out_path)) + return {"status": "empty", "shard": shard_index, "rows": 0} + + manifest_df = pd.concat([_load_cluster_manifest_shard(str(f)) for f in my_files], ignore_index=True) + logger.info("shard {}/{}: {:,} rows from {} file(s)", shard_index, num_shards, len(manifest_df), len(my_files)) + + cluster_ids = {str(r) for r in manifest_df["cluster_id"].dropna() if str(r).lower() not in _NULL_VALS} + urls = set(manifest_df["url"].astype(str)) + gpu_df = _load_gpu_df(Path(inference_results_dir), shard_index, cluster_ids, urls) + + mapping_by_cluster: dict = {} + for rec in gpu_df.to_dict("records"): + cid = str(rec.get("cluster_id") or "") + if cid and cid.lower() not in _NULL_VALS: + mapping_by_cluster.setdefault(cid, rec.get("mapping_json") or rec.get("llm_output_raw", "")) + + manifest_df["dripper_layout_cluster"] = manifest_df["cluster_id"].astype(str) + manifest_df["dripper_layout_representative"] = manifest_df["cluster_role"].isin(["representative", "singleton"]) + manifest_df["dripper_layout_mapping_json"] = ( + manifest_df["cluster_id"] + .astype(str) + .map(lambda cid: mapping_by_cluster.get(cid, "") if cid and cid.lower() not in _NULL_VALS else "") + ) + manifest_df["dripper_layout_pending_propagation"] = manifest_df["cluster_role"] == "sibling" + + stage = DripperHTMLLayoutPropagationStage(use_static_lbp=True) + pipeline = Pipeline(name="stage3_cpu_propagation") + pipeline.add_stage(stage) + chunk = max(1, len(manifest_df) // max(1, num_workers)) + doc_tasks = [ + DocumentBatch(dataset_name="stage3", data=manifest_df.iloc[i : i + chunk].reset_index(drop=True)) + for i in range(0, len(manifest_df), chunk) + ] + logger.info("submitting {:,} tasks ({} actors)...", len(doc_tasks), num_workers) + output_doc_tasks = pipeline.run(executor=RayActorPoolExecutor(), initial_tasks=doc_tasks) or [] + + frames = [t.to_pandas() for t in output_doc_tasks] + result_df = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame(columns=OUTPUT_COLUMNS) + result_df = result_df.rename( + columns={ + "dripper_layout_html": "dripper_html", + "dripper_layout_content": "dripper_content", + "dripper_layout_error": "dripper_error", + "dripper_layout_postprocess_time_s": "dripper_time_s", + "dripper_layout_propagation_success": "propagation_success", + "dripper_layout_propagation_method": "propagation_method", + } + ) + for col in OUTPUT_COLUMNS: + if col not in result_df.columns: + result_df[col] = None + + tmp = out_path.with_suffix(f".tmp_{os.getpid()}.parquet") + pq.write_table( + pa.Table.from_pandas(result_df[OUTPUT_COLUMNS], preserve_index=False), str(tmp), compression="snappy" + ) + tmp.rename(out_path) + + elapsed = time.perf_counter() - t_start + ns = int(result_df.get("propagation_success", pd.Series()).fillna(False).sum()) + logger.info( + "shard {} done pages={:,} success={} elapsed={:.1f}s output={}", + shard_index, + len(result_df), + ns, + elapsed, + out_path, + ) + metrics = { + "shard_index": shard_index, + "num_shards": num_shards, + "total_pages": len(result_df), + "success_pages": ns, + "elapsed_s": elapsed, + "output_path": str(out_path), + } + (out_dir / f"metrics_shard_{shard_index:04d}.json").write_text(json.dumps(metrics, indent=2)) + return metrics + + +def _apply_config_defaults(args: argparse.Namespace) -> argparse.Namespace: + if args.config is None: + return args + _configs_dir = Path(__file__).parent / "configs" + if str(_configs_dir) not in sys.path: + sys.path.insert(0, str(_configs_dir)) + from dripper_config import DripperConfig + + cfg = DripperConfig.from_yaml(args.config) + if args.num_shards == _DEFAULT_NUM_SHARDS: + args.num_shards = cfg.num_shards + if args.num_workers == _DEFAULT_NUM_WORKERS: + stage_res = cfg.resources.get("stage3", {}) + args.num_workers = int(stage_res.get("num_workers", stage_res.get("cpus", args.num_workers))) + return args + + +def main() -> int: + p = argparse.ArgumentParser( + description="Stage 3: CPU template propagation", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + p.add_argument("--config", default=None) + p.add_argument("--cluster-manifest", required=True) + p.add_argument("--inference-results", required=True) + p.add_argument("--output-dir", required=True) + p.add_argument("--shard-index", type=int, default=int(os.environ.get("SLURM_ARRAY_TASK_ID", "0"))) + p.add_argument("--num-shards", type=int, default=_DEFAULT_NUM_SHARDS) + p.add_argument("--num-workers", type=int, default=_DEFAULT_NUM_WORKERS) + p.add_argument("--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"]) + args = _apply_config_defaults(p.parse_args()) + logger.remove() + logger.add(sys.stdout, level=args.log_level.upper()) + logger.info( + "manifest={} gpu={} out={} shard={}/{} workers={}", + args.cluster_manifest, + args.inference_results, + args.output_dir, + args.shard_index, + args.num_shards, + args.num_workers, + ) + metrics = process_shard( + args.cluster_manifest, + args.inference_results, + args.output_dir, + args.shard_index, + args.num_shards, + args.num_workers, + ) + status = metrics.get("status", "done") + logger.info( + "Shard {} {}", + args.shard_index, + {"skipped": "already complete — skipped.", "empty": "had no input — wrote empty shard."}.get( + status, "complete." + ), + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tutorials/text/dripper-common-crawl/stage3b_gpu_llm_fallback.py b/tutorials/text/dripper-common-crawl/stage3b_gpu_llm_fallback.py new file mode 100644 index 0000000000..097e4aa1b8 --- /dev/null +++ b/tutorials/text/dripper-common-crawl/stage3b_gpu_llm_fallback.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage 3b: GPU LLM fallback for siblings where Stage 3 propagation failed. + +Without this stage, F1 is ~0.84. With it, F1 reaches ~0.92 (above the 0.90 target). + +Siblings where DripperHTMLLayoutPropagationStage returned propagation_success=False +(content ratio too high/low, no template, etc.) are re-run through the full LLM +extraction pipeline (DripperHTMLPreprocessStage -> GPU inference -> PostprocessStage). + +INPUT: Stage 3 propagation results (shard_*.parquet) + Stage 1b cluster manifest (for html column) +OUTPUT: Updated shard with failed siblings replaced by LLM extraction results +""" + +from __future__ import annotations + +import argparse +import os +import time +from pathlib import Path + +import pandas as pd +import pyarrow.parquet as pq +from loguru import logger + +_DEFAULT_SHARD_INDEX = int(os.environ.get("SLURM_ARRAY_TASK_ID", "0")) +_DEFAULT_NUM_SHARDS = 80 + + +def _load_failed_siblings( + propagation_dir: Path, + manifest_dir: Path, + shard_index: int, + num_shards: int, +) -> pd.DataFrame: + """Load siblings where propagation failed and attach their html for LLM re-inference.""" + prop_files = sorted(propagation_dir.glob("shard_*.parquet")) or sorted(propagation_dir.glob("*.parquet")) + if not prop_files: + raise FileNotFoundError(f"No propagation result files in {propagation_dir}") + + n = len(prop_files) + my_files = prop_files[n * shard_index // num_shards : n * (shard_index + 1) // num_shards] + if not my_files: + logger.info("shard {}: no propagation files — nothing to do", shard_index) + return pd.DataFrame() + + prop_df = pd.concat([pq.read_table(f).to_pandas() for f in my_files], ignore_index=True) + + # Select only siblings where propagation failed + failed_mask = ~prop_df.get("propagation_success", pd.Series(True, index=prop_df.index)).fillna(True).astype( + bool + ) & (prop_df.get("cluster_role", pd.Series("singleton", index=prop_df.index)) == "sibling") + failed_df = prop_df[failed_mask].copy() + if failed_df.empty: + logger.info("shard {}: no failed siblings — all propagation succeeded", shard_index) + return pd.DataFrame() + + logger.info("shard {}: {:,} / {:,} siblings need LLM fallback", shard_index, len(failed_df), len(prop_df)) + + # Load html from manifest for the failed siblings + manifest_files = sorted(manifest_dir.glob("shard_*.parquet")) or sorted(manifest_dir.glob("*.parquet")) + if not manifest_files: + raise FileNotFoundError(f"No manifest files in {manifest_dir}") + + failed_urls = set(failed_df["url"].astype(str)) + html_parts = [] + for mf in manifest_files: + schema = pq.read_schema(str(mf)).names + if "html" not in schema: + continue + cols = [c for c in ["url", "html"] if c in schema] + mdf = pq.read_table(str(mf), columns=cols).to_pandas() + matched = mdf[mdf["url"].astype(str).isin(failed_urls)] + if not matched.empty: + html_parts.append(matched) + + if not html_parts: + logger.warning("No html found for failed siblings — cannot run LLM fallback") + return pd.DataFrame() + + html_df = pd.concat(html_parts, ignore_index=True).drop_duplicates("url", keep="first") + failed_df = failed_df.merge(html_df[["url", "html"]], on="url", how="inner") + logger.info("shard {}: {:,} siblings with html for LLM fallback", shard_index, len(failed_df)) + return failed_df + + +def run_llm_fallback( + failed_df: pd.DataFrame, + model_name: str, + server_url: str, + max_concurrent_requests: int, + num_workers: int, +) -> pd.DataFrame: + """Run LLM extraction on failed siblings using library stages.""" + from nemo_curator.backends.ray_actor_pool import RayActorPoolExecutor + from nemo_curator.models.client.openai_client import OpenAIClient + from nemo_curator.pipeline import Pipeline + from nemo_curator.stages.text.experimental.dripper import ( + DripperHTMLPostprocessStage, + DripperHTMLPreprocessStage, + ) + from nemo_curator.stages.text.experimental.dripper._base_stages import DripperHTMLInferenceStage + from nemo_curator.tasks import DocumentBatch + + client = OpenAIClient(model=model_name, base_url=server_url, api_key="EMPTY") + + preprocess = DripperHTMLPreprocessStage(html_col="html", url_col="url", worker_count=num_workers) + inference = DripperHTMLInferenceStage( + client=client, + model_name=model_name, + max_concurrent_requests=max_concurrent_requests, + health_check=False, + ) + postprocess = DripperHTMLPostprocessStage( + html_col="html", + url_col="url", + fallback="trafilatura", + output_format="mm_md", + worker_count=num_workers, + ) + + pipeline = Pipeline(name="stage3b_llm_fallback") + pipeline.add_stage(preprocess) + pipeline.add_stage(inference) + pipeline.add_stage(postprocess) + + chunk = max(1, len(failed_df) // max(1, num_workers)) + tasks = [ + DocumentBatch(dataset_name="stage3b", data=failed_df.iloc[i : i + chunk].reset_index(drop=True)) + for i in range(0, len(failed_df), chunk) + ] + result_tasks = pipeline.run(executor=RayActorPoolExecutor(), initial_tasks=tasks) or [] + + frames = [t.to_pandas() for t in result_tasks] + return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame() + + +def process_shard(args: argparse.Namespace) -> dict: + t0 = time.perf_counter() + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + out_path = out_dir / f"shard_{args.shard_index:04d}.parquet" + + if out_path.exists(): + meta = pq.read_metadata(str(out_path)) + if meta.num_rows > 0: + logger.info("SKIP shard {} — already done ({:,} rows)", args.shard_index, meta.num_rows) + return {"status": "skipped", "shard": args.shard_index} + + failed_df = _load_failed_siblings( + Path(args.propagation_results), + Path(args.cluster_manifest), + args.shard_index, + args.num_shards, + ) + if failed_df.empty: + pq.write_table( + pq.read_schema(str(next(Path(args.propagation_results).glob("*.parquet")))).empty_table(), str(out_path) + ) + return {"status": "empty", "shard": args.shard_index, "fallback_rows": 0} + + result_df = run_llm_fallback( + failed_df, args.model_name, args.server_url, args.max_concurrent_requests, args.workers + ) + + tmp = out_path.with_suffix(f".tmp_{os.getpid()}.parquet") + result_df.to_parquet(str(tmp), index=False, compression="snappy") + tmp.rename(out_path) + + elapsed = time.perf_counter() - t0 + ok = ( + int(result_df["dripper_content"].astype(str).str.len().gt(5).sum()) + if "dripper_content" in result_df.columns + else 0 + ) + logger.info( + "shard {} done fallback_rows={:,} ok={} elapsed={:.1f}s output={}", + args.shard_index, + len(result_df), + ok, + elapsed, + out_path, + ) + return {"status": "done", "shard": args.shard_index, "fallback_rows": len(result_df), "ok": ok} + + +def main() -> int: + p = argparse.ArgumentParser(description="Stage 3b: GPU LLM fallback for failed propagation siblings") + p.add_argument("--propagation-results", required=True, help="Stage 3 output dir") + p.add_argument("--cluster-manifest", required=True, help="Stage 1b cluster assignment dir (needs html column)") + p.add_argument("--output-dir", required=True, help="Output dir for stage3b results") + p.add_argument("--model-name", default="opendatalab/MinerU-HTML-v1.1-hunyuan0.5B-compact") + p.add_argument("--server-url", default="http://localhost:8000/v1") + p.add_argument("--shard-index", type=int, default=_DEFAULT_SHARD_INDEX) + p.add_argument("--num-shards", type=int, default=_DEFAULT_NUM_SHARDS) + p.add_argument("--max-concurrent-requests", type=int, default=64) + p.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 2)) + p.add_argument("--log-level", default="INFO") + args = p.parse_args() + + import sys + + from loguru import logger as _log + + _log.remove() + _log.add(sys.stdout, level=args.log_level.upper()) + + process_shard(args) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tutorials/text/dripper-common-crawl/stage_gpu_pipeline.py b/tutorials/text/dripper-common-crawl/stage_gpu_pipeline.py new file mode 100644 index 0000000000..023372c66e --- /dev/null +++ b/tutorials/text/dripper-common-crawl/stage_gpu_pipeline.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Stage 1c + Stage 2 (vLLM) + Stage 2b GPU pipeline. Input: Stage 1b parquet.""" + +from __future__ import annotations + +import argparse +import os +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +import pandas as pd +import pyarrow.parquet as pq +from loguru import logger + +_REPO_ROOT = str(Path(__file__).parent.parent.parent.parent) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from pipeline_metrics import StageMetrics + +OUTPUT_COLS = [ + "url", + "url_host_name", + "cluster_id", + "cluster_role", + "mapping_json", + "dripper_content", + "dripper_html", + "dripper_error", + "inference_time_s", +] +_GPU_SLICE_COLS = ["url", "prompt", "item_count", "cluster_id", "cluster_role", "url_host_name"] +_MIN_CONTENT_LEN, _MIN_ERROR_LEN, _MIN_PROMPT_LEN = 5, 2, 10 + + +def run_stage1c(df: pd.DataFrame) -> pd.DataFrame: + from nemo_curator.stages.text.experimental.dripper.preprocessing import DripperHTMLPreprocessStage + + from nemo_curator.backends.ray_actor_pool import RayActorPoolExecutor + from nemo_curator.pipeline import Pipeline + from nemo_curator.tasks import DocumentBatch + + t0 = time.perf_counter() + n_workers = max(1, (os.cpu_count() or 4) - 2) + chunk = max(1, len(df) // n_workers) + tasks = [ + DocumentBatch(dataset_name="stage1c", data=df.iloc[i : i + chunk].reset_index(drop=True)) + for i in range(0, len(df), chunk) + ] + stage = DripperHTMLPreprocessStage(html_col="html", url_col="url", worker_count=n_workers) + pipeline = Pipeline(name="stage1c") + pipeline.add_stage(stage) + result_tasks = pipeline.run(executor=RayActorPoolExecutor(), initial_tasks=tasks) or [] + out = pd.concat([t.to_pandas() for t in result_tasks], ignore_index=True) + ok = (out.get("prompt", out.get("_dripper_prompt", pd.Series())).astype(str).str.len() > _MIN_PROMPT_LEN).sum() + logger.info("Stage 1c: {:,}/{:,} prompts in {:.1f}s", ok, len(df), time.perf_counter() - t0) + return out + + +@dataclass +class _Cfg: + model: str + gpu_mem_util: float + max_model_len: int + max_num_seqs: int + max_num_batched_tokens: int + max_tokens: int + kv_cache_dtype: str + + +def _build_worker_prompts(rows, tok, max_model_len, max_tokens): + from vllm import SamplingParams + + supports_think: list[bool] = [True] + prompts, samplings, ridx, results, n_trunc = [], [], [], [None] * len(rows), 0 + for i, r in enumerate(rows): + p = str(r.get("prompt", "") or "") + if not p or p.startswith("ERROR:"): + results[i] = { + **r, + "llm_response": "", + "dripper_error": p if p.startswith("ERROR:") else "empty_prompt", + "inference_time_s": 0.0, + } + continue + ic = max(0, int(r.get("item_count", 0) or 0)) + max_tok = min(max_tokens, max(32, ic * 6 + 16) if ic > 0 else max_tokens) + msgs = [{"role": "user", "content": p}] + if supports_think[0]: + try: + text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True, enable_thinking=False) + except TypeError: + supports_think[0] = False + text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) + else: + text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) + ids = tok(text, add_special_tokens=False)["input_ids"] + cap = max_model_len - max_tok - 8 + if len(ids) > cap: + ids = ids[:cap] + n_trunc += 1 + prompts.append({"prompt_token_ids": ids}) + samplings.append(SamplingParams(temperature=0.0, max_tokens=max_tok)) + ridx.append(i) + return prompts, samplings, ridx, results, n_trunc + + +def run_stage2_worker(gpu_id: int, slice_path: str, out_path: str, cfg: _Cfg) -> None: + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + from transformers import AutoTokenizer + from vllm import LLM + + from nemo_curator.utils.vllm_utils import pick_free_port, resolve_local_model_path + + local_model = resolve_local_model_path(cfg.model) + tok = AutoTokenizer.from_pretrained(local_model, trust_remote_code=True) + llm_kw: dict = { + "model": local_model, + "tensor_parallel_size": 1, + "gpu_memory_utilization": cfg.gpu_mem_util, + "max_model_len": cfg.max_model_len, + "max_num_seqs": cfg.max_num_seqs, + "max_num_batched_tokens": cfg.max_num_batched_tokens, + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + "enforce_eager": False, + "trust_remote_code": True, + "disable_log_stats": True, + } + if cfg.kv_cache_dtype and cfg.kv_cache_dtype != "auto": + llm_kw["kv_cache_dtype"] = cfg.kv_cache_dtype + os.environ["MASTER_PORT"] = str(pick_free_port()) + t_setup = time.perf_counter() + llm = LLM(**llm_kw) + setup_s = time.perf_counter() - t_setup + rows = pq.ParquetFile(slice_path).read().to_pandas().to_dict("records") + prompts, samplings, ridx, results, n_trunc = _build_worker_prompts(rows, tok, cfg.max_model_len, cfg.max_tokens) + t1 = time.perf_counter() + outs = llm.generate(prompts, samplings) if prompts else [] + infer_s = time.perf_counter() - t1 + for j, o in enumerate(outs): + i = ridx[j] + resp = o.outputs[0].text if o.outputs else "" + results[i] = { + **rows[i], + "llm_response": resp, + "dripper_error": "" if resp else "empty_response", + "inference_time_s": infer_s / max(len(outs), 1), + } + pd.DataFrame([x for x in results if x is not None]).to_parquet(out_path, index=False, compression="snappy") + logger.info( + "gpu{} DONE {} prompts ({} trunc) setup={:.1f}s infer={:.1f}s {:.1f} pages/s", + gpu_id, + len(prompts), + n_trunc, + setup_s, + infer_s, + len(prompts) / max(infer_s, 1e-6), + ) + + +def _detect_gpus() -> int: + n = os.environ.get("SLURM_GPUS_ON_NODE") or os.environ.get("SLURM_GPUS_PER_NODE", "") + if n: + try: + return int(n.split(":")[-1]) + except ValueError: + pass + try: + r = subprocess.run(["nvidia-smi", "-L"], check=False, capture_output=True, text=True, timeout=5) + return max(1, sum(1 for ln in r.stdout.splitlines() if ln.startswith("GPU"))) + except OSError: + return 1 + + +def run_stage2(df: pd.DataFrame, args: argparse.Namespace) -> pd.DataFrame: + n_gpus = args.replicas if args.replicas > 0 else _detect_gpus() + logger.info("Stage 2: {:,} pages over {} GPUs", len(df), n_gpus) + tmp = Path(args.output) / "_gpu_slices" + tmp.mkdir(parents=True, exist_ok=True) + cost = df["prompt"].astype(str).str.len().to_numpy() + order = sorted(range(len(df)), key=lambda i: -cost[i]) + bins: list[list[int]] = [[] for _ in range(n_gpus)] + load = [0] * n_gpus + for i in order: + g = min(range(n_gpus), key=lambda k: load[k]) + bins[g].append(i) + load[g] += int(cost[i]) + sl = [str(tmp / f"slice_{g}.parquet") for g in range(n_gpus)] + ol = [str(tmp / f"out_{g}.parquet") for g in range(n_gpus)] + cols = [c for c in _GPU_SLICE_COLS if c in df.columns] + for g in range(n_gpus): + df[cols].iloc[bins[g]].to_parquet(sl[g], index=False) + w_base = [ + sys.executable, + os.path.abspath(__file__), + "--worker", + "--model", + args.model, + "--max-tokens", + str(args.max_tokens), + "--gpu-mem-util", + str(args.gpu_mem_util), + "--max-model-len", + str(args.max_model_len), + "--max-num-seqs", + str(args.max_num_seqs), + "--max-num-batched-tokens", + str(args.max_num_batched_tokens), + "--kv-cache-dtype", + args.kv_cache_dtype, + ] + t0 = time.perf_counter() + procs = [ + subprocess.Popen([*w_base, "--gpu", str(g), "--slice", sl[g], "--slice-out", ol[g]]) for g in range(n_gpus) + ] + rcs = [p.wait() for p in procs] + logger.info("Stage 2 workers done in {:.1f}s codes={}", time.perf_counter() - t0, rcs) + frames = [pq.ParquetFile(o).read().to_pandas() for o in ol if Path(o).exists()] + return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame() + + +def run_stage2b(df: pd.DataFrame) -> pd.DataFrame: + from nemo_curator.stages.text.experimental.dripper.preprocessing import DripperHTMLPostprocessStage + + from nemo_curator.backends.ray_actor_pool import RayActorPoolExecutor + from nemo_curator.pipeline import Pipeline + from nemo_curator.tasks import DocumentBatch + + t0 = time.perf_counter() + n_workers = max(1, (os.cpu_count() or 4) - 2) + stage_df = df.copy() + if "dripper_response" not in stage_df.columns and "llm_response" in stage_df.columns: + stage_df["dripper_response"] = stage_df["llm_response"] + stage = DripperHTMLPostprocessStage(html_col="html", url_col="url", worker_count=n_workers) + pipeline = Pipeline(name="stage2b") + pipeline.add_stage(stage) + chunks = [ + DocumentBatch(dataset_name="stage2b", data=stage_df.iloc[i : i + 1000].reset_index(drop=True)) + for i in range(0, len(stage_df), 1000) + ] + output = pipeline.run(executor=RayActorPoolExecutor(), initial_tasks=chunks) or [] + out = pd.concat([t.to_pandas() for t in output], ignore_index=True) if output else stage_df + if "mapping_json" not in out.columns: + out["mapping_json"] = "" + logger.info( + "Stage 2b: content_ok={:,} mapping_ok={:,} in {:.1f}s", + (out["dripper_content"].astype(str).str.len() > _MIN_CONTENT_LEN).sum(), + (out["mapping_json"].astype(str).str.len() > _MIN_CONTENT_LEN).sum(), + time.perf_counter() - t0, + ) + return out + + +def run(args: argparse.Namespace) -> None: + tracker = StageMetrics( + "stage_gpu_pipeline", + shard_index=args.shard_index, + num_shards=args.num_shards, + n_gpus=args.replicas or _detect_gpus(), + ) + tracker.start() + t_total = time.perf_counter() + inp = Path(args.input) + if inp.is_dir(): + exact = inp / f"shard_{args.shard_index:04d}.parquet" + inp = exact if exact.exists() else sorted(inp.glob("shard_*.parquet"))[0] + all_df = pq.ParquetFile(str(inp)).read().to_pandas() + rep_df = ( + all_df[all_df["cluster_role"].isin(["representative", "singleton"])] + if "cluster_role" in all_df.columns + else all_df + ).reset_index(drop=True) + logger.info( + "{:,}/{:,} pages sent to LLM ({:.1f}%)", len(rep_df), len(all_df), len(rep_df) / max(len(all_df), 1) * 100 + ) + _t = time.perf_counter() + rep_df = run_stage1c(rep_df) + t1c_s = time.perf_counter() - _t + _t = time.perf_counter() + infer_df = run_stage2(rep_df, args) + t2_s = time.perf_counter() - _t + _t = time.perf_counter() + passthrough = rep_df[["url"] + [c for c in ["simp_html", "map_html", "html"] if c in rep_df.columns]] + infer_df = infer_df.merge(passthrough, on="url", how="left", suffixes=("", "_1c")) + for c in ["simp_html", "map_html", "html"]: + if f"{c}_1c" in infer_df.columns: + infer_df[c] = infer_df[c].fillna(infer_df[f"{c}_1c"]) + infer_df = infer_df.drop(columns=[f"{c}_1c"]) + result_df = run_stage2b(infer_df) + t2b_s = time.perf_counter() - _t + out_dir = Path(args.output) + out_dir.mkdir(parents=True, exist_ok=True) + fname = f"shard_{args.shard_index:04d}.parquet" if args.num_shards > 1 else "pipeline_results.parquet" + out_path = out_dir / fname + for col in OUTPUT_COLS: + if col not in result_df.columns: + result_df[col] = None + tmp = out_path.with_suffix(".parquet.tmp") + result_df.to_parquet(str(tmp), index=False, compression="snappy") + tmp.rename(out_path) + total_s = time.perf_counter() - t_total + ok = int((result_df["dripper_content"].astype(str).str.len() > _MIN_CONTENT_LEN).sum()) + errs = int((result_df["dripper_error"].astype(str).str.len() > _MIN_ERROR_LEN).sum()) + logger.info( + "ALL DONE: {:,} pages ok={} total={:.1f}s (1c={:.1f}s 2={:.1f}s 2b={:.1f}s) -> {}", + len(result_df), + ok, + total_s, + t1c_s, + t2_s, + t2b_s, + out_path, + ) + tracker.finish(total_pages=len(result_df), errors=errs) + tracker.extra = { + "stage1c_s": round(t1c_s, 1), + "stage2_s": round(t2_s, 1), + "stage2b_s": round(t2b_s, 1), + "content_ok": ok, + } + tracker.save(args.output) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--worker", action="store_true") + p.add_argument("--gpu", type=int, default=0) + p.add_argument("--slice") + p.add_argument("--slice-out") + p.add_argument("--input") + p.add_argument("--output") + p.add_argument("--shard-index", type=int, default=int(os.environ.get("SLURM_ARRAY_TASK_ID", "0"))) + p.add_argument("--num-shards", type=int, default=1) + p.add_argument("--replicas", type=int, default=int(os.environ.get("N_GPU_REPLICAS", "0"))) + p.add_argument("--model", default="opendatalab/MinerU-HTML-v1.1-hunyuan0.5B-compact") + p.add_argument("--hf-cache", default=os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))) + p.add_argument("--max-tokens", type=int, default=2048) + p.add_argument("--gpu-mem-util", type=float, default=0.90) + p.add_argument("--max-model-len", type=int, default=32768) + p.add_argument("--max-num-seqs", type=int, default=512) + p.add_argument("--max-num-batched-tokens", type=int, default=16384) + p.add_argument("--kv-cache-dtype", default="fp8") + args = p.parse_args() + os.environ.setdefault("HF_HOME", args.hf_cache) + if args.worker: + run_stage2_worker( + args.gpu, + args.slice, + args.slice_out, + _Cfg( + args.model, + args.gpu_mem_util, + args.max_model_len, + args.max_num_seqs, + args.max_num_batched_tokens, + args.max_tokens, + args.kv_cache_dtype, + ), + ) + else: + if not args.input or not args.output: + p.error("--input and --output required in main mode") + run(args) + + +if __name__ == "__main__": + main() diff --git a/uv.lock b/uv.lock index 7509d39c76..6ce966bfbe 100644 --- a/uv.lock +++ b/uv.lock @@ -5195,6 +5195,7 @@ all = [ { name = "vllm", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "warcio" }, { name = "whisperx", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, + { name = "xxhash" }, ] audio-common = [ { name = "accelerate" }, @@ -5336,6 +5337,7 @@ math-cpu = [ { name = "sentencepiece" }, { name = "trafilatura" }, { name = "warcio" }, + { name = "xxhash" }, ] math-cuda12 = [ { name = "beautifulsoup4" }, @@ -5363,6 +5365,7 @@ math-cuda12 = [ { name = "trafilatura" }, { name = "vllm", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "warcio" }, + { name = "xxhash" }, ] sdg-cpu = [ { name = "data-designer" }, @@ -5392,6 +5395,7 @@ text-cpu = [ { name = "sentencepiece" }, { name = "trafilatura" }, { name = "warcio" }, + { name = "xxhash" }, ] text-cuda12 = [ { name = "beautifulsoup4" }, @@ -5418,6 +5422,7 @@ text-cuda12 = [ { name = "trafilatura" }, { name = "vllm", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "warcio" }, + { name = "xxhash" }, ] translation-all = [ { name = "aiohttp" }, @@ -5669,6 +5674,7 @@ requires-dist = [ { name = "vllm", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'vllm'", specifier = ">=0.14.1" }, { name = "warcio", marker = "extra == 'text-cpu'" }, { name = "whisperx", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'audio-common'", specifier = ">=3.8.4" }, + { name = "xxhash", marker = "extra == 'text-cpu'" }, ] provides-extras = ["cuda12", "vllm", "inference-server", "deduplication-cuda12", "audio-common", "audio-cpu", "audio-cuda12", "image-cpu", "image-cuda12", "translation-common", "translation-metrics", "translation-segmentation", "translation-aws", "translation-google", "translation-nmt", "translation-all", "text-cpu", "text-cuda12", "video-cpu", "video-cuda12", "math-cpu", "math-cuda12", "interleaved-cpu", "interleaved-cuda12", "sdg-cpu", "sdg-cuda12", "all"] @@ -11623,16 +11629,24 @@ sdist = { url = "https://files.pythonhosted.org/packages/02/84/30869e01909fb37a6 wheels = [ { url = "https://files.pythonhosted.org/packages/a5/86/cf2c0321dc3940a7aa73076f4fd677a0fb3e405cb297ead7d864fd90847e/xxhash-3.6.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:297b7fbf86c82c550e12e8fb71968b3f033d27b874276ba3624ea868c11165a8", size = 193880, upload-time = "2025-10-02T14:34:22.431Z" }, { url = "https://files.pythonhosted.org/packages/ba/b3/5a4241309217c5c876f156b10778f3ab3af7ba7e3259e6d5f5c7d0129eb2/xxhash-3.6.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:51312c768403d8540487dbbfb557454cfc55589bbde6424456951f7fcd4facb3", size = 191409, upload-time = "2025-10-02T14:34:29.696Z" }, + { url = "https://files.pythonhosted.org/packages/c0/01/99bfbc15fb9abb9a72b088c1d95219fc4782b7d01fc835bd5744d66dd0b8/xxhash-3.6.0-cp311-cp311-win32.whl", hash = "sha256:d1927a69feddc24c987b337ce81ac15c4720955b667fe9b588e02254b80446fd", size = 30574, upload-time = "2025-10-02T14:34:31.028Z" }, { url = "https://files.pythonhosted.org/packages/65/79/9d24d7f53819fe301b231044ea362ce64e86c74f6e8c8e51320de248b3e5/xxhash-3.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:26734cdc2d4ffe449b41d186bbeac416f704a482ed835d375a5c0cb02bc63fef", size = 31481, upload-time = "2025-10-02T14:34:32.062Z" }, { url = "https://files.pythonhosted.org/packages/11/4f/426f91b96701ec2f37bb2b8cec664eff4f658a11f3fa9d94f0a887ea6d2b/xxhash-3.6.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:49e03e6fe2cac4a1bc64952dd250cf0dbc5ef4ebb7b8d96bce82e2de163c82a2", size = 193883, upload-time = "2025-10-02T14:34:43.249Z" }, { url = "https://files.pythonhosted.org/packages/23/07/63ffb386cd47029aa2916b3d2f454e6cc5b9f5c5ada3790377d5430084e7/xxhash-3.6.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:418daf3db71e1413cfe211c2f9a528456936645c17f46b5204705581a45390ae", size = 191431, upload-time = "2025-10-02T14:34:50.798Z" }, + { url = "https://files.pythonhosted.org/packages/0f/93/14fde614cadb4ddf5e7cebf8918b7e8fac5ae7861c1875964f17e678205c/xxhash-3.6.0-cp312-cp312-win32.whl", hash = "sha256:50fc255f39428a27299c20e280d6193d8b63b8ef8028995323bf834a026b4fbb", size = 30617, upload-time = "2025-10-02T14:34:51.954Z" }, { url = "https://files.pythonhosted.org/packages/13/5d/0d125536cbe7565a83d06e43783389ecae0c0f2ed037b48ede185de477c0/xxhash-3.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:c0f2ab8c715630565ab8991b536ecded9416d615538be8ecddce43ccf26cbc7c", size = 31534, upload-time = "2025-10-02T14:34:53.276Z" }, { url = "https://files.pythonhosted.org/packages/5e/1e/3c3d3ef071b051cc3abbe3721ffb8365033a172613c04af2da89d5548a87/xxhash-3.6.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:42c36dd7dbad2f5238950c377fcbf6811b1cdb1c444fab447960030cea60504d", size = 193936, upload-time = "2025-10-02T14:35:05.013Z" }, { url = "https://files.pythonhosted.org/packages/af/3c/0bb129170ee8f3650f08e993baee550a09593462a5cddd8e44d0011102b1/xxhash-3.6.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f205badabde7aafd1a31e8ca2a3e5a763107a71c397c4481d6a804eb5063d8bd", size = 191495, upload-time = "2025-10-02T14:35:12.971Z" }, + { url = "https://files.pythonhosted.org/packages/e9/3a/6797e0114c21d1725e2577508e24006fd7ff1d8c0c502d3b52e45c1771d8/xxhash-3.6.0-cp313-cp313-win32.whl", hash = "sha256:2577b276e060b73b73a53042ea5bd5203d3e6347ce0d09f98500f418a9fcf799", size = 30620, upload-time = "2025-10-02T14:35:14.129Z" }, { url = "https://files.pythonhosted.org/packages/86/15/9bc32671e9a38b413a76d24722a2bf8784a132c043063a8f5152d390b0f9/xxhash-3.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:757320d45d2fbcce8f30c42a6b2f47862967aea7bf458b9625b4bbe7ee390392", size = 31542, upload-time = "2025-10-02T14:35:15.21Z" }, { url = "https://files.pythonhosted.org/packages/d7/6b/33e21afb1b5b3f46b74b6bd1913639066af218d704cc0941404ca717fc57/xxhash-3.6.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fba27a198363a7ef87f8c0f6b171ec36b674fe9053742c58dd7e3201c1ab30ee", size = 196070, upload-time = "2025-10-02T14:35:26.586Z" }, { url = "https://files.pythonhosted.org/packages/dc/6c/5cbde9de2cd967c322e651c65c543700b19e7ae3e0aae8ece3469bf9683d/xxhash-3.6.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5f059d9faeacd49c0215d66f4056e1326c80503f51a1532ca336a385edadd033", size = 193787, upload-time = "2025-10-02T14:35:33.827Z" }, + { url = "https://files.pythonhosted.org/packages/19/fa/0172e350361d61febcea941b0cc541d6e6c8d65d153e85f850a7b256ff8a/xxhash-3.6.0-cp313-cp313t-win32.whl", hash = "sha256:1244460adc3a9be84731d72b8e80625788e5815b68da3da8b83f78115a40a7ec", size = 30916, upload-time = "2025-10-02T14:35:35.107Z" }, { url = "https://files.pythonhosted.org/packages/ad/e6/e8cf858a2b19d6d45820f072eff1bea413910592ff17157cabc5f1227a16/xxhash-3.6.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b1e420ef35c503869c4064f4a2f2b08ad6431ab7b229a05cce39d74268bca6b8", size = 31799, upload-time = "2025-10-02T14:35:36.165Z" }, + { url = "https://files.pythonhosted.org/packages/56/15/064b197e855bfb7b343210e82490ae672f8bc7cdf3ddb02e92f64304ee8a/xxhash-3.6.0-cp313-cp313t-win_arm64.whl", hash = "sha256:ec44b73a4220623235f67a996c862049f375df3b1052d9899f40a6382c32d746", size = 28044, upload-time = "2025-10-02T14:35:37.195Z" }, + { url = "https://files.pythonhosted.org/packages/93/1e/8aec23647a34a249f62e2398c42955acd9b4c6ed5cf08cbea94dc46f78d2/xxhash-3.6.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0f7b7e2ec26c1666ad5fc9dbfa426a6a3367ceaf79db5dd76264659d509d73b0", size = 30662, upload-time = "2025-10-02T14:37:01.743Z" }, + { url = "https://files.pythonhosted.org/packages/b8/0b/b14510b38ba91caf43006209db846a696ceea6a847a0c9ba0a5b1adc53d6/xxhash-3.6.0-pp311-pypy311_pp73-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5dc1e14d14fa0f5789ec29a7062004b5933964bb9b02aae6622b8f530dc40296", size = 41056, upload-time = "2025-10-02T14:37:02.879Z" }, + { url = "https://files.pythonhosted.org/packages/50/55/15a7b8a56590e66ccd374bbfa3f9ffc45b810886c8c3b614e3f90bd2367c/xxhash-3.6.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:881b47fc47e051b37d94d13e7455131054b56749b91b508b0907eb07900d1c13", size = 36251, upload-time = "2025-10-02T14:37:04.44Z" }, { url = "https://files.pythonhosted.org/packages/62/b2/5ac99a041a29e58e95f907876b04f7067a0242cb85b5f39e726153981503/xxhash-3.6.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c6dc31591899f5e5666f04cc2e529e69b4072827085c1ef15294d91a004bc1bd", size = 32481, upload-time = "2025-10-02T14:37:05.869Z" }, { url = "https://files.pythonhosted.org/packages/7b/d9/8d95e906764a386a3d3b596f3c68bb63687dfca806373509f51ce8eea81f/xxhash-3.6.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:15e0dac10eb9309508bfc41f7f9deaa7755c69e35af835db9cb10751adebc35d", size = 31565, upload-time = "2025-10-02T14:37:06.966Z" }, ]