From c26ad67183f804182ce22de7d9c7e6ac5b198a3e Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Wed, 4 Mar 2026 22:23:32 -0500 Subject: [PATCH 1/3] Add LaughterSegmentation contrib model with inf2 benchmark notebook Wav2Vec2-based laughter detection model (315M params) compiled with torch_neuronx.trace(). Includes single-core and DataParallel benchmarks, accuracy validation, and end-to-end inference demo on inf2.xlarge. --- contrib/models/LaughterSegmentation/README.md | 102 ++ .../laughter_neuron_inf2.ipynb | 1597 +++++++++++++++++ 2 files changed, 1699 insertions(+) create mode 100644 contrib/models/LaughterSegmentation/README.md create mode 100644 contrib/models/LaughterSegmentation/laughter_neuron_inf2.ipynb diff --git a/contrib/models/LaughterSegmentation/README.md b/contrib/models/LaughterSegmentation/README.md new file mode 100644 index 00000000..c0423e7d --- /dev/null +++ b/contrib/models/LaughterSegmentation/README.md @@ -0,0 +1,102 @@ +# Contrib Model: LaughterSegmentation + +Laughter detection on AWS Inferentia2 using `torch_neuronx.trace()`. + +## Model Information + +- **HuggingFace ID:** `omine-me/LaughterSegmentation` +- **Model Type:** Wav2Vec2-based audio frame classifier +- **Parameters:** ~315M (FP32) +- **Base Model:** `jonatasgrosman/wav2vec2-large-xlsr-53-english` +- **License:** Check HuggingFace model card + +## Architecture Details + +The model uses `Wav2Vec2ForAudioFrameClassification` to classify each audio frame as laughter or not-laughter. It takes 7-second audio windows at 16 kHz (112,000 samples) and outputs 349 per-frame binary predictions. + +This model uses `torch_neuronx.trace()` (not NxD Inference) since it is an encoder-only classification model rather than an autoregressive LLM. + +## Validation Results + +**Validated:** 2026-03-04 +**Instance:** inf2.xlarge (1 Inferentia2 chip, 2 NeuronCores) +**SDK:** Neuron SDK 2.28, PyTorch 2.9 + +### Benchmark Results (Single Core) + +| Batch Size | Mean Latency | Throughput | Real-Time Factor | +|-----------|-------------|-----------|-----------------| +| 1 | 18.44 ms | 54.2 W/s | 380x | +| 2 | 21.27 ms | 94.0 W/s | 658x | +| 4 | 42.05 ms | 95.1 W/s | 666x | +| 8 | 83.93 ms | 95.3 W/s | 667x | + +### DataParallel Results (Full Instance, 2 Cores) + +| Configuration | Throughput | Real-Time Factor | Speedup | +|--------------|-----------|-----------------|---------| +| Single core (BS=2) | 94.0 W/s | 658x | 1.0x | +| DataParallel (2 cores) | 175.2 W/s | 1226x | 1.86x | + +### Accuracy Validation + +| Input | Cosine Similarity | Frame Agreement | +|-------|------------------|----------------| +| Random normal | 1.000000 | 100.00% | +| Quiet noise | 1.000000 | 100.00% | +| Loud signal | 1.000000 | 100.00% | +| Sine 440 Hz | 1.000000 | 100.00% | +| Silence | 0.999999 | 100.00% | + +**Status:** VALIDATED + +## Usage + +The included notebook (`laughter_neuron_inf2.ipynb`) contains the complete workflow: + +1. Download model weights from HuggingFace +2. Remove `weight_norm` parametrizations (required for SDK 2.28+) +3. Compile with `torch_neuronx.trace()` across multiple batch sizes +4. Benchmark single-core and DataParallel throughput +5. Validate accuracy against CPU reference +6. Run end-to-end inference on sample audio + +```bash +# On an inf2 or trn2 instance with DLAMI +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate +pip install jupyter safetensors librosa scipy pydub +jupyter notebook laughter_neuron_inf2.ipynb +``` + +### Known Issues + +**weight_norm crash on SDK 2.28+**: Wav2Vec2 uses `weight_norm` on `pos_conv_embed.conv`, which crashes `torch_neuronx.trace()`. The notebook strips parametrizations before tracing: + +```python +import torch.nn.utils.parametrize as parametrize +for name, module in model.named_modules(): + if hasattr(module, "parametrizations"): + for param_name in list(module.parametrizations.keys()): + parametrize.remove_parametrizations(module, param_name) +``` + +### trn2 Usage + +To run on trn2.3xlarge, uncomment the `--lnc` compiler arg in the compilation cell. + +## Compatibility Matrix + +| Instance/Version | SDK 2.28 | SDK 2.27 and earlier | +|------------------|----------|---------------------| +| inf2.xlarge | VALIDATED | Not tested | +| trn2.3xlarge | Tested (see project notes) | Not tested | + +## Example Checkpoints + +* [omine-me/LaughterSegmentation](https://huggingface.co/omine-me/LaughterSegmentation) + +## Maintainer + +Community contribution + +**Last Updated:** 2026-03-04 diff --git a/contrib/models/LaughterSegmentation/laughter_neuron_inf2.ipynb b/contrib/models/LaughterSegmentation/laughter_neuron_inf2.ipynb new file mode 100644 index 00000000..7ca7fc93 --- /dev/null +++ b/contrib/models/LaughterSegmentation/laughter_neuron_inf2.ipynb @@ -0,0 +1,1597 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Laughter Detection on AWS Inferentia2 with Neuron SDK\n", + "\n", + "This notebook traces, benchmarks, and validates the [LaughterSegmentation](https://github.com/omine-me/LaughterSegmentation) model on AWS Inferentia2 using the Neuron SDK. The model is a Wav2Vec2-based audio frame classifier (~315M params, FP32) that detects laughter segments in speech audio.\n", + "\n", + "**Instance**: inf2.xlarge (1 Inferentia2 chip, 2 NeuronCores)\n", + "**Time to complete**: ~15 minutes\n", + "**Prerequisites**: Deep Learning AMI Neuron (Ubuntu 24.04)\n", + "\n", + "### Quick Start\n", + "\n", + "```bash\n", + "# Activate pre-installed Neuron venv\n", + "source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate\n", + "pip install jupyter\n", + "jupyter notebook --no-browser --port 8888\n", + "```\n", + "\n", + "### trn2 Usage\n", + "\n", + "This notebook is configured for inf2. To run on trn2.3xlarge, uncomment the `--lnc` compiler arg in Step 2. See comments in the compilation cell." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model & Instance Specifications\n", + "\n", + "| Property | Value |\n", + "|----------|-------|\n", + "| Model | omine-me/LaughterSegmentation |\n", + "| Architecture | Wav2Vec2ForAudioFrameClassification |\n", + "| Parameters | ~315M (FP32) |\n", + "| Base model | jonatasgrosman/wav2vec2-large-xlsr-53-english |\n", + "| Input | 7-second audio windows at 16 kHz (112,000 samples) |\n", + "| Output | 349 per-frame binary predictions (laughter / not-laughter) |\n", + "\n", + "| Software | Version |\n", + "|----------|--------|\n", + "| DLAMI | Deep Learning AMI Neuron (Ubuntu 24.04) 20260227 |\n", + "| Neuron SDK | 2.28 |\n", + "| PyTorch | 2.9 |\n", + "| Transformers | 4.x |\n", + "\n", + "> **Important**: Wav2Vec2 uses `weight_norm` parametrizations on its positional convolution embedding. These must be removed before `torch_neuronx.trace()` or the compiler will crash with a PjRt buffer null error. This notebook handles this automatically." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 0: Setup & Environment Check\n", + "\n", + "Install required packages and verify the Neuron environment." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-05T02:08:07.036087Z", + "iopub.status.busy": "2026-03-05T02:08:07.035911Z", + "iopub.status.idle": "2026-03-05T02:12:01.450361Z", + "shell.execute_reply": "2026-03-05T02:12:01.449499Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch: 2.9.0+cu128\n", + "torch-neuronx: 2.9.0.2.12.22436+0f1dac25\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "instance-type: inf2.xlarge\n", + "instance-id: i-0f0fd1a81f421fbc9\n", + "+--------+--------+----------+--------+--------------+----------+------+\n", + "| NEURON | NEURON | NEURON | NEURON | PCI | CPU | NUMA |\n", + "| DEVICE | CORES | CORE IDS | MEMORY | BDF | AFFINITY | NODE |\n", + "+--------+--------+----------+--------+--------------+----------+------+\n", + "| 0 | 2 | 0-1 | 32 GB | 0000:00:1f.0 | 0-3 | -1 |\n", + "+--------+--------+----------+--------+--------------+----------+------+\n" + ] + }, + { + "data": { + "text/plain": [ + "CompletedProcess(args=['neuron-ls'], returncode=0)" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import subprocess, sys\n", + "\n", + "# Install dependencies\n", + "subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n", + " \"safetensors\", \"librosa\", \"scipy\", \"pydub\"])\n", + "\n", + "import torch\n", + "import torch_neuronx\n", + "\n", + "print(f\"PyTorch: {torch.__version__}\")\n", + "print(f\"torch-neuronx: {torch_neuronx.__version__}\")\n", + "print()\n", + "subprocess.run([\"neuron-ls\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 1: Download Model Weights\n", + "\n", + "Download the fine-tuned LaughterSegmentation weights (1.26 GB safetensors) from HuggingFace, then load the model with the correct architecture configuration.\n", + "\n", + "**Key detail**: The safetensors file stores weights with an `audio_model.` prefix (from the training wrapper class). We strip this prefix when loading into the bare HuggingFace `Wav2Vec2ForAudioFrameClassification`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-05T02:12:01.452225Z", + "iopub.status.busy": "2026-03-05T02:12:01.451877Z", + "iopub.status.idle": "2026-03-05T02:12:44.905894Z", + "shell.execute_reply": "2026-03-05T02:12:44.905013Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading 'model.safetensors' to 'models/.cache/huggingface/download/xGOKKLRSlIhH692hSVvI1-gpoa8=.449b14f73c70db26da9b4a59ee77d9a9b29fbcaceb083dd7ea27cdfaa68442a0.incomplete'\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Download complete. Moving file to models/model.safetensors\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "models/model.safetensors\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "effd7b52b6dd4a0f8253a52b2f284643", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0.00B [00:00, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Removing parametrization: wav2vec2.encoder.pos_conv_embed.conv.weight\n", + "\n", + "Model loaded: 315.4M parameters\n" + ] + } + ], + "source": [ + "import os\n", + "import torch.nn.utils.parametrize as parametrize\n", + "import safetensors.torch\n", + "from transformers import Wav2Vec2Config, Wav2Vec2ForAudioFrameClassification\n", + "\n", + "MODEL_DIR = \"./models\"\n", + "WEIGHTS_PATH = os.path.join(MODEL_DIR, \"model.safetensors\")\n", + "AUDIO_MODEL_NAME = \"jonatasgrosman/wav2vec2-large-xlsr-53-english\"\n", + "\n", + "# Download weights if not present\n", + "if not os.path.exists(WEIGHTS_PATH):\n", + " os.makedirs(MODEL_DIR, exist_ok=True)\n", + " subprocess.check_call([\"hf\", \"download\", \"omine-me/LaughterSegmentation\",\n", + " \"model.safetensors\", \"--local-dir\", MODEL_DIR])\n", + "\n", + "# Load model from config (avoids downloading full base model checkpoint)\n", + "config = Wav2Vec2Config.from_pretrained(AUDIO_MODEL_NAME)\n", + "config.num_labels = 1\n", + "config.problem_type = \"single_label_classification\"\n", + "model = Wav2Vec2ForAudioFrameClassification(config)\n", + "\n", + "# Load fine-tuned weights, stripping the \"audio_model.\" prefix\n", + "state_dict = safetensors.torch.load_file(WEIGHTS_PATH, device=\"cpu\")\n", + "prefix = \"audio_model.\"\n", + "stripped = {(k[len(prefix):] if k.startswith(prefix) else k): v\n", + " for k, v in state_dict.items()}\n", + "model.load_state_dict(stripped)\n", + "model.eval()\n", + "\n", + "# CRITICAL: Remove weight_norm parametrizations before tracing.\n", + "# Wav2Vec2 uses weight_norm on pos_conv_embed.conv, which crashes\n", + "# torch_neuronx.trace() on SDK 2.28 with a PjRt buffer null error.\n", + "for name, module in model.named_modules():\n", + " if hasattr(module, \"parametrizations\"):\n", + " for param_name in list(module.parametrizations.keys()):\n", + " print(f\"Removing parametrization: {name}.{param_name}\")\n", + " parametrize.remove_parametrizations(module, param_name)\n", + "\n", + "param_count = sum(p.numel() for p in model.parameters())\n", + "print(f\"\\nModel loaded: {param_count / 1e6:.1f}M parameters\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 2: Compile for Neuron\n", + "\n", + "Trace the model with `torch_neuronx.trace()` across multiple batch sizes. We use:\n", + "\n", + "- `--model-type transformer` -- optimizes attention patterns\n", + "- `--auto-cast matmult` -- casts FP32 matmuls to BF16 for ~2x throughput with negligible accuracy loss\n", + "- `inline_weights_to_neff=True` -- embeds weights in the compiled artifact for faster loading\n", + "\n", + "Each batch size requires a separate compilation (~2 minutes each)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-05T02:12:44.907877Z", + "iopub.status.busy": "2026-03-05T02:12:44.907488Z", + "iopub.status.idle": "2026-03-05T02:27:33.765519Z", + "shell.execute_reply": "2026-03-05T02:27:33.764610Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BS=1: Compiling... " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Compiler status PASS\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "159s, 488 MB\n", + "BS=2: Compiling... " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Compiler status PASS\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "145s, 491 MB\n", + "BS=4: Compiling... " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Compiler status PASS\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "203s, 496 MB\n", + "BS=8: Compiling... " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "." + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Compiler status PASS\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "374s, 506 MB\n", + "\n", + "Compiled 4 models.\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "INPUT_SEC = 7\n", + "SAMPLE_RATE = 16000\n", + "INPUT_SAMPLES = INPUT_SEC * SAMPLE_RATE # 112,000\n", + "\n", + "BATCH_SIZES = [1, 2, 4, 8]\n", + "COMPILED_DIR = \"./compiled\"\n", + "os.makedirs(COMPILED_DIR, exist_ok=True)\n", + "\n", + "compiler_args = [\n", + " \"--model-type\", \"transformer\",\n", + " \"--optlevel\", \"2\",\n", + " \"--auto-cast\", \"matmult\",\n", + " # --- trn2 only: uncomment one of the following ---\n", + " # \"--lnc\", \"1\", # LNC=1: model on 1 NeuronCore\n", + " # \"--lnc\", \"2\", # LNC=2: model spans 2 NeuronCores, better per-model latency\n", + "]\n", + "\n", + "compiled_models = {}\n", + "\n", + "for bs in BATCH_SIZES:\n", + " save_path = os.path.join(COMPILED_DIR, f\"laughter_bs{bs}.pt\")\n", + "\n", + " if os.path.exists(save_path):\n", + " print(f\"BS={bs}: Loading cached {save_path}\")\n", + " compiled_models[bs] = torch.jit.load(save_path)\n", + " continue\n", + "\n", + " print(f\"BS={bs}: Compiling...\", end=\" \", flush=True)\n", + " example_input = torch.randn(bs, INPUT_SAMPLES)\n", + "\n", + " t0 = time.time()\n", + " model_neuron = torch_neuronx.trace(\n", + " model, example_input,\n", + " compiler_args=compiler_args,\n", + " inline_weights_to_neff=True,\n", + " )\n", + " compile_time = time.time() - t0\n", + "\n", + " torch.jit.save(model_neuron, save_path)\n", + " file_mb = os.path.getsize(save_path) / (1024 * 1024)\n", + " print(f\"{compile_time:.0f}s, {file_mb:.0f} MB\")\n", + " compiled_models[bs] = model_neuron\n", + "\n", + "print(f\"\\nCompiled {len(compiled_models)} models.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 3: Benchmark\n", + "\n", + "Measure latency and throughput for each batch size. We report:\n", + "\n", + "- **Latency**: mean, p50, p95, p99 in milliseconds\n", + "- **Throughput**: audio windows processed per second\n", + "- **Real-time factor (RTF)**: seconds of audio processed per second of wall time (>1 = faster than real-time)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-05T02:27:33.767882Z", + "iopub.status.busy": "2026-03-05T02:27:33.767691Z", + "iopub.status.idle": "2026-03-05T02:28:14.412445Z", + "shell.execute_reply": "2026-03-05T02:28:14.411489Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " BS | Mean(ms) | p50(ms) | p95(ms) | p99(ms) | W/s | RTF\n", + "------------------------------------------------------------------------\n", + " 1 | 18.44 | 17.75 | 23.13 | 24.11 | 54.2 | 380x\n", + " 2 | 21.27 | 21.28 | 21.47 | 21.50 | 94.0 | 658x\n", + " 4 | 42.05 | 42.04 | 42.11 | 42.15 | 95.1 | 666x\n", + " 8 | 83.93 | 83.91 | 84.13 | 84.19 | 95.3 | 667x\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "WARMUP = 10\n", + "ITERATIONS = 100\n", + "\n", + "results = []\n", + "\n", + "for bs in BATCH_SIZES:\n", + " model_n = compiled_models[bs]\n", + " x = torch.randn(bs, INPUT_SAMPLES)\n", + "\n", + " # Warmup\n", + " for _ in range(WARMUP):\n", + " model_n(x)\n", + "\n", + " # Benchmark\n", + " latencies = []\n", + " for _ in range(ITERATIONS):\n", + " t0 = time.time()\n", + " model_n(x)\n", + " latencies.append((time.time() - t0) * 1000)\n", + "\n", + " lat = np.array(latencies)\n", + " throughput = bs / (lat.mean() / 1000)\n", + " rtf = throughput * INPUT_SEC\n", + "\n", + " results.append({\n", + " \"bs\": bs,\n", + " \"mean\": lat.mean(), \"p50\": np.percentile(lat, 50),\n", + " \"p95\": np.percentile(lat, 95), \"p99\": np.percentile(lat, 99),\n", + " \"throughput\": throughput, \"rtf\": rtf,\n", + " })\n", + "\n", + "# Print results table\n", + "print(f\"{'BS':>4} | {'Mean(ms)':>9} | {'p50(ms)':>8} | {'p95(ms)':>8} | {'p99(ms)':>8} | {'W/s':>8} | {'RTF':>8}\")\n", + "print(\"-\" * 72)\n", + "for r in results:\n", + " print(f\"{r['bs']:>4} | {r['mean']:>9.2f} | {r['p50']:>8.2f} | {r['p95']:>8.2f} | \"\n", + " f\"{r['p99']:>8.2f} | {r['throughput']:>8.1f} | {r['rtf']:>7.0f}x\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 3b: DataParallel Benchmark (Full Instance)\n", + "\n", + "`torch_neuronx.DataParallel` loads the same compiled model onto each available NeuronCore and runs\n", + "them in parallel. Each core processes its own full batch independently, multiplying total throughput\n", + "by the number of cores.\n", + "\n", + "On inf2.xlarge (2 NeuronCores), this gives each core its own batch, doubling throughput.\n", + "\n", + "**Important**: `DataParallel` defaults to `num_workers=2` threads. This works for 2-core instances,\n", + "but on instances with more cores, set `model_dp.num_workers = num_cores` for full parallelism." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-05T02:28:14.414120Z", + "iopub.status.busy": "2026-03-05T02:28:14.413937Z", + "iopub.status.idle": "2026-03-05T02:28:23.727898Z", + "shell.execute_reply": "2026-03-05T02:28:23.727089Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neuron cores detected: 2\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataParallel input shape: torch.Size([4, 112000]) (2 per core x 2 cores)\n", + "\n", + "Warming up DataParallel...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Benchmarking DataParallel (100 iterations)...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "=== DataParallel Results (2 cores) ===\n", + "Latency mean: 22.84 ms p95: 22.94 ms\n", + "Throughput: 175.2 windows/sec (1226x real-time)\n", + "\n", + "=== Comparison ===\n", + "Single core (BS=2): 94.0 W/s (658x RT)\n", + "DataParallel (2 cores): 175.2 W/s (1226x RT)\n", + "Speedup: 1.86x\n" + ] + } + ], + "source": [ + "import json as _json\n", + "\n", + "def get_neuron_core_count():\n", + " \"\"\"Detect the number of Neuron cores using neuron-ls.\"\"\"\n", + " try:\n", + " result = subprocess.run(\n", + " [\"neuron-ls\", \"--json-output\"],\n", + " capture_output=True, text=True, timeout=10\n", + " )\n", + " if result.returncode == 0:\n", + " devices = _json.loads(result.stdout)\n", + " return sum(d[\"nc_count\"] for d in devices)\n", + " except Exception:\n", + " pass\n", + " return 1\n", + "\n", + "num_cores = get_neuron_core_count()\n", + "print(f\"Neuron cores detected: {num_cores}\")\n", + "\n", + "# Use the best single-core batch size for DataParallel\n", + "DP_BATCH_SIZE = 2\n", + "\n", + "if num_cores > 1:\n", + " # Load the compiled model onto all cores\n", + " model_dp = torch_neuronx.DataParallel(compiled_models[DP_BATCH_SIZE])\n", + " model_dp.num_workers = num_cores\n", + "\n", + " # Each core gets DP_BATCH_SIZE, so total input is DP_BATCH_SIZE * num_cores\n", + " dp_total_batch = DP_BATCH_SIZE * num_cores\n", + " dp_input = torch.randn(dp_total_batch, INPUT_SAMPLES)\n", + " print(f\"DataParallel input shape: {dp_input.shape} ({DP_BATCH_SIZE} per core x {num_cores} cores)\")\n", + "\n", + " print(f\"\\nWarming up DataParallel...\")\n", + " for _ in range(WARMUP):\n", + " model_dp(dp_input)\n", + "\n", + " print(f\"Benchmarking DataParallel ({ITERATIONS} iterations)...\")\n", + " dp_latencies = []\n", + " for _ in range(ITERATIONS):\n", + " t0 = time.time()\n", + " model_dp(dp_input)\n", + " dp_latencies.append((time.time() - t0) * 1000)\n", + "\n", + " dp_lat = np.array(dp_latencies)\n", + " dp_throughput = dp_total_batch / (dp_lat.mean() / 1000)\n", + " dp_rtf = dp_throughput * INPUT_SEC\n", + "\n", + " # Find single-core results for comparison\n", + " single_result = [r for r in results if r['bs'] == DP_BATCH_SIZE][0]\n", + "\n", + " print(f\"\\n=== DataParallel Results ({num_cores} cores) ===\")\n", + " print(f\"Latency mean: {dp_lat.mean():.2f} ms p95: {np.percentile(dp_lat, 95):.2f} ms\")\n", + " print(f\"Throughput: {dp_throughput:.1f} windows/sec ({dp_rtf:.0f}x real-time)\")\n", + " print(f\"\\n=== Comparison ===\")\n", + " print(f\"Single core (BS={DP_BATCH_SIZE}): {single_result['throughput']:.1f} W/s ({single_result['rtf']:.0f}x RT)\")\n", + " print(f\"DataParallel ({num_cores} cores): {dp_throughput:.1f} W/s ({dp_rtf:.0f}x RT)\")\n", + " print(f\"Speedup: {dp_throughput / single_result['throughput']:.2f}x\")\n", + "else:\n", + " print(\"Only 1 Neuron core available - skipping DataParallel benchmark.\")\n", + " print(\"Use an instance with multiple cores (e.g., inf2.xlarge has 2) for DataParallel.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 4: Accuracy Validation\n", + "\n", + "Compare the Neuron-compiled model output against a CPU reference to verify that `--auto-cast matmult` does not degrade laughter detection accuracy.\n", + "\n", + "We generate random audio inputs of varying characteristics and measure:\n", + "- **Cosine similarity** between raw logit vectors\n", + "- **Frame-level prediction agreement** (after sigmoid + 0.5 threshold)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-05T02:28:23.729588Z", + "iopub.status.busy": "2026-03-05T02:28:23.729413Z", + "iopub.status.idle": "2026-03-05T02:28:32.389088Z", + "shell.execute_reply": "2026-03-05T02:28:32.388396Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Input | Cosine Sim | Max Error | Mean Error | Agreement | CPU Laugh | NRN Laugh\n", + "--------------------------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " random_normal | 1.000000 | 0.0195 | 0.0017 | 100.00% | 0 | 0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " quiet_noise | 1.000000 | 0.0093 | 0.0010 | 100.00% | 0 | 0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " loud_signal | 1.000000 | 0.0100 | 0.0047 | 100.00% | 0 | 0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " sine_440hz | 1.000000 | 0.0178 | 0.0030 | 100.00% | 0 | 0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " silence | 0.999999 | 0.0288 | 0.0030 | 100.00% | 0 | 0\n" + ] + } + ], + "source": [ + "neuron_model = compiled_models[1] # BS=1 for validation\n", + "\n", + "# Generate diverse test inputs\n", + "torch.manual_seed(42)\n", + "test_inputs = {\n", + " \"random_normal\": torch.randn(1, INPUT_SAMPLES),\n", + " \"quiet_noise\": torch.randn(1, INPUT_SAMPLES) * 0.01,\n", + " \"loud_signal\": torch.randn(1, INPUT_SAMPLES) * 5.0,\n", + " \"sine_440hz\": torch.sin(2 * torch.pi * 440 * torch.linspace(0, INPUT_SEC, INPUT_SAMPLES)).unsqueeze(0),\n", + " \"silence\": torch.zeros(1, INPUT_SAMPLES),\n", + "}\n", + "\n", + "print(f\"{'Input':>16} | {'Cosine Sim':>10} | {'Max Error':>10} | {'Mean Error':>10} | {'Agreement':>10} | {'CPU Laugh':>10} | {'NRN Laugh':>10}\")\n", + "print(\"-\" * 98)\n", + "\n", + "for name, x in test_inputs.items():\n", + " # CPU reference\n", + " with torch.no_grad():\n", + " cpu_out = model(input_values=x)\n", + " cpu_logits = cpu_out.logits.squeeze(-1) # [1, 349]\n", + "\n", + " # Neuron\n", + " neuron_out = neuron_model(x)\n", + " if isinstance(neuron_out, dict):\n", + " neuron_logits = neuron_out[\"logits\"].squeeze(-1)\n", + " elif isinstance(neuron_out, (tuple, list)):\n", + " neuron_logits = neuron_out[0].squeeze(-1)\n", + " else:\n", + " neuron_logits = neuron_out.squeeze(-1)\n", + "\n", + " cpu_flat = cpu_logits.flatten().float()\n", + " nrn_flat = neuron_logits.flatten().float()\n", + "\n", + " cosine = torch.nn.functional.cosine_similarity(cpu_flat.unsqueeze(0), nrn_flat.unsqueeze(0)).item()\n", + " max_err = (cpu_flat - nrn_flat).abs().max().item()\n", + " mean_err = (cpu_flat - nrn_flat).abs().mean().item()\n", + "\n", + " cpu_preds = (torch.sigmoid(cpu_logits) >= 0.5).int()\n", + " nrn_preds = (torch.sigmoid(neuron_logits) >= 0.5).int()\n", + " agree = (cpu_preds == nrn_preds).float().mean().item()\n", + "\n", + " print(f\"{name:>16} | {cosine:>10.6f} | {max_err:>10.4f} | {mean_err:>10.4f} | \"\n", + " f\"{agree * 100:>9.2f}% | {cpu_preds.sum().item():>10} | {nrn_preds.sum().item():>10}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Step 5: End-to-End Inference Demo\n", + "\n", + "Run the full laughter detection pipeline on a sample audio file:\n", + "1. Load audio at 16 kHz\n", + "2. Slice into 7-second windows with 2-second overlap\n", + "3. Batch inference through the Neuron model\n", + "4. Sigmoid + threshold + merge overlapping detections\n", + "5. Output JSON with laughter timestamps\n", + "\n", + "We use a librosa built-in audio sample for the demo." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2026-03-05T02:28:32.391051Z", + "iopub.status.busy": "2026-03-05T02:28:32.390869Z", + "iopub.status.idle": "2026-03-05T02:29:07.638860Z", + "shell.execute_reply": "2026-03-05T02:29:07.638004Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading file 'sorohanro_-_solo-trumpet-06.ogg' from 'https://librosa.org/data/audio/sorohanro_-_solo-trumpet-06.ogg' to '/home/ubuntu/.cache/librosa'.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Audio: 5.3s, 85335 samples\n", + "\n", + "Detected 1 laughter segment(s) in 0.025s (217x real-time):\n", + " [0] 0.28s - 1.58s (1.30s)\n", + "\n", + "Saved to ./output/laughter_results.json\n" + ] + } + ], + "source": [ + "import json\n", + "import librosa\n", + "\n", + "# Load a sample audio file (librosa built-in)\n", + "audio_path = librosa.example(\"trumpet\")\n", + "audio, sr = librosa.load(audio_path, sr=SAMPLE_RATE, mono=True)\n", + "duration = len(audio) / SAMPLE_RATE\n", + "print(f\"Audio: {duration:.1f}s, {len(audio)} samples\")\n", + "\n", + "# Windowed inference\n", + "OVERLAP_SEC = 2.0\n", + "stride = int(SAMPLE_RATE * (INPUT_SEC - OVERLAP_SEC))\n", + "neuron_bs2 = compiled_models[2] # Use BS=2 for throughput\n", + "batch_size = 2\n", + "\n", + "all_preds = []\n", + "laughter_events = {}\n", + "event_idx = 0\n", + "t0 = time.time()\n", + "\n", + "with torch.no_grad():\n", + " for start in range(0, len(audio), stride * batch_size):\n", + " windows = []\n", + " actual = 0\n", + " for i in range(batch_size):\n", + " w_start = start + i * stride\n", + " w = audio[w_start:w_start + INPUT_SAMPLES]\n", + " if len(w) == 0:\n", + " break\n", + " if len(w) < INPUT_SAMPLES:\n", + " w = np.append(w, np.zeros(INPUT_SAMPLES - len(w)))\n", + " windows.append(w)\n", + " actual += 1\n", + "\n", + " if actual == 0:\n", + " break\n", + "\n", + " # Pad to compiled batch size\n", + " while len(windows) < batch_size:\n", + " windows.append(np.zeros(INPUT_SAMPLES))\n", + "\n", + " x = torch.from_numpy(np.array(windows)).float()\n", + " out = neuron_bs2(x)\n", + " logits = (out[\"logits\"] if isinstance(out, dict) else out).squeeze(-1)\n", + " preds = (torch.sigmoid(logits[:actual].float()) >= 0.5).int()\n", + "\n", + " for bi in range(actual):\n", + " w_start_sec = (start + bi * stride) / SAMPLE_RATE\n", + " frame_pred = preds[bi].cpu().numpy()\n", + " frame_count = len(frame_pred)\n", + " in_laugh = False\n", + " seg_start = None\n", + "\n", + " for fi, f in enumerate(frame_pred):\n", + " if f == 1 and not in_laugh:\n", + " seg_start = fi\n", + " in_laugh = True\n", + " elif (f == 0 or fi == frame_count - 1) and in_laugh:\n", + " seg_end = fi if f == 0 else fi + 1\n", + " laughter_events[str(event_idx)] = {\n", + " \"start_sec\": round(w_start_sec + (INPUT_SEC / frame_count) * seg_start, 3),\n", + " \"end_sec\": round(w_start_sec + (INPUT_SEC / frame_count) * seg_end, 3),\n", + " }\n", + " event_idx += 1\n", + " in_laugh = False\n", + "\n", + "inference_time = time.time() - t0\n", + "\n", + "# Merge overlapping events from overlapping windows\n", + "merged = {}\n", + "for evt in sorted(laughter_events.values(), key=lambda e: e[\"start_sec\"]):\n", + " if not merged:\n", + " merged[\"0\"] = evt.copy()\n", + " else:\n", + " last = list(merged.values())[-1]\n", + " if evt[\"start_sec\"] <= last[\"end_sec\"]:\n", + " last[\"end_sec\"] = max(last[\"end_sec\"], evt[\"end_sec\"])\n", + " else:\n", + " merged[str(len(merged))] = evt.copy()\n", + "\n", + "# Remove short events (<0.2s)\n", + "merged = {k: v for k, v in merged.items() if v[\"end_sec\"] - v[\"start_sec\"] >= 0.2}\n", + "merged = {str(i): v for i, v in enumerate(merged.values())}\n", + "\n", + "print(f\"\\nDetected {len(merged)} laughter segment(s) in {inference_time:.3f}s ({duration / inference_time:.0f}x real-time):\")\n", + "for idx, seg in merged.items():\n", + " print(f\" [{idx}] {seg['start_sec']:.2f}s - {seg['end_sec']:.2f}s ({seg['end_sec'] - seg['start_sec']:.2f}s)\")\n", + "\n", + "if not merged:\n", + " print(\" (no laughter detected)\")\n", + "\n", + "# Save results\n", + "os.makedirs(\"./output\", exist_ok=True)\n", + "with open(\"./output/laughter_results.json\", \"w\") as f:\n", + " json.dump(merged, f, indent=2)\n", + "print(f\"\\nSaved to ./output/laughter_results.json\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Summary\n", + "\n", + "This notebook demonstrated:\n", + "\n", + "1. **Tracing** Wav2Vec2ForAudioFrameClassification on Neuron with `torch_neuronx.trace()`\n", + "2. **Single-core benchmarking** across batch sizes -- throughput saturates around BS=2-4 on inf2\n", + "3. **DataParallel benchmarking** -- `torch_neuronx.DataParallel` for full-instance throughput across all NeuronCores\n", + "4. **Accuracy validation** -- cosine similarity ~1.0, prediction agreement 100% with `--auto-cast matmult`\n", + "5. **End-to-end inference** -- audio file to laughter timestamps at 100x+ real-time\n", + "\n", + "### Key Findings\n", + "\n", + "- **weight_norm fix required**: `parametrize.remove_parametrizations()` must be called before tracing any model with weight_norm on SDK 2.28+\n", + "- **`--auto-cast matmult` is safe**: negligible accuracy impact for this binary classification task\n", + "- **`inline_weights_to_neff=True`** recommended for single-model deployment\n", + "\n", + "### trn2 Notes\n", + "\n", + "To run on trn2.3xlarge, add `\"--lnc\", \"1\"` or `\"--lnc\", \"2\"` to `compiler_args` in Step 2:\n", + "- **LNC=1**: Model uses 1 NeuronCore. May be more efficient for smaller models.\n", + "- **LNC=2**: Model spans 2 NeuronCores. Better per-model latency, may not give maximum throughput." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": { + "156056ebc0584c97afa8fe353d5b75b7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1f3274edf018499d8aec6f3fd77e5642": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_2fe64674668c4812833bc630fc743daf", + "placeholder": "​", + "style": "IPY_MODEL_9c30b64321ea4989aeb3d00fcda9c043", + "tabbable": null, + "tooltip": null, + "value": " 1.53k/? [00:00<00:00, 241kB/s]" + } + }, + "2fe64674668c4812833bc630fc743daf": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7a6ec080891b4b598249abfd0d3f0e3c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9c30b64321ea4989aeb3d00fcda9c043": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "c40efdc4e9604302b3065cd1e547f742": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "20px" + } + }, + "d6c17942903947b49d082cc837397f0d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_7a6ec080891b4b598249abfd0d3f0e3c", + "placeholder": "​", + "style": "IPY_MODEL_ede6901c94884d1da6ac2c33c99dce7f", + "tabbable": null, + "tooltip": null, + "value": "config.json: " + } + }, + "e4846163114e4ebf9de6f2b5538b0dd1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "e926a30f8e7c47439c226e6b9a85580e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_c40efdc4e9604302b3065cd1e547f742", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_e4846163114e4ebf9de6f2b5538b0dd1", + "tabbable": null, + "tooltip": null, + "value": 1 + } + }, + "ede6901c94884d1da6ac2c33c99dce7f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "effd7b52b6dd4a0f8253a52b2f284643": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_d6c17942903947b49d082cc837397f0d", + "IPY_MODEL_e926a30f8e7c47439c226e6b9a85580e", + "IPY_MODEL_1f3274edf018499d8aec6f3fd77e5642" + ], + "layout": "IPY_MODEL_156056ebc0584c97afa8fe353d5b75b7", + "tabbable": null, + "tooltip": null + } + } + }, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 5af904306c04b44e9dff727b463d6dad497f2d55 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Thu, 5 Mar 2026 18:53:45 -0500 Subject: [PATCH 2/3] Add integration tests for LaughterSegmentation 12/12 tests pass on inf2.xlarge (SDK 2.28, PyTorch 2.9): - Smoke tests (model loads and runs) - Accuracy: cosine similarity >= 0.999, 100% frame agreement - DataParallel: 1.88x speedup on 2 cores (176.5 W/s) - Performance: 101.5 W/s throughput, 9.85 ms p50 latency --- .../LaughterSegmentation/test/__init__.py | 0 .../test/integration/__init__.py | 0 .../test/integration/test_model.py | 496 ++++++++++++++++++ .../test/unit/__init__.py | 0 4 files changed, 496 insertions(+) create mode 100644 contrib/models/LaughterSegmentation/test/__init__.py create mode 100644 contrib/models/LaughterSegmentation/test/integration/__init__.py create mode 100644 contrib/models/LaughterSegmentation/test/integration/test_model.py create mode 100644 contrib/models/LaughterSegmentation/test/unit/__init__.py diff --git a/contrib/models/LaughterSegmentation/test/__init__.py b/contrib/models/LaughterSegmentation/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/LaughterSegmentation/test/integration/__init__.py b/contrib/models/LaughterSegmentation/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/LaughterSegmentation/test/integration/test_model.py b/contrib/models/LaughterSegmentation/test/integration/test_model.py new file mode 100644 index 00000000..6306c37b --- /dev/null +++ b/contrib/models/LaughterSegmentation/test/integration/test_model.py @@ -0,0 +1,496 @@ +""" +Integration tests for LaughterSegmentation on Neuron. + +Tests compile and run the Wav2Vec2-based laughter detection model using +torch_neuronx.trace() on Inferentia2 / Trainium2. Validates accuracy by +comparing Neuron output against CPU reference using cosine similarity. + +Usage: + # Run with pytest + pytest test_model.py --capture=tee-sys -v + + # Run standalone + python test_model.py + +Prerequisites: + source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate + pip install safetensors +""" + +import json +import os +import subprocess +import time +from pathlib import Path + +import pytest +import torch +import torch.nn.utils.parametrize as parametrize + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +MODEL_ID = "omine-me/LaughterSegmentation" +BASE_MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-english" +MODEL_DIR = "/home/ubuntu/models/LaughterSegmentation" +COMPILED_DIR = "/home/ubuntu/neuron_models/LaughterSegmentation" + +INPUT_SEC = 7 +SAMPLE_RATE = 16000 +INPUT_SAMPLES = INPUT_SEC * SAMPLE_RATE # 112,000 +BATCH_SIZE = 1 + +COMPILER_ARGS = [ + "--model-type", + "transformer", + "--optlevel", + "2", + "--auto-cast", + "matmult", +] + +# Accuracy thresholds +COSINE_SIM_THRESHOLD = 0.999 +FRAME_AGREEMENT_THRESHOLD = 0.99 + +# Performance thresholds +THROUGHPUT_THRESHOLD = 40.0 # windows/sec minimum on single core +LATENCY_THRESHOLD_MS = 50.0 # max p50 latency at BS=1 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def download_weights(): + """Download model weights from HuggingFace if not present.""" + weights_path = os.path.join(MODEL_DIR, "model.safetensors") + if not os.path.exists(weights_path): + os.makedirs(MODEL_DIR, exist_ok=True) + subprocess.check_call( + ["hf", "download", MODEL_ID, "model.safetensors", "--local-dir", MODEL_DIR] + ) + return weights_path + + +def load_cpu_model(): + """Load the LaughterSegmentation model on CPU with parametrizations removed.""" + import safetensors.torch + from transformers import Wav2Vec2Config, Wav2Vec2ForAudioFrameClassification + + weights_path = download_weights() + + config = Wav2Vec2Config.from_pretrained(BASE_MODEL_ID) + config.num_labels = 1 + config.problem_type = "single_label_classification" + model = Wav2Vec2ForAudioFrameClassification(config) + + # Load weights, stripping "audio_model." prefix + state_dict = safetensors.torch.load_file(weights_path, device="cpu") + prefix = "audio_model." + stripped = { + (k[len(prefix) :] if k.startswith(prefix) else k): v + for k, v in state_dict.items() + } + model.load_state_dict(stripped) + model.eval() + + # Remove weight_norm parametrizations (required for SDK 2.28+) + for name, module in model.named_modules(): + if hasattr(module, "parametrizations"): + for param_name in list(module.parametrizations.keys()): + parametrize.remove_parametrizations(module, param_name) + + return model + + +def compile_neuron_model(cpu_model, batch_size=BATCH_SIZE): + """Compile the model for Neuron and cache the result.""" + import torch_neuronx + + save_path = os.path.join(COMPILED_DIR, f"laughter_bs{batch_size}.pt") + + if os.path.exists(save_path): + return torch.jit.load(save_path) + + os.makedirs(COMPILED_DIR, exist_ok=True) + example_input = torch.randn(batch_size, INPUT_SAMPLES) + + model_neuron = torch_neuronx.trace( + cpu_model, + example_input, + compiler_args=COMPILER_ARGS, + inline_weights_to_neff=True, + ) + + torch.jit.save(model_neuron, save_path) + return model_neuron + + +def get_neuron_core_count(): + """Detect available NeuronCores via neuron-ls.""" + try: + result = subprocess.run( + ["neuron-ls", "--json-output"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + devices = json.loads(result.stdout) + return sum(d["nc_count"] for d in devices) + except Exception: + pass + return 1 + + +# --------------------------------------------------------------------------- +# Fixtures (module-scoped so compile happens once) +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def cpu_model(): + """Load the CPU reference model.""" + return load_cpu_model() + + +@pytest.fixture(scope="module") +def neuron_model(cpu_model): + """Compile and load the Neuron model (BS=1).""" + return compile_neuron_model(cpu_model, batch_size=BATCH_SIZE) + + +@pytest.fixture(scope="module") +def neuron_model_bs2(cpu_model): + """Compile and load the Neuron model (BS=2) for DataParallel.""" + return compile_neuron_model(cpu_model, batch_size=2) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestModelLoads: + """Smoke tests that the model compiles and loads on Neuron.""" + + def test_neuron_model_loads(self, neuron_model): + """Model compiles and loads successfully.""" + assert neuron_model is not None + + def test_neuron_model_runs(self, neuron_model): + """Model produces output of expected shape.""" + x = torch.randn(BATCH_SIZE, INPUT_SAMPLES) + out = neuron_model(x) + # Output is a tensor of shape [batch, 349, 1] or [batch, 349] + if isinstance(out, dict): + logits = out["logits"] + elif isinstance(out, (tuple, list)): + logits = out[0] + else: + logits = out + assert logits.shape[0] == BATCH_SIZE + assert logits.shape[1] == 349 # expected frame count for 7s at 16kHz + + +class TestAccuracy: + """Validate Neuron output against CPU reference.""" + + @pytest.mark.parametrize( + "input_name,input_fn", + [ + ("random_normal", lambda: torch.randn(1, INPUT_SAMPLES)), + ("quiet_noise", lambda: torch.randn(1, INPUT_SAMPLES) * 0.01), + ("loud_signal", lambda: torch.randn(1, INPUT_SAMPLES) * 5.0), + ( + "sine_440hz", + lambda: torch.sin( + 2 * torch.pi * 440 * torch.linspace(0, INPUT_SEC, INPUT_SAMPLES) + ).unsqueeze(0), + ), + ("silence", lambda: torch.zeros(1, INPUT_SAMPLES)), + ], + ) + def test_cosine_similarity(self, cpu_model, neuron_model, input_name, input_fn): + """Cosine similarity between CPU and Neuron logits exceeds threshold.""" + torch.manual_seed(42) + x = input_fn() + + # CPU reference + with torch.no_grad(): + cpu_out = cpu_model(input_values=x) + cpu_logits = cpu_out.logits.squeeze(-1).flatten().float() + + # Neuron + neuron_out = neuron_model(x) + if isinstance(neuron_out, dict): + neuron_logits = neuron_out["logits"] + elif isinstance(neuron_out, (tuple, list)): + neuron_logits = neuron_out[0] + else: + neuron_logits = neuron_out + neuron_logits = neuron_logits.squeeze(-1).flatten().float() + + cosine = torch.nn.functional.cosine_similarity( + cpu_logits.unsqueeze(0), neuron_logits.unsqueeze(0) + ).item() + + print(f" {input_name}: cosine_sim={cosine:.6f}") + assert cosine >= COSINE_SIM_THRESHOLD, ( + f"Cosine similarity {cosine:.6f} below threshold {COSINE_SIM_THRESHOLD}" + ) + + def test_frame_agreement(self, cpu_model, neuron_model): + """Frame-level prediction agreement on random input.""" + torch.manual_seed(42) + x = torch.randn(1, INPUT_SAMPLES) + + with torch.no_grad(): + cpu_out = cpu_model(input_values=x) + cpu_preds = (torch.sigmoid(cpu_out.logits.squeeze(-1)) >= 0.5).int() + + neuron_out = neuron_model(x) + if isinstance(neuron_out, dict): + neuron_logits = neuron_out["logits"] + elif isinstance(neuron_out, (tuple, list)): + neuron_logits = neuron_out[0] + else: + neuron_logits = neuron_out + neuron_preds = (torch.sigmoid(neuron_logits.squeeze(-1)) >= 0.5).int() + + agreement = (cpu_preds == neuron_preds).float().mean().item() + print(f" Frame agreement: {agreement * 100:.2f}%") + assert agreement >= FRAME_AGREEMENT_THRESHOLD, ( + f"Frame agreement {agreement:.4f} below threshold {FRAME_AGREEMENT_THRESHOLD}" + ) + + +class TestDataParallel: + """Test DataParallel for full-instance throughput.""" + + def test_data_parallel_runs(self, neuron_model_bs2): + """DataParallel loads and produces correct output shape.""" + import torch_neuronx + + num_cores = get_neuron_core_count() + if num_cores < 2: + pytest.skip("Only 1 NeuronCore available, skipping DataParallel test") + + model_dp = torch_neuronx.DataParallel(neuron_model_bs2) + model_dp.num_workers = num_cores + + dp_total_batch = 2 * num_cores + x = torch.randn(dp_total_batch, INPUT_SAMPLES) + + out = model_dp(x) + if isinstance(out, dict): + logits = out["logits"] + elif isinstance(out, (tuple, list)): + logits = out[0] + else: + logits = out + + assert logits.shape[0] == dp_total_batch + assert logits.shape[1] == 349 + print(f" DataParallel OK: {num_cores} cores, output shape {logits.shape}") + + def test_data_parallel_speedup(self, neuron_model_bs2): + """DataParallel achieves meaningful speedup over single core.""" + import torch_neuronx + import numpy as np + + num_cores = get_neuron_core_count() + if num_cores < 2: + pytest.skip("Only 1 NeuronCore available, skipping DataParallel test") + + # Single core baseline + x_single = torch.randn(2, INPUT_SAMPLES) + for _ in range(5): + neuron_model_bs2(x_single) + single_times = [] + for _ in range(20): + t0 = time.time() + neuron_model_bs2(x_single) + single_times.append(time.time() - t0) + single_throughput = 2 / np.mean(single_times) + + # DataParallel + model_dp = torch_neuronx.DataParallel(neuron_model_bs2) + model_dp.num_workers = num_cores + dp_total_batch = 2 * num_cores + x_dp = torch.randn(dp_total_batch, INPUT_SAMPLES) + for _ in range(5): + model_dp(x_dp) + dp_times = [] + for _ in range(20): + t0 = time.time() + model_dp(x_dp) + dp_times.append(time.time() - t0) + dp_throughput = dp_total_batch / np.mean(dp_times) + + speedup = dp_throughput / single_throughput + print( + f" Single: {single_throughput:.1f} W/s, DP: {dp_throughput:.1f} W/s, Speedup: {speedup:.2f}x" + ) + assert speedup > 1.3, ( + f"DataParallel speedup {speedup:.2f}x too low (expected >1.3x)" + ) + + +class TestPerformance: + """Benchmark throughput and latency.""" + + def test_throughput(self, neuron_model): + """Single-core throughput exceeds minimum threshold.""" + import numpy as np + + x = torch.randn(BATCH_SIZE, INPUT_SAMPLES) + + # Warmup + for _ in range(10): + neuron_model(x) + + # Benchmark + latencies = [] + for _ in range(50): + t0 = time.time() + neuron_model(x) + latencies.append((time.time() - t0) * 1000) + + lat = np.array(latencies) + throughput = BATCH_SIZE / (lat.mean() / 1000) + p50 = np.percentile(lat, 50) + + print(f" Throughput: {throughput:.1f} W/s, p50: {p50:.2f} ms") + assert throughput >= THROUGHPUT_THRESHOLD, ( + f"Throughput {throughput:.1f} below threshold {THROUGHPUT_THRESHOLD}" + ) + + def test_latency(self, neuron_model): + """p50 latency is below threshold at BS=1.""" + import numpy as np + + x = torch.randn(BATCH_SIZE, INPUT_SAMPLES) + + for _ in range(10): + neuron_model(x) + + latencies = [] + for _ in range(50): + t0 = time.time() + neuron_model(x) + latencies.append((time.time() - t0) * 1000) + + p50 = np.percentile(latencies, 50) + print(f" p50 latency: {p50:.2f} ms") + assert p50 <= LATENCY_THRESHOLD_MS, ( + f"p50 latency {p50:.2f} ms exceeds threshold {LATENCY_THRESHOLD_MS} ms" + ) + + +# --------------------------------------------------------------------------- +# Standalone runner +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print("=" * 60) + print("LaughterSegmentation Integration Tests") + print("=" * 60) + + print("\n[1/7] Loading CPU model...") + model_cpu = load_cpu_model() + print( + f" Model loaded: {sum(p.numel() for p in model_cpu.parameters()) / 1e6:.1f}M params" + ) + + print("\n[2/7] Compiling for Neuron (BS=1)...") + model_neuron = compile_neuron_model(model_cpu, batch_size=1) + print(" Compile OK") + + print("\n[3/7] Testing model runs...") + x = torch.randn(1, INPUT_SAMPLES) + out = model_neuron(x) + logits = ( + out[0] + if isinstance(out, (tuple, list)) + else (out["logits"] if isinstance(out, dict) else out) + ) + print(f" Output shape: {logits.shape}") + + print("\n[4/7] Accuracy validation (cosine similarity)...") + torch.manual_seed(42) + test_inputs = { + "random_normal": torch.randn(1, INPUT_SAMPLES), + "quiet_noise": torch.randn(1, INPUT_SAMPLES) * 0.01, + "loud_signal": torch.randn(1, INPUT_SAMPLES) * 5.0, + "silence": torch.zeros(1, INPUT_SAMPLES), + } + all_pass = True + for name, inp in test_inputs.items(): + with torch.no_grad(): + cpu_out = model_cpu(input_values=inp) + cpu_logits = cpu_out.logits.squeeze(-1).flatten().float() + nrn_out = model_neuron(inp) + nrn_logits = ( + nrn_out[0] + if isinstance(nrn_out, (tuple, list)) + else (nrn_out["logits"] if isinstance(nrn_out, dict) else nrn_out) + ) + nrn_logits = nrn_logits.squeeze(-1).flatten().float() + cosine = torch.nn.functional.cosine_similarity( + cpu_logits.unsqueeze(0), nrn_logits.unsqueeze(0) + ).item() + status = "PASS" if cosine >= COSINE_SIM_THRESHOLD else "FAIL" + if cosine < COSINE_SIM_THRESHOLD: + all_pass = False + print(f" {name}: cosine={cosine:.6f} [{status}]") + print(f" Overall: {'PASS' if all_pass else 'FAIL'}") + + print("\n[5/7] Performance benchmark (BS=1)...") + import numpy as np + + x = torch.randn(1, INPUT_SAMPLES) + for _ in range(10): + model_neuron(x) + latencies = [] + for _ in range(50): + t0 = time.time() + model_neuron(x) + latencies.append((time.time() - t0) * 1000) + lat = np.array(latencies) + print(f" Throughput: {1 / (lat.mean() / 1000):.1f} W/s") + print( + f" p50: {np.percentile(lat, 50):.2f} ms, p99: {np.percentile(lat, 99):.2f} ms" + ) + + print("\n[6/7] Compiling BS=2 for DataParallel...") + model_bs2 = compile_neuron_model(model_cpu, batch_size=2) + print(" Compile OK") + + print("\n[7/7] DataParallel test...") + num_cores = get_neuron_core_count() + if num_cores >= 2: + import torch_neuronx + + model_dp = torch_neuronx.DataParallel(model_bs2) + model_dp.num_workers = num_cores + dp_batch = 2 * num_cores + x_dp = torch.randn(dp_batch, INPUT_SAMPLES) + for _ in range(5): + model_dp(x_dp) + dp_times = [] + for _ in range(20): + t0 = time.time() + model_dp(x_dp) + dp_times.append(time.time() - t0) + dp_throughput = dp_batch / np.mean(dp_times) + print(f" DataParallel ({num_cores} cores): {dp_throughput:.1f} W/s") + else: + print(" Skipped: only 1 NeuronCore") + + print("\n" + "=" * 60) + print("All tests complete.") + print("=" * 60) diff --git a/contrib/models/LaughterSegmentation/test/unit/__init__.py b/contrib/models/LaughterSegmentation/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b From 923fe6cfb8e2ffd5d43b03bd073ffa22c5253cea Mon Sep 17 00:00:00 2001 From: Jim Burtoft <39492751+jimburtoft@users.noreply.github.com> Date: Thu, 5 Mar 2026 20:06:22 -0500 Subject: [PATCH 3/3] Update README with maintainer details Added maintainer information for Jim Burtoft. --- contrib/models/LaughterSegmentation/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/contrib/models/LaughterSegmentation/README.md b/contrib/models/LaughterSegmentation/README.md index c0423e7d..99f904bf 100644 --- a/contrib/models/LaughterSegmentation/README.md +++ b/contrib/models/LaughterSegmentation/README.md @@ -97,6 +97,7 @@ To run on trn2.3xlarge, uncomment the `--lnc` compiler arg in the compilation ce ## Maintainer +Jim Burtoft Community contribution **Last Updated:** 2026-03-04