Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from .types import FORMAT_TYPE_MAP, DecodedMPLog, FeatureInfo, Format
from .utils import format_dataframe_floats, get_format_name, unpack_metadata_byte

__version__ = "0.3.1"
__version__ = "0.3.3"

# Maximum supported schema version (4 bits = 0-15)
_MAX_SCHEMA_VERSION = 15
Expand Down Expand Up @@ -295,8 +295,9 @@ def decode_mplog_dataframe(
mp_config_id_column: Name of the column containing model proxy config ID (default: "mp_config_id")
num_partitions: Number of partitions for distributed decode. Default 10000 to keep
partition size small when rows are large (3-5 MB each). Increase if rows are small.
max_records_per_batch: Max rows per Arrow batch in mapInPandas. When set (default 200),
applied temporarily during this call to limit memory per batch when rows are large.
max_records_per_batch: Max rows per Arrow batch in mapInPandas. Default 50 to keep each
batch well under Arrow's 2 GiB per-column limit when rows carry multi-MB payloads
(features column can be several MB each). Raise it if your rows are small.
needed_columns: Optional set or list of feature names to include. If provided, only these
columns are decoded and returned (reduces memory and output size). If None, all schema columns are returned.

Expand Down Expand Up @@ -522,10 +523,25 @@ def _decode_batch(iterator):
out_pdf = pd.DataFrame(out_rows, columns=all_columns_ordered)
yield out_pdf

# Project to only the columns _decode_batch actually reads, so Arrow batches
# don't carry unused giant columns into the Python worker. This keeps each
# mapInPandas batch well under Arrow's 2 GiB per-column limit.
projected_cols = [
c for c in (
[features_column, metadata_column, mp_config_id_column, "entities"]
+ row_metadata_columns
)
if c in df_columns
]
# Preserve original column order and drop duplicates
seen = set()
projected_cols = [c for c in projected_cols if not (c in seen or seen.add(c))]
df_projected = df.select(*projected_cols)

n_partitions = num_partitions if num_partitions is not None else 10000
df_repart = df.repartition(n_partitions)
df_repart = df_projected.repartition(n_partitions)

batch_limit = max_records_per_batch if max_records_per_batch is not None else 200
batch_limit = max_records_per_batch if max_records_per_batch is not None else 50
prev_max_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", str(batch_limit))
Comment on lines +544 to 546
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Validate max_records_per_batch bounds before setting Spark conf.

Line 544 accepts non-positive values from callers; this can fail later with a less clear Spark/Arrow error. Add an early validation and fail fast with ValueError.

Suggested patch
-    batch_limit = max_records_per_batch if max_records_per_batch is not None else 50
+    batch_limit = max_records_per_batch if max_records_per_batch is not None else 50
+    if batch_limit <= 0:
+        raise ValueError("max_records_per_batch must be a positive integer")
     prev_max_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
batch_limit = max_records_per_batch if max_records_per_batch is not None else 50
prev_max_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", str(batch_limit))
batch_limit = max_records_per_batch if max_records_per_batch is not None else 50
if batch_limit <= 0:
raise ValueError("max_records_per_batch must be a positive integer")
prev_max_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", str(batch_limit))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@py-sdk/inference_logging_client/inference_logging_client/__init__.py` around
lines 544 - 546, The code sets batch_limit from max_records_per_batch and then
configures Spark without validating input; add an early check that if
max_records_per_batch is not None and is <= 0 you raise a ValueError (clear
message like "max_records_per_batch must be a positive integer"), otherwise
compute batch_limit = max_records_per_batch if provided else 50 and then call
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch",
str(batch_limit)); reference max_records_per_batch, batch_limit and the
spark.conf.set/get calls when making the change.

try:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Custom exceptions for inference-logging-client."""

from typing import Optional


class InferenceLoggingError(Exception):
"""Base exception for inference-logging-client errors."""
Expand All @@ -8,9 +10,15 @@ class InferenceLoggingError(Exception):


class SchemaFetchError(InferenceLoggingError):
"""Raised when fetching schema from inference service fails."""
"""Raised when fetching schema from inference service fails.

pass
The optional `status_code` attribute carries the HTTP status when the
failure was an HTTP response (None for network/URL errors).
"""

def __init__(self, message: str, status_code: Optional[int] = None):
super().__init__(message)
self.status_code = status_code


class SchemaNotFoundError(InferenceLoggingError):
Expand Down
43 changes: 38 additions & 5 deletions py-sdk/inference_logging_client/inference_logging_client/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
_schema_cache_lock = threading.Lock()
_SCHEMA_CACHE_MAX_SIZE = 100 # Maximum number of cached schemas

# Thread-safe LRU negative cache for schemas. Only HTTP 400 responses are
# negative-cached; transient failures (5xx, network errors, 401/403/429, etc.)
# are intentionally excluded so they can be retried on subsequent calls.
# Values are the error message at the time of the 400, returned on cache hit.
_schema_neg_cache: OrderedDict[tuple[str, int], str] = OrderedDict()
_schema_neg_cache_lock = threading.Lock()
_SCHEMA_NEG_CACHE_MAX_SIZE = 100

# Retry configuration
_MAX_RETRIES = 3
_RETRY_BACKOFF_BASE = 0.5 # seconds
Expand Down Expand Up @@ -58,9 +66,13 @@ def _fetch_schema_with_retry(url: str, max_retries: int = _MAX_RETRIES) -> dict:
error_body = e.read().decode("utf-8") if e.fp else str(e)
# Don't retry on client errors (4xx)
if 400 <= e.code < 500:
raise SchemaFetchError(f"HTTP Error {e.code} from inference service: {error_body}")
raise SchemaFetchError(
f"HTTP Error {e.code} from inference service: {error_body}",
status_code=e.code,
)
last_exception = SchemaFetchError(
f"HTTP Error {e.code} from inference service: {error_body}"
f"HTTP Error {e.code} from inference service: {error_body}",
status_code=e.code,
)
except urllib.error.URLError as e:
last_exception = SchemaFetchError(
Expand Down Expand Up @@ -125,12 +137,27 @@ def get_feature_schema(
_schema_cache.move_to_end(cache_key)
return _schema_cache[cache_key]

# Thread-safe negative cache lookup (HTTP 400 only). Replay the error
# without hitting the network. Same LRU shape as the positive cache.
with _schema_neg_cache_lock:
if cache_key in _schema_neg_cache:
_schema_neg_cache.move_to_end(cache_key)
raise SchemaFetchError(_schema_neg_cache[cache_key], status_code=400)

base_url = f"{inference_host}{api_path}"
params = urllib.parse.urlencode({"model_config_id": model_config_id, "version": str(version)})
url = f"{base_url}?{params}"

# Fetch with retry
data = _fetch_schema_with_retry(url)
# Fetch with retry. Negative-cache HTTP 400 only.
try:
data = _fetch_schema_with_retry(url)
except SchemaFetchError as e:
if getattr(e, "status_code", None) == 400:
with _schema_neg_cache_lock:
while len(_schema_neg_cache) >= _SCHEMA_NEG_CACHE_MAX_SIZE:
_schema_neg_cache.popitem(last=False)
_schema_neg_cache[cache_key] = str(e)
raise

features = []
for idx, component in enumerate(data.get("data", [])):
Expand Down Expand Up @@ -158,9 +185,15 @@ def get_feature_schema(


def clear_schema_cache() -> None:
"""Clear the schema cache. Useful for testing or when schemas have changed."""
"""Clear both positive and negative schema caches.

Useful for testing, or when schemas have changed on the inference service
and the negative cache should be reset to allow a fresh fetch.
"""
with _schema_cache_lock:
_schema_cache.clear()
with _schema_neg_cache_lock:
_schema_neg_cache.clear()


def parse_mplog_protobuf(data: bytes) -> DecodedMPLog:
Expand Down
2 changes: 1 addition & 1 deletion py-sdk/inference_logging_client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "inference-logging-client"
version = "0.3.1"
version = "0.3.3"
description = "Decode MPLog feature logs from proto, arrow, or parquet format"
readme = "readme.md"
requires-python = ">=3.8"
Expand Down
Loading