Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions py-sdk/inference_logging_client/examples/decode_csv_to_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Decode an inference-log CSV directly to another CSV using the caller-supplied
schema. No Spark required; pure-Python (csv + json + base64 + the proto decoder).

Usage:
python decode_csv_to_csv.py [input.csv] [output.csv]

Defaults:
input = /Users/dheerajchouhan/Downloads/test_new.csv
output = /tmp/decoded_test_new.csv
"""

import sys

from inference_logging_client import decode_mplog_proto_csv

# Import the full 256-feature schema from the sibling script.
from decode_single_row import SCHEMA


DEFAULT_INPUT = "/Users/dheerajchouhan/Downloads/test_new.csv"
DEFAULT_OUTPUT = "/tmp/decoded_test_new.csv"


def main():
input_csv = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_INPUT
output_csv = sys.argv[2] if len(sys.argv) > 2 else DEFAULT_OUTPUT

print(f"input : {input_csv}")
print(f"output : {output_csv}")
print(f"schema : {len(SCHEMA['data'])} features")

n = decode_mplog_proto_csv(
input_csv=input_csv,
output_csv=output_csv,
schema=SCHEMA,
)

print(f"decoded rows written: {n}")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# `decode_mplog_proto_dataframe` — example notebook\n",
"\n",
"This notebook demonstrates the `decode_mplog_proto_dataframe` API added in `inference-logging-client` 0.3.4.\n",
"\n",
"Use this method when:\n",
"- All rows in the input DataFrame are encoded as **proto** (not arrow / parquet).\n",
"- You already have the feature schema in hand (from your own API, a cached JSON, or a prior `get_feature_schema` call).\n",
"- You want to avoid contacting the inference service at decode time (no schema fetch, no positive cache, no negative cache, no per-worker fallback).\n",
"\n",
"Compared to `decode_mplog_dataframe`, this method skips: driver-side `distinct().collect()` for schema discovery, per-row metadata-byte parsing, format dispatch, and the `schema_cache.get()` per row. Same distributed `mapInPandas` pipeline, same Arrow 2 GiB safety (default `max_records_per_batch=50`), same input-column projection."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Install\n",
"\n",
"On Databricks, install at the cluster or notebook level:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade inference-logging-client==0.3.4 zstandard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dbutils.library.restartPython() # Databricks only"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Imports and Spark session"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pyspark.sql import SparkSession\n",
"from pyspark.sql import functions as F\n",
"\n",
"from inference_logging_client import decode_mplog_proto_dataframe\n",
"\n",
"spark = SparkSession.builder.appName(\"decode_mplog_proto_example\").getOrCreate()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Load the encoded logs DataFrame\n",
"\n",
"Adjust the table name / path and the partition filter for your environment. The DataFrame must have at least a `features` column (the encoded payloads) and a `mp_config_id` column. Optional columns that get passed through if present: `entities`, `parent_entity`, `prism_ingested_at`, `prism_extracted_at`, `created_at`, `tracking_id`, `user_id`, `year`, `month`, `day`, `hour`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"logs_df = (\n",
" spark.table(\"silver.ML_Platform__model_proxy_inference_logs\")\n",
" .filter(F.col(\"mp_config_id\") == \"my-model-proxy-id\")\n",
" .filter(F.concat_ws(\"-\", \"year\", \"month\", \"day\") == \"2026-05-09\")\n",
")\n",
"\n",
"logs_df.printSchema()\n",
"logs_df.limit(3).show(truncate=80)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Provide the feature schema\n",
"\n",
"The `schema` argument accepts three shapes — pick whichever your source already produces:\n",
"\n",
"**Option A — inference-service JSON response shape (most common):**\n",
"```python\n",
"schema = {\"data\": [{\"feature_name\": \"...\", \"feature_type\": \"DataTypeFP32\", \"feature_size\": 1}, ...]}\n",
"```\n",
"\n",
"**Option B — plain list of dicts:**\n",
"```python\n",
"schema = [{\"feature_name\": \"...\", \"feature_type\": \"DataTypeFP32\"}, ...]\n",
"```\n",
"\n",
"**Option C — typed `FeatureInfo` list (returned by `get_feature_schema`):**\n",
"```python\n",
"from inference_logging_client import get_feature_schema\n",
"schema = get_feature_schema(\"my-model-proxy-id\", 1)\n",
"```\n",
"\n",
"Array order is preserved and used as the proto field index — do not reorder."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example: schema fetched from your own API, in the inference-service JSON shape.\n",
"schema = {\n",
" \"data\": [\n",
" {\"feature_name\": \"user:derived_2_fp32:log_views_56day\", \"feature_type\": \"DataTypeFP32\", \"feature_size\": 1},\n",
" {\"feature_name\": \"user:derived_2_fp32:orders_by_clicks_laplace_56day\", \"feature_type\": \"DataTypeFP32\", \"feature_size\": 1},\n",
" {\"feature_name\": \"user:derived_2_fp32:avg_click_catalog_nqd_30day\", \"feature_type\": \"DataTypeFP32\", \"feature_size\": 1},\n",
" {\"feature_name\": \"user:derived_2_fp32:avg_order__catalog_arp_sscat_percentile__90day\", \"feature_type\": \"DataTypeFP32\", \"feature_size\": 1},\n",
" {\"feature_name\": \"user:derived_2_fp32:browse_time_last_7day\", \"feature_type\": \"DataTypeFP32\", \"feature_size\": 1},\n",
" {\"feature_name\": \"user:derived_2_fp32:clicks_by_views_laplace_28day\", \"feature_type\": \"DataTypeFP32\", \"feature_size\": 1},\n",
" {\"feature_name\": \"user:derived_2_fp32:engagement_click_percent\", \"feature_type\": \"DataTypeFP32\", \"feature_size\": 1},\n",
" {\"feature_name\": \"user:derived_2_fp32:retention_90_days\", \"feature_type\": \"DataTypeFP32\", \"feature_size\": 1},\n",
" {\"feature_name\": \"user:derived_2_fp32:user__nqp\", \"feature_type\": \"DataTypeFP32\", \"feature_size\": 1},\n",
" {\"feature_name\": \"user:derived_2_fp32:user__nqp_by_nqd\", \"feature_type\": \"DataTypeFP32\", \"feature_size\": 1},\n",
" ]\n",
"}\n",
"\n",
"print(f\"schema has {len(schema['data'])} features\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Decode\n",
"\n",
"Defaults that matter:\n",
"- `decompress=True` — automatically zstd-decompresses each payload if needed.\n",
"- `num_partitions=10000` — keeps each worker task small when rows carry multi-MB payloads.\n",
"- `max_records_per_batch=50` — keeps each Arrow batch under the 2 GiB per-column limit.\n",
"- `needed_columns=None` — decode all schema columns; pass a list/set to project early."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"decoded_df = decode_mplog_proto_dataframe(\n",
" df=logs_df,\n",
" spark=spark,\n",
" schema=schema,\n",
")\n",
"\n",
"decoded_df.printSchema()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"decoded_df.limit(5).show(truncate=80)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Decode only the columns you need (optional)\n",
"\n",
"Pass `needed_columns` so workers skip decoding and emitting features you don't care about. Significantly reduces output size and worker memory when the schema is wide."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"subset_df = decode_mplog_proto_dataframe(\n",
" df=logs_df,\n",
" spark=spark,\n",
" schema=schema,\n",
" needed_columns={\n",
" \"user:derived_2_fp32:log_views_56day\",\n",
" \"user:derived_2_fp32:retention_90_days\",\n",
" },\n",
")\n",
"\n",
"subset_df.printSchema()\n",
"subset_df.limit(5).show(truncate=80)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Tuning knobs (optional)\n",
"\n",
"Adjust if your rows are unusually small or unusually large:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tuned_df = decode_mplog_proto_dataframe(\n",
" df=logs_df,\n",
" spark=spark,\n",
" schema=schema,\n",
" num_partitions=20000, # raise if executors are CPU-starved\n",
" max_records_per_batch=20, # lower further if you still hit Arrow overflow\n",
" decompress=True,\n",
")\n",
"\n",
"tuned_df.limit(3).show(truncate=80)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Persist the decoded output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"(\n",
" decoded_df\n",
" .write\n",
" .mode(\"overwrite\")\n",
" .partitionBy(\"year\", \"month\", \"day\")\n",
" .parquet(\"s3://your-bucket/decoded_mplog/my-model-proxy-id/\")\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Notes\n",
"\n",
"- **Format is always PROTO.** If your logs use arrow or parquet encoding, use `decode_mplog_dataframe` instead.\n",
"- **Schema is applied to every row.** All rows in the input DataFrame must have been encoded against the schema you pass. If you have multiple `(mp_config_id, version)` combos in the same DataFrame and they use different schemas, filter and decode each group separately.\n",
"- **Type strings.** `DataTypeFP32`, `FP32`, `fp32` — all work. The decoder strips the `DataType` prefix and case-normalizes internally.\n",
"- **`feature_size`.** Ignored. The decoder infers scalar vs vector from the type name (`FP32` vs `FP32Vector`)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading