fix(inference-logging-client): avoid arrow 2GB overflow in decode_mplog_dataframe#372
fix(inference-logging-client): avoid arrow 2GB overflow in decode_mplog_dataframe#372dheerajchouhan08 wants to merge 2 commits into
Conversation
… overflow - Lower default max_records_per_batch from 200 to 50 so multi-MB feature rows can't aggregate past Arrow's 2 GiB per-column-per-batch limit, which was crashing the Python worker before _decode_batch ran. - Project input DataFrame to only the columns _decode_batch reads before mapInPandas, so unused columns don't bloat Arrow batches. - Bump version to 0.3.2. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
WalkthroughVersion bumped to 0.3.3. decode_mplog_dataframe now defaults arrow batch sizing to 50, projects input DataFrame to required columns before repartitioning, and repartitions the projected DataFrame. SchemaFetchError gains an optional status_code; get_feature_schema adds a thread-safe negative LRU cache for HTTP 400 errors and clear_schema_cache clears both caches. ChangesDataFrame Projection and Batch Configuration Optimization
Suggested labels
🚥 Pre-merge checks | ✅ 1✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@py-sdk/inference_logging_client/inference_logging_client/__init__.py`:
- Around line 544-546: The code sets batch_limit from max_records_per_batch and
then configures Spark without validating input; add an early check that if
max_records_per_batch is not None and is <= 0 you raise a ValueError (clear
message like "max_records_per_batch must be a positive integer"), otherwise
compute batch_limit = max_records_per_batch if provided else 50 and then call
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch",
str(batch_limit)); reference max_records_per_batch, batch_limit and the
spark.conf.set/get calls when making the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: a7fe3cce-2478-491e-b98e-c0659c22f269
📒 Files selected for processing (2)
py-sdk/inference_logging_client/inference_logging_client/__init__.pypy-sdk/inference_logging_client/pyproject.toml
| batch_limit = max_records_per_batch if max_records_per_batch is not None else 50 | ||
| prev_max_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch") | ||
| spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", str(batch_limit)) |
There was a problem hiding this comment.
Validate max_records_per_batch bounds before setting Spark conf.
Line 544 accepts non-positive values from callers; this can fail later with a less clear Spark/Arrow error. Add an early validation and fail fast with ValueError.
Suggested patch
- batch_limit = max_records_per_batch if max_records_per_batch is not None else 50
+ batch_limit = max_records_per_batch if max_records_per_batch is not None else 50
+ if batch_limit <= 0:
+ raise ValueError("max_records_per_batch must be a positive integer")
prev_max_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| batch_limit = max_records_per_batch if max_records_per_batch is not None else 50 | |
| prev_max_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch") | |
| spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", str(batch_limit)) | |
| batch_limit = max_records_per_batch if max_records_per_batch is not None else 50 | |
| if batch_limit <= 0: | |
| raise ValueError("max_records_per_batch must be a positive integer") | |
| prev_max_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch") | |
| spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", str(batch_limit)) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@py-sdk/inference_logging_client/inference_logging_client/__init__.py` around
lines 544 - 546, The code sets batch_limit from max_records_per_batch and then
configures Spark without validating input; add an early check that if
max_records_per_batch is not None and is <= 0 you raise a ValueError (clear
message like "max_records_per_batch must be a positive integer"), otherwise
compute batch_limit = max_records_per_batch if provided else 50 and then call
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch",
str(batch_limit)); reference max_records_per_batch, batch_limit and the
spark.conf.set/get calls when making the change.
- Add SchemaFetchError.status_code so callers can distinguish HTTP failure modes (400 bad request vs 5xx vs network errors). - Add a thread-safe LRU negative cache in get_feature_schema, mirroring the positive cache shape (OrderedDict, max 100 entries, no TTL). Only HTTP 400 responses are negative-cached; transient errors (5xx, network, 401/403/429) are intentionally excluded so they can be retried. - clear_schema_cache() now clears both caches. - Bump version to 0.3.3. Prevents worker tasks from hammering the inference service with the same 400 for every row when a (mp_config_id, version) schema is not registered. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
py-sdk/inference_logging_client/inference_logging_client/io.py (1)
129-146:⚠️ Potential issue | 🟠 Major | ⚡ Quick winNegative cache key can poison valid fetches across different endpoints.
_schema_neg_cacheis keyed only by(model_config_id, version). A400from oneinference_host/api_path(misconfig, rollout mismatch, bad gateway behavior) will be replayed for other hosts/paths without a network call.Suggested fix
- _schema_neg_cache: OrderedDict[tuple[str, int], str] = OrderedDict() + _schema_neg_cache: OrderedDict[tuple[str, str, str, int], str] = OrderedDict() ... - cache_key = (model_config_id, version) + cache_key = (model_config_id, version) + neg_cache_key = (inference_host, api_path, model_config_id, version) ... - with _schema_neg_cache_lock: - if cache_key in _schema_neg_cache: - _schema_neg_cache.move_to_end(cache_key) - raise SchemaFetchError(_schema_neg_cache[cache_key], status_code=400) + with _schema_neg_cache_lock: + if neg_cache_key in _schema_neg_cache: + _schema_neg_cache.move_to_end(neg_cache_key) + raise SchemaFetchError(_schema_neg_cache[neg_cache_key], status_code=400) ... - _schema_neg_cache[cache_key] = str(e) + _schema_neg_cache[neg_cache_key] = str(e)Also applies to: 151-160
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@py-sdk/inference_logging_client/inference_logging_client/io.py` around lines 129 - 146, The negative cache currently uses cache_key = (model_config_id, version) and thus replays a 400 from one endpoint across different inference_host/api_path; change the negative-cache key to include the endpoint identifiers (e.g., cache_key_neg = (model_config_id, version, inference_host, api_path)) and use that new key for both lookups and inserts in _schema_neg_cache (where you check membership, move_to_end, and when storing the error); ensure the positive cache behavior remains the same and update all places that reference _schema_neg_cache so lookups and raises replay the endpoint-specific error only.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@py-sdk/inference_logging_client/inference_logging_client/io.py`:
- Around line 129-146: The negative cache currently uses cache_key =
(model_config_id, version) and thus replays a 400 from one endpoint across
different inference_host/api_path; change the negative-cache key to include the
endpoint identifiers (e.g., cache_key_neg = (model_config_id, version,
inference_host, api_path)) and use that new key for both lookups and inserts in
_schema_neg_cache (where you check membership, move_to_end, and when storing the
error); ensure the positive cache behavior remains the same and update all
places that reference _schema_neg_cache so lookups and raises replay the
endpoint-specific error only.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: d249088e-19a6-414c-839f-d49712fb5fe2
📒 Files selected for processing (4)
py-sdk/inference_logging_client/inference_logging_client/__init__.pypy-sdk/inference_logging_client/inference_logging_client/exceptions.pypy-sdk/inference_logging_client/inference_logging_client/io.pypy-sdk/inference_logging_client/pyproject.toml
✅ Files skipped from review due to trivial changes (1)
- py-sdk/inference_logging_client/pyproject.toml
🚧 Files skipped from review as they are similar to previous changes (1)
- py-sdk/inference_logging_client/inference_logging_client/init.py
Summary
max_records_per_batchindecode_mplog_dataframefrom 200 → 50 so multi-MB feature rows can't aggregate past Arrow's 2 GiB per-column-per-batch limit (which was crashing the Python worker before_decode_batcheven ran)._decode_batchactually reads (features,metadata,mp_config_id,entities, and row-metadata cols) beforemapInPandas, so unused columns don't bloat Arrow batches.Context
Decoding via
decode_mplog_dataframewas failing on Databricks with aPythonExceptionwhose traceback bottomed out inside Spark's own Arrow→pandas conversion (pyarrow.lib.check_status), before any user-level decode code ran. Root cause: withmax_records_per_batch=200and multi-MB binary payloads per row, a single Arrow batch's binary column exceeded the 2 GiB 32-bit-offset limit.Test plan
decode_mplog_dataframecompletes.max_records_per_batchstill overrides the new default.🤖 Generated with Claude Code
Summary by CodeRabbit
Chores
Bug Fixes