From 31ccf36c0a91662ebabd5cd36bd3a11429f024e4 Mon Sep 17 00:00:00 2001 From: dheerajchouhan08 Date: Mon, 11 May 2026 14:30:21 +0530 Subject: [PATCH 1/2] fix: shrink mapInPandas batches and project inputs to avoid arrow 2GB overflow - Lower default max_records_per_batch from 200 to 50 so multi-MB feature rows can't aggregate past Arrow's 2 GiB per-column-per-batch limit, which was crashing the Python worker before _decode_batch ran. - Project input DataFrame to only the columns _decode_batch reads before mapInPandas, so unused columns don't bloat Arrow batches. - Bump version to 0.3.2. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../inference_logging_client/__init__.py | 26 +++++++++++++++---- .../inference_logging_client/pyproject.toml | 2 +- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/py-sdk/inference_logging_client/inference_logging_client/__init__.py b/py-sdk/inference_logging_client/inference_logging_client/__init__.py index 7efa8f2d..74c5f125 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/__init__.py +++ b/py-sdk/inference_logging_client/inference_logging_client/__init__.py @@ -48,7 +48,7 @@ from .types import FORMAT_TYPE_MAP, DecodedMPLog, FeatureInfo, Format from .utils import format_dataframe_floats, get_format_name, unpack_metadata_byte -__version__ = "0.3.1" +__version__ = "0.3.2" # Maximum supported schema version (4 bits = 0-15) _MAX_SCHEMA_VERSION = 15 @@ -295,8 +295,9 @@ def decode_mplog_dataframe( mp_config_id_column: Name of the column containing model proxy config ID (default: "mp_config_id") num_partitions: Number of partitions for distributed decode. Default 10000 to keep partition size small when rows are large (3-5 MB each). Increase if rows are small. - max_records_per_batch: Max rows per Arrow batch in mapInPandas. When set (default 200), - applied temporarily during this call to limit memory per batch when rows are large. + max_records_per_batch: Max rows per Arrow batch in mapInPandas. Default 50 to keep each + batch well under Arrow's 2 GiB per-column limit when rows carry multi-MB payloads + (features column can be several MB each). Raise it if your rows are small. needed_columns: Optional set or list of feature names to include. If provided, only these columns are decoded and returned (reduces memory and output size). If None, all schema columns are returned. @@ -522,10 +523,25 @@ def _decode_batch(iterator): out_pdf = pd.DataFrame(out_rows, columns=all_columns_ordered) yield out_pdf + # Project to only the columns _decode_batch actually reads, so Arrow batches + # don't carry unused giant columns into the Python worker. This keeps each + # mapInPandas batch well under Arrow's 2 GiB per-column limit. + projected_cols = [ + c for c in ( + [features_column, metadata_column, mp_config_id_column, "entities"] + + row_metadata_columns + ) + if c in df_columns + ] + # Preserve original column order and drop duplicates + seen = set() + projected_cols = [c for c in projected_cols if not (c in seen or seen.add(c))] + df_projected = df.select(*projected_cols) + n_partitions = num_partitions if num_partitions is not None else 10000 - df_repart = df.repartition(n_partitions) + df_repart = df_projected.repartition(n_partitions) - batch_limit = max_records_per_batch if max_records_per_batch is not None else 200 + batch_limit = max_records_per_batch if max_records_per_batch is not None else 50 prev_max_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch") spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", str(batch_limit)) try: diff --git a/py-sdk/inference_logging_client/pyproject.toml b/py-sdk/inference_logging_client/pyproject.toml index 13366434..d921c087 100644 --- a/py-sdk/inference_logging_client/pyproject.toml +++ b/py-sdk/inference_logging_client/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "inference-logging-client" -version = "0.3.1" +version = "0.3.2" description = "Decode MPLog feature logs from proto, arrow, or parquet format" readme = "readme.md" requires-python = ">=3.8" From db4cd7d4ed0ddd9bf0b67fa730aaf10feb0ddae6 Mon Sep 17 00:00:00 2001 From: dheerajchouhan08 Date: Mon, 11 May 2026 14:43:00 +0530 Subject: [PATCH 2/2] feat: negative-cache schema fetches on HTTP 400 only - Add SchemaFetchError.status_code so callers can distinguish HTTP failure modes (400 bad request vs 5xx vs network errors). - Add a thread-safe LRU negative cache in get_feature_schema, mirroring the positive cache shape (OrderedDict, max 100 entries, no TTL). Only HTTP 400 responses are negative-cached; transient errors (5xx, network, 401/403/429) are intentionally excluded so they can be retried. - clear_schema_cache() now clears both caches. - Bump version to 0.3.3. Prevents worker tasks from hammering the inference service with the same 400 for every row when a (mp_config_id, version) schema is not registered. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../inference_logging_client/__init__.py | 2 +- .../inference_logging_client/exceptions.py | 12 +++++- .../inference_logging_client/io.py | 43 ++++++++++++++++--- .../inference_logging_client/pyproject.toml | 2 +- 4 files changed, 50 insertions(+), 9 deletions(-) diff --git a/py-sdk/inference_logging_client/inference_logging_client/__init__.py b/py-sdk/inference_logging_client/inference_logging_client/__init__.py index 74c5f125..f643186c 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/__init__.py +++ b/py-sdk/inference_logging_client/inference_logging_client/__init__.py @@ -48,7 +48,7 @@ from .types import FORMAT_TYPE_MAP, DecodedMPLog, FeatureInfo, Format from .utils import format_dataframe_floats, get_format_name, unpack_metadata_byte -__version__ = "0.3.2" +__version__ = "0.3.3" # Maximum supported schema version (4 bits = 0-15) _MAX_SCHEMA_VERSION = 15 diff --git a/py-sdk/inference_logging_client/inference_logging_client/exceptions.py b/py-sdk/inference_logging_client/inference_logging_client/exceptions.py index 682f1206..46b9c7ed 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/exceptions.py +++ b/py-sdk/inference_logging_client/inference_logging_client/exceptions.py @@ -1,5 +1,7 @@ """Custom exceptions for inference-logging-client.""" +from typing import Optional + class InferenceLoggingError(Exception): """Base exception for inference-logging-client errors.""" @@ -8,9 +10,15 @@ class InferenceLoggingError(Exception): class SchemaFetchError(InferenceLoggingError): - """Raised when fetching schema from inference service fails.""" + """Raised when fetching schema from inference service fails. - pass + The optional `status_code` attribute carries the HTTP status when the + failure was an HTTP response (None for network/URL errors). + """ + + def __init__(self, message: str, status_code: Optional[int] = None): + super().__init__(message) + self.status_code = status_code class SchemaNotFoundError(InferenceLoggingError): diff --git a/py-sdk/inference_logging_client/inference_logging_client/io.py b/py-sdk/inference_logging_client/inference_logging_client/io.py index dea3d5b2..608a59a6 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/io.py +++ b/py-sdk/inference_logging_client/inference_logging_client/io.py @@ -23,6 +23,14 @@ _schema_cache_lock = threading.Lock() _SCHEMA_CACHE_MAX_SIZE = 100 # Maximum number of cached schemas +# Thread-safe LRU negative cache for schemas. Only HTTP 400 responses are +# negative-cached; transient failures (5xx, network errors, 401/403/429, etc.) +# are intentionally excluded so they can be retried on subsequent calls. +# Values are the error message at the time of the 400, returned on cache hit. +_schema_neg_cache: OrderedDict[tuple[str, int], str] = OrderedDict() +_schema_neg_cache_lock = threading.Lock() +_SCHEMA_NEG_CACHE_MAX_SIZE = 100 + # Retry configuration _MAX_RETRIES = 3 _RETRY_BACKOFF_BASE = 0.5 # seconds @@ -58,9 +66,13 @@ def _fetch_schema_with_retry(url: str, max_retries: int = _MAX_RETRIES) -> dict: error_body = e.read().decode("utf-8") if e.fp else str(e) # Don't retry on client errors (4xx) if 400 <= e.code < 500: - raise SchemaFetchError(f"HTTP Error {e.code} from inference service: {error_body}") + raise SchemaFetchError( + f"HTTP Error {e.code} from inference service: {error_body}", + status_code=e.code, + ) last_exception = SchemaFetchError( - f"HTTP Error {e.code} from inference service: {error_body}" + f"HTTP Error {e.code} from inference service: {error_body}", + status_code=e.code, ) except urllib.error.URLError as e: last_exception = SchemaFetchError( @@ -125,12 +137,27 @@ def get_feature_schema( _schema_cache.move_to_end(cache_key) return _schema_cache[cache_key] + # Thread-safe negative cache lookup (HTTP 400 only). Replay the error + # without hitting the network. Same LRU shape as the positive cache. + with _schema_neg_cache_lock: + if cache_key in _schema_neg_cache: + _schema_neg_cache.move_to_end(cache_key) + raise SchemaFetchError(_schema_neg_cache[cache_key], status_code=400) + base_url = f"{inference_host}{api_path}" params = urllib.parse.urlencode({"model_config_id": model_config_id, "version": str(version)}) url = f"{base_url}?{params}" - # Fetch with retry - data = _fetch_schema_with_retry(url) + # Fetch with retry. Negative-cache HTTP 400 only. + try: + data = _fetch_schema_with_retry(url) + except SchemaFetchError as e: + if getattr(e, "status_code", None) == 400: + with _schema_neg_cache_lock: + while len(_schema_neg_cache) >= _SCHEMA_NEG_CACHE_MAX_SIZE: + _schema_neg_cache.popitem(last=False) + _schema_neg_cache[cache_key] = str(e) + raise features = [] for idx, component in enumerate(data.get("data", [])): @@ -158,9 +185,15 @@ def get_feature_schema( def clear_schema_cache() -> None: - """Clear the schema cache. Useful for testing or when schemas have changed.""" + """Clear both positive and negative schema caches. + + Useful for testing, or when schemas have changed on the inference service + and the negative cache should be reset to allow a fresh fetch. + """ with _schema_cache_lock: _schema_cache.clear() + with _schema_neg_cache_lock: + _schema_neg_cache.clear() def parse_mplog_protobuf(data: bytes) -> DecodedMPLog: diff --git a/py-sdk/inference_logging_client/pyproject.toml b/py-sdk/inference_logging_client/pyproject.toml index d921c087..c27d8fd4 100644 --- a/py-sdk/inference_logging_client/pyproject.toml +++ b/py-sdk/inference_logging_client/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "inference-logging-client" -version = "0.3.2" +version = "0.3.3" description = "Decode MPLog feature logs from proto, arrow, or parquet format" readme = "readme.md" requires-python = ">=3.8"