diff --git a/records/track_10min_16mb/2026-03-19_QAT_Ablation/README.md b/records/track_10min_16mb/2026-03-19_QAT_Ablation/README.md new file mode 100644 index 000000000..0faf735a2 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_QAT_Ablation/README.md @@ -0,0 +1,113 @@ +# 2026-03-19_QAT_Ablation + +*Non-record: Does int8 quantization-aware training improve post-roundtrip val_bpb?* + +**Answer: No — the overhead costs more than it recovers.** + +--- + +## Question + +The baseline loses ~0.007 BPB in the int8+zlib export step because bf16-trained weights are rounded cold onto the int8 grid. Every leaderboard entry so far attacks this gap indirectly — aggressive warmdown for tighter weight distributions, FP16 embedding bypass, or alternative quantization formats (int6). Nobody has trained directly against the int8 quantization grid. + +This submission tests whether QAT (straight-through fake-quantize matching the export pipeline exactly) recovers some of that gap. The experiment isolates QAT as the only variable — baseline architecture, baseline hyperparameters, no other changes. + +--- + +## Method + +A `fake_quantize_int8_per_row` function is inserted into `CastedLinear.forward`. It matches the export pipeline's `quantize_float_tensor` exactly: +- Same `INT8_CLIP_Q = 0.9999984` percentile clipping via `torch.quantile` +- Same per-row scale: `clip_abs / 127.0` +- Same rounding: `round().clamp(-127, 127)` +- Straight-through estimator: gradients pass through as if no quantization happened + +**Schedule:** QAT activates at 30% of training steps (~step 6,000). Training runs bf16-only before that to let the loss landscape stabilize. + +**No other changes.** Architecture is 9L×512d, all hyperparameters are baseline defaults (WARMDOWN_ITERS=1200, MATRIX_LR=0.04, etc). + +--- + +## Results + +| Metric | SlidingWindowEval (no QAT) | This run (QAT) | +|--------|---------------------------|----------------| +| Steps completed | 13,450 | 8,011 | +| step_avg | 44.6ms | 75.2ms (64.5 pre-QAT, 77+ post-QAT) | +| Pre-quant val_bpb (standard eval) | 1.2196 | 1.2327 | +| **Post-quant val_bpb (sliding window)** | **1.1925** | **1.2052** | +| Artifact bytes | 15,874,829 | 15,868,103 | +| Eval time | 70s | 75s | + +**val_bpb: 1.2052 vs 1.1925 — QAT is 0.013 worse.** + +--- + +## Why it didn't work + +The result is **not** evidence that QAT is a bad idea. It's evidence that **exact percentile-matching QAT is too expensive for int8 in this competition format.** + +### The core problem: `torch.quantile` overhead + +Matching the export pipeline exactly requires `torch.quantile(w.abs(), 0.9999984, dim=1)` on every weight matrix, every forward pass. This adds **~20% per-step overhead** (64ms → 77ms after QAT activates). Over a 600-second training budget, that costs ~2,000 training steps — roughly 1B fewer training tokens. + +The lost training tokens hurt more than the quantization gap recovery helps. The int8 quantization gap (~0.007 BPB) is smaller than the convergence loss from 40% fewer training steps. + +### Why this matters for the competition + +| Approach | Per-step cost | Quant gap reduction | Net effect | +|----------|--------------|--------------------|----| +| Aggressive warmdown (WD=20000) | 0% overhead | ~0.009 BPB | **Positive** | +| FP16 tied embedding | 0% overhead, ~500KB artifact | ~0.004 BPB | **Positive** | +| Int8 QAT (this submission) | ~20% overhead → ~2000 fewer steps | ~0.003-0.006 BPB theoretical | **Negative** (overhead > recovery) | +| Int6 QAT (PRs #128, #137) | ~20% overhead | ~0.01+ BPB (larger gap) | **Likely positive** (larger gap to close) | + +### When QAT would work + +1. **With int6 quantization** — the quantization gap is larger (~0.01+ BPB), making the overhead worthwhile. PRs #128 and #137 confirm this with val_bpb 1.1594 and 1.1666 respectively. +2. **With `amax` instead of `torch.quantile`** — near-zero overhead, but doesn't match the export pipeline exactly. The 0.0001% percentile difference may not matter in practice. +3. **With a longer training budget** — if the wallclock cap were 30 minutes instead of 10, the overhead would be amortized over more steps. + +--- + +## Graph priming finding + +An earlier version pre-primed the QAT compiled graph during warmup (running one forward/backward pass with `_qat=True`, then resetting to `_qat=False`). This caused `torch.compile` to use a slower compilation path for the non-QAT forward pass — step_avg was 65ms from step 1, even before QAT activated. Removing the graph priming restored baseline speed for the non-QAT phase. This is a useful finding for anyone implementing conditional code paths under `torch.compile(dynamic=False, fullgraph=True)`. + +--- + +## Reproduction + +```bash +cd /workspace +git clone https://github.com/mrdavtan/parameter-golf.git +cd parameter-golf && git checkout qat-sliding-window +python3 data/cached_challenge_fineweb.py --variant sp1024 + +# Set env vars +export VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 +export MLP_MULT=2 TIE_EMBEDDINGS=1 TRAIN_BATCH_TOKENS=524288 TRAIN_SEQ_LEN=1024 +export ITERATIONS=20000 WARMDOWN_ITERS=1200 WARMUP_STEPS=20 +export MAX_WALLCLOCK_SECONDS=600 TRAIN_LOG_EVERY=200 VAL_LOSS_EVERY=0 +export QAT=1 EVAL_STRIDE=64 EVAL_BATCH_SEQS=32 DOC_ISOLATED_EVAL=0 +export SEED=1337 RUN_ID=ablation_qat_slide64 + +torchrun --standalone --nproc_per_node=8 \ + records/track_10min_16mb/2026-03-19_QAT_Ablation/train_gpt.py +``` + +Hardware: 8×H100 SXM (RunPod), PyTorch 2.9.1+cu128 + +--- + +## Acknowledgments + +- `train_gpt.py` is based on the SlidingWindowEval entry (#50) by @mattqlf, which provides the sliding window evaluation infrastructure +- Analysis informed by the WarmdownQuantization entry by @samuellarson (warmdown vs QAT tradeoffs) and the LoRA TTT ablation by @samacquaviva (doc-isolated eval gains) +- Int6 QAT comparison data from PRs #128 (@rsavitt) and #137 (@abhishekgahlot2) +- Built with [Claude Code](https://claude.com/claude-code) + +## Author + +GitHub: [@mrdavtan](https://github.com/mrdavtan) +Date: 2026-03-20 diff --git a/records/track_10min_16mb/2026-03-19_QAT_Ablation/logs/ablation_qat_30pct_v3.txt b/records/track_10min_16mb/2026-03-19_QAT_Ablation/logs/ablation_qat_30pct_v3.txt new file mode 100644 index 000000000..13d4b6dd2 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_QAT_Ablation/logs/ablation_qat_30pct_v3.txt @@ -0,0 +1,170 @@ +logs/ablation_qat_30pct_v3.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +qat:True (activates at 30% of iterations = step 6000) +model_params:17059912 (unique_layers:9 loops:1 effective_depth:9 lora_rank:0 lora_params:0) +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9370 train_time:50ms step_avg:50.13ms +step:2/20000 train_loss:16.8366 train_time:99ms step_avg:49.53ms +step:3/20000 train_loss:8.7609 train_time:155ms step_avg:51.57ms +step:4/20000 train_loss:6.6387 train_time:210ms step_avg:52.52ms +step:5/20000 train_loss:6.6117 train_time:288ms step_avg:57.54ms +step:6/20000 train_loss:7.4221 train_time:351ms step_avg:58.54ms +step:7/20000 train_loss:6.3509 train_time:413ms step_avg:58.97ms +step:8/20000 train_loss:6.1583 train_time:493ms step_avg:61.62ms +step:9/20000 train_loss:6.0680 train_time:557ms step_avg:61.92ms +step:10/20000 train_loss:5.9748 train_time:620ms step_avg:62.02ms +step:200/20000 train_loss:2.8544 train_time:12864ms step_avg:64.32ms +step:400/20000 train_loss:2.3536 train_time:25585ms step_avg:63.96ms +step:600/20000 train_loss:2.5528 train_time:38924ms step_avg:64.87ms +step:800/20000 train_loss:2.2956 train_time:52807ms step_avg:66.01ms +step:1000/20000 train_loss:2.3710 train_time:66301ms step_avg:66.30ms +step:1200/20000 train_loss:2.3861 train_time:79773ms step_avg:66.48ms +step:1400/20000 train_loss:2.4330 train_time:92897ms step_avg:66.35ms +step:1600/20000 train_loss:2.1007 train_time:105036ms step_avg:65.65ms +step:1800/20000 train_loss:2.2012 train_time:117744ms step_avg:65.41ms +step:2000/20000 train_loss:2.2521 train_time:131332ms step_avg:65.67ms +step:2200/20000 train_loss:2.0783 train_time:144313ms step_avg:65.60ms +step:2400/20000 train_loss:2.2024 train_time:157336ms step_avg:65.56ms +step:2600/20000 train_loss:2.4112 train_time:169876ms step_avg:65.34ms +step:2800/20000 train_loss:2.2358 train_time:183280ms step_avg:65.46ms +step:3000/20000 train_loss:2.2263 train_time:196108ms step_avg:65.37ms +step:3200/20000 train_loss:2.1873 train_time:209318ms step_avg:65.41ms +step:3400/20000 train_loss:2.1570 train_time:222737ms step_avg:65.51ms +step:3600/20000 train_loss:2.1152 train_time:235103ms step_avg:65.31ms +step:3800/20000 train_loss:2.2241 train_time:247545ms step_avg:65.14ms +step:4000/20000 train_loss:2.1641 train_time:259662ms step_avg:64.92ms +step:4200/20000 train_loss:2.1776 train_time:274509ms step_avg:65.36ms +step:4400/20000 train_loss:2.1126 train_time:287085ms step_avg:65.25ms +step:4600/20000 train_loss:1.9722 train_time:299308ms step_avg:65.07ms +step:4800/20000 train_loss:2.2631 train_time:311773ms step_avg:64.95ms +step:5000/20000 train_loss:2.0304 train_time:324003ms step_avg:64.80ms +step:5200/20000 train_loss:2.1743 train_time:336425ms step_avg:64.70ms +step:5400/20000 train_loss:2.1880 train_time:349255ms step_avg:64.68ms +step:5600/20000 train_loss:2.1843 train_time:362206ms step_avg:64.68ms +step:5800/20000 train_loss:2.1458 train_time:374309ms step_avg:64.54ms +qat_activated step:6000/20000 +step:6000/20000 train_loss:2.2221 train_time:386982ms step_avg:64.50ms +step:6200/20000 train_loss:2.0886 train_time:481881ms step_avg:77.72ms +step:6400/20000 train_loss:2.1616 train_time:494552ms step_avg:77.27ms +step:6600/20000 train_loss:2.1236 train_time:507596ms step_avg:76.91ms +step:6800/20000 train_loss:2.1860 train_time:520990ms step_avg:76.62ms +step:7000/20000 train_loss:2.2116 train_time:534315ms step_avg:76.33ms +step:7200/20000 train_loss:2.1757 train_time:547515ms step_avg:76.04ms +step:7400/20000 train_loss:2.0941 train_time:560062ms step_avg:75.68ms +step:7600/20000 train_loss:1.9693 train_time:572584ms step_avg:75.34ms +step:7800/20000 train_loss:2.1111 train_time:584859ms step_avg:74.98ms +step:8000/20000 train_loss:2.0699 train_time:598143ms step_avg:74.77ms +step:8011/20000 val_loss:2.0814 val_bpb:1.2327 train_time:602064ms step_avg:75.15ms +stopping_early: wallclock_cap train_time:602064ms step:8011/20000 +peak memory allocated: 10119 MiB reserved: 10424 MiB +Serialized model: 67224983 bytes +Code size: 63581 bytes +Total submission size: 67288564 bytes +Serialized model int8+zlib: 15804522 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) +Total submission size int8+zlib: 15868103 bytes +final_eval_mode:sliding_window stride:64 batch_seqs:32 doc_isolated:False +sliding_eval [ 0.0%] 32/121134 windows running_bpb=1.203131 +sliding_eval [ 1.3%] 1632/121134 windows running_bpb=1.195758 +sliding_eval [ 2.7%] 3232/121134 windows running_bpb=1.197809 +sliding_eval [ 4.0%] 4832/121134 windows running_bpb=1.192269 +sliding_eval [ 5.3%] 6432/121134 windows running_bpb=1.204374 +sliding_eval [ 6.6%] 8032/121134 windows running_bpb=1.205619 +sliding_eval [ 8.0%] 9632/121134 windows running_bpb=1.207984 +sliding_eval [ 9.3%] 11232/121134 windows running_bpb=1.203688 +sliding_eval [ 10.6%] 12832/121134 windows running_bpb=1.201027 +sliding_eval [ 11.9%] 14432/121134 windows running_bpb=1.202773 +sliding_eval [ 13.2%] 16032/121134 windows running_bpb=1.211795 +sliding_eval [ 14.6%] 17632/121134 windows running_bpb=1.210786 +sliding_eval [ 15.9%] 19232/121134 windows running_bpb=1.212110 +sliding_eval [ 17.2%] 20832/121134 windows running_bpb=1.210688 +sliding_eval [ 18.5%] 22432/121134 windows running_bpb=1.209239 +sliding_eval [ 19.8%] 24032/121134 windows running_bpb=1.209753 +sliding_eval [ 21.2%] 25632/121134 windows running_bpb=1.210995 +sliding_eval [ 22.5%] 27232/121134 windows running_bpb=1.211682 +sliding_eval [ 23.8%] 28832/121134 windows running_bpb=1.217680 +sliding_eval [ 25.1%] 30432/121134 windows running_bpb=1.214952 +sliding_eval [ 26.4%] 32032/121134 windows running_bpb=1.216245 +sliding_eval [ 27.8%] 33632/121134 windows running_bpb=1.214992 +sliding_eval [ 29.1%] 35232/121134 windows running_bpb=1.214376 +sliding_eval [ 30.4%] 36832/121134 windows running_bpb=1.214029 +sliding_eval [ 31.7%] 38432/121134 windows running_bpb=1.214833 +sliding_eval [ 33.0%] 40032/121134 windows running_bpb=1.212650 +sliding_eval [ 34.4%] 41632/121134 windows running_bpb=1.211820 +sliding_eval [ 35.7%] 43232/121134 windows running_bpb=1.212058 +sliding_eval [ 37.0%] 44832/121134 windows running_bpb=1.211127 +sliding_eval [ 38.3%] 46432/121134 windows running_bpb=1.210929 +sliding_eval [ 39.7%] 48032/121134 windows running_bpb=1.210144 +sliding_eval [ 41.0%] 49632/121134 windows running_bpb=1.211272 +sliding_eval [ 42.3%] 51232/121134 windows running_bpb=1.212395 +sliding_eval [ 43.6%] 52832/121134 windows running_bpb=1.212887 +sliding_eval [ 44.9%] 54432/121134 windows running_bpb=1.212423 +sliding_eval [ 46.3%] 56032/121134 windows running_bpb=1.212843 +sliding_eval [ 47.6%] 57632/121134 windows running_bpb=1.211920 +sliding_eval [ 48.9%] 59232/121134 windows running_bpb=1.208484 +sliding_eval [ 50.2%] 60832/121134 windows running_bpb=1.208299 +sliding_eval [ 51.5%] 62432/121134 windows running_bpb=1.209154 +sliding_eval [ 52.9%] 64032/121134 windows running_bpb=1.209189 +sliding_eval [ 54.2%] 65632/121134 windows running_bpb=1.209028 +sliding_eval [ 55.5%] 67232/121134 windows running_bpb=1.207807 +sliding_eval [ 56.8%] 68832/121134 windows running_bpb=1.207306 +sliding_eval [ 58.1%] 70432/121134 windows running_bpb=1.206638 +sliding_eval [ 59.5%] 72032/121134 windows running_bpb=1.206674 +sliding_eval [ 60.8%] 73632/121134 windows running_bpb=1.206673 +sliding_eval [ 62.1%] 75232/121134 windows running_bpb=1.206874 +sliding_eval [ 63.4%] 76832/121134 windows running_bpb=1.206570 +sliding_eval [ 64.7%] 78432/121134 windows running_bpb=1.207203 +sliding_eval [ 66.1%] 80032/121134 windows running_bpb=1.207558 +sliding_eval [ 67.4%] 81632/121134 windows running_bpb=1.207306 +sliding_eval [ 68.7%] 83232/121134 windows running_bpb=1.208383 +sliding_eval [ 70.0%] 84832/121134 windows running_bpb=1.210262 +sliding_eval [ 71.4%] 86432/121134 windows running_bpb=1.209648 +sliding_eval [ 72.7%] 88032/121134 windows running_bpb=1.210599 +sliding_eval [ 74.0%] 89632/121134 windows running_bpb=1.210889 +sliding_eval [ 75.3%] 91232/121134 windows running_bpb=1.210974 +sliding_eval [ 76.6%] 92832/121134 windows running_bpb=1.210459 +sliding_eval [ 78.0%] 94432/121134 windows running_bpb=1.210782 +sliding_eval [ 79.3%] 96032/121134 windows running_bpb=1.210236 +sliding_eval [ 80.6%] 97632/121134 windows running_bpb=1.213084 +sliding_eval [ 81.9%] 99232/121134 windows running_bpb=1.213028 +sliding_eval [ 83.2%] 100832/121134 windows running_bpb=1.213101 +sliding_eval [ 84.6%] 102432/121134 windows running_bpb=1.212770 +sliding_eval [ 85.9%] 104032/121134 windows running_bpb=1.212278 +sliding_eval [ 87.2%] 105632/121134 windows running_bpb=1.211539 +sliding_eval [ 88.5%] 107232/121134 windows running_bpb=1.211460 +sliding_eval [ 89.8%] 108832/121134 windows running_bpb=1.212012 +sliding_eval [ 91.2%] 110432/121134 windows running_bpb=1.212038 +sliding_eval [ 92.5%] 112032/121134 windows running_bpb=1.211971 +sliding_eval [ 93.8%] 113632/121134 windows running_bpb=1.212494 +sliding_eval [ 95.1%] 115232/121134 windows running_bpb=1.212217 +sliding_eval [ 96.4%] 116832/121134 windows running_bpb=1.211896 +sliding_eval [ 97.8%] 118432/121134 windows running_bpb=1.212275 +sliding_eval [ 99.1%] 120032/121134 windows running_bpb=1.212335 +final_int8_zlib_roundtrip val_loss:2.0349 val_bpb:1.2052 eval_time:74761ms +final_int8_zlib_roundtrip_exact val_loss:2.03485247 val_bpb:1.20515425 diff --git a/records/track_10min_16mb/2026-03-19_QAT_Ablation/run_ablation.sh b/records/track_10min_16mb/2026-03-19_QAT_Ablation/run_ablation.sh new file mode 100755 index 000000000..6d395c8e5 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_QAT_Ablation/run_ablation.sh @@ -0,0 +1,76 @@ +#!/usr/bin/env bash +# QAT Ablation — isolate QAT's effect on post-quantization val_bpb +# +# 4 runs, one variable (QAT on/off), two eval modes: +# 1. Baseline (no QAT, standard eval) — reproduces naive baseline +# 2. Baseline (no QAT, sliding eval) — reproduces SlidingWindowEval entry +# 3. QAT (sliding eval) — measures QAT's contribution +# 4. QAT (sliding eval, doc-isolated) — measures doc isolation on top +# +# Architecture: 9L×512d, default hyperparams throughout. +# No FP16 embed, no warmdown tuning, no leader hyperparams. +# +# Usage: +# cd /workspace/parameter-golf +# bash records/track_10min_16mb/2026-03-19_QAT_Ablation/run_ablation.sh + +set -euo pipefail + +SCRIPT="records/track_10min_16mb/2026-03-19_QAT_Ablation/train_gpt.py" + +# Baseline architecture + training — all defaults +BASE="VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4" +BASE="$BASE MLP_MULT=2 TIE_EMBEDDINGS=1 TRAIN_BATCH_TOKENS=524288 TRAIN_SEQ_LEN=1024" +BASE="$BASE ITERATIONS=20000 WARMDOWN_ITERS=1200 WARMUP_STEPS=20" +BASE="$BASE MAX_WALLCLOCK_SECONDS=600 TRAIN_LOG_EVERY=200 VAL_LOSS_EVERY=0" +BASE="$BASE NUM_LOOPS=1 LORA_RANK=0 FP16_EMBED_EXPORT=0 SEED=1337" + +echo "============================================" +echo "QAT Ablation — 4 runs, 8×H100" +echo "============================================" + +# Run 1: Baseline — no QAT, standard eval (non-overlapping) +echo "" +echo ">>> Run 1/4: Baseline (no QAT, standard eval)" +env $BASE RUN_ID=ablation_baseline QAT=0 EVAL_STRIDE=0 DOC_ISOLATED_EVAL=0 \ + torchrun --standalone --nproc_per_node=8 "$SCRIPT" +cp logs/ablation_baseline.txt records/track_10min_16mb/2026-03-19_QAT_Ablation/logs/ 2>/dev/null || true +echo ">>> Run 1 done: $(grep 'final_int8_zlib_roundtrip_exact' logs/ablation_baseline.txt | tail -1)" + +# Run 2: No QAT, sliding eval (stride=64) +echo "" +echo ">>> Run 2/4: No QAT, sliding eval (stride=64)" +env $BASE RUN_ID=ablation_slide64 QAT=0 EVAL_STRIDE=64 EVAL_BATCH_SEQS=32 DOC_ISOLATED_EVAL=0 \ + torchrun --standalone --nproc_per_node=8 "$SCRIPT" +cp logs/ablation_slide64.txt records/track_10min_16mb/2026-03-19_QAT_Ablation/logs/ 2>/dev/null || true +echo ">>> Run 2 done: $(grep 'final_int8_zlib_roundtrip_exact' logs/ablation_slide64.txt | tail -1)" + +# Run 3: QAT + sliding eval (stride=64) +echo "" +echo ">>> Run 3/4: QAT + sliding eval (stride=64)" +env $BASE RUN_ID=ablation_qat_slide64 QAT=1 EVAL_STRIDE=64 EVAL_BATCH_SEQS=32 DOC_ISOLATED_EVAL=0 \ + torchrun --standalone --nproc_per_node=8 "$SCRIPT" +cp logs/ablation_qat_slide64.txt records/track_10min_16mb/2026-03-19_QAT_Ablation/logs/ 2>/dev/null || true +echo ">>> Run 3 done: $(grep 'final_int8_zlib_roundtrip_exact' logs/ablation_qat_slide64.txt | tail -1)" + +# Run 4: QAT + sliding eval + doc-isolated +echo "" +echo ">>> Run 4/4: QAT + sliding eval + doc-isolated" +env $BASE RUN_ID=ablation_qat_slide64_dociso QAT=1 EVAL_STRIDE=64 EVAL_BATCH_SEQS=32 DOC_ISOLATED_EVAL=1 \ + torchrun --standalone --nproc_per_node=8 "$SCRIPT" +cp logs/ablation_qat_slide64_dociso.txt records/track_10min_16mb/2026-03-19_QAT_Ablation/logs/ 2>/dev/null || true +echo ">>> Run 4 done: $(grep 'final_int8_zlib_roundtrip_exact' logs/ablation_qat_slide64_dociso.txt | tail -1)" + +echo "" +echo "============================================" +echo "ABLATION RESULTS" +echo "============================================" +for LOG in ablation_baseline ablation_slide64 ablation_qat_slide64 ablation_qat_slide64_dociso; do + echo "$LOG: $(grep 'final_int8_zlib_roundtrip_exact' logs/${LOG}.txt 2>/dev/null | tail -1)" +done +echo "" +echo "Expected pattern:" +echo " baseline ~1.2244 (reproduces naive baseline)" +echo " slide64 ~1.1925 (reproduces SlidingWindowEval entry)" +echo " qat+slide64 < slide64 if QAT helps" +echo " qat+slide64+doc < qat+slide64 if doc isolation helps" diff --git a/records/track_10min_16mb/2026-03-19_QAT_Ablation/submission.json b/records/track_10min_16mb/2026-03-19_QAT_Ablation/submission.json new file mode 100644 index 000000000..91e055917 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_QAT_Ablation/submission.json @@ -0,0 +1,33 @@ +{ + "author": "mrdavtan", + "github_id": "mrdavtan", + "name": "Non-record: QAT ablation — int8 QAT overhead exceeds quantization gap recovery", + "blurb": "Clean ablation of per-row int8 QAT with exact INT8_CLIP_Q percentile matching. Finding: torch.quantile adds ~20% per-step overhead, costing ~2000 training steps. The lost training tokens hurt more than the ~0.007 BPB quantization gap recovery helps. QAT likely only pays off with int6 (larger gap) or with a faster approximate quantile.", + "date": "2026-03-20", + "val_loss": 2.03485247, + "val_bpb": 1.20515425, + "pre_quant_val_loss": 2.0814, + "pre_quant_val_bpb": 1.2327, + "step_stop": 8011, + "wallclock_seconds": 602.064, + "eval_time_seconds": 74.761, + "bytes_total": 15868103, + "bytes_model_int8_zlib": 15804522, + "bytes_code": 63581, + "hardware": "8xH100 SXM (RunPod), PyTorch 2.9.1+cu128", + "seed": 1337, + "track": "track_10min_16mb", + "model": { + "num_layers": 9, + "model_dim": 512, + "num_heads": 8, + "num_kv_heads": 4, + "mlp_mult": 2, + "vocab_size": 1024, + "tie_embeddings": true + }, + "changes_from_baseline": [ + "QAT: fake_quantize_int8_per_row (STE, INT8_CLIP_Q percentile, activates at 30% of training)" + ], + "notes": "Non-record submission. Negative result: int8 QAT overhead exceeds recovery. See README for analysis." +} diff --git a/records/track_10min_16mb/2026-03-19_QAT_Ablation/train_gpt.py b/records/track_10min_16mb/2026-03-19_QAT_Ablation/train_gpt.py new file mode 100644 index 000000000..bdead28c3 --- /dev/null +++ b/records/track_10min_16mb/2026-03-19_QAT_Ablation/train_gpt.py @@ -0,0 +1,1475 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) # 0 = use mlp_mult * model_dim + fp16_embed_export = bool(int(os.environ.get("FP16_EMBED_EXPORT", "0"))) # keep tok_emb in fp16 at export + num_loops = int(os.environ.get("NUM_LOOPS", 1)) + lora_rank = int(os.environ.get("LORA_RANK", 0)) + qat = bool(int(os.environ.get("QAT", "1"))) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + doc_isolated_eval = bool(int(os.environ.get("DOC_ISOLATED_EVAL", "1"))) # eval per-document, no cross-doc context + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + lora_lr = float(os.environ.get("LORA_LR", 0.01)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +# Names of tensors to keep in fp16 at export instead of quantizing to int8. +# Populated at runtime when fp16_embed_export=True. +_FP16_EXPORT_NAMES: set[str] = set() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + # - fp16 passthrough for tensors in _FP16_EXPORT_NAMES (e.g. tied embeddings) + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # FP16 export bypass: keep specified tensors (e.g. tied embeddings) in fp16 + # instead of quantizing to int8. Avoids compounding int8 errors through both + # the input embedding and output projection paths. + if name in _FP16_EXPORT_NAMES: + kept = t.to(dtype=torch.float16).contiguous() + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +def fake_quantize_int8_per_row(w: Tensor) -> Tensor: + """Simulate per-row int8 quantization with straight-through estimator. + + Forward: uses quantized-then-dequantized weights matching the export pipeline exactly + (same INT8_CLIP_Q percentile clip, same per-row scale, same rounding as quantize_float_tensor). + Backward: gradients pass through as if no quantization happened (STE). + """ + w32 = w.float() + clip_abs = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + w_clipped = torch.clamp(w32, -clip_abs[:, None], clip_abs[:, None]) + w_q = torch.clamp(torch.round(w_clipped / scale[:, None]), -127, 127) + w_deq = w_q * scale[:, None] + return w + (w_deq.to(w.dtype) - w).detach() + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + _qat: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self._qat and self.training: + w = fake_quantize_int8_per_row(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +class AttentionLoRA(nn.Module): + """Per-iteration LoRA adapters for attention Q, K, V, and output projections. + + Initialized so that the LoRA contribution is zero at the start of training + (B matrices are zeros). During training, the optimizer learns per-iteration + specialization while the base attention weights remain shared across loops. + """ + def __init__(self, dim: int, kv_dim: int, rank: int): + super().__init__() + self.q_A = nn.Parameter(torch.empty(dim, rank)) + self.q_B = nn.Parameter(torch.zeros(rank, dim)) + self.k_A = nn.Parameter(torch.empty(dim, rank)) + self.k_B = nn.Parameter(torch.zeros(rank, kv_dim)) + self.v_A = nn.Parameter(torch.empty(dim, rank)) + self.v_B = nn.Parameter(torch.zeros(rank, kv_dim)) + self.proj_A = nn.Parameter(torch.empty(dim, rank)) + self.proj_B = nn.Parameter(torch.zeros(rank, dim)) + self._init_lora() + + def _init_lora(self) -> None: + for name in ("q_A", "k_A", "v_A", "proj_A"): + nn.init.kaiming_uniform_(getattr(self, name), a=math.sqrt(5)) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, lora: AttentionLoRA | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + if lora is not None: + # LoRA delta: (bsz, seqlen, dim) @ (dim, rank) @ (rank, out_dim) + # autocast handles fp32->bf16 cast of LoRA params automatically + q = q + (x @ lora.q_A) @ lora.q_B + k = k + (x @ lora.k_A) @ lora.k_B + v = v + (x @ lora.v_A) @ lora.v_B + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + out = self.proj(y) + if lora is not None: + out = out + (y @ lora.proj_A) @ lora.proj_B + return out + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, lora: AttentionLoRA | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), lora=lora) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_loops: int = 1, + lora_rank: int = 0, + mlp_hidden: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_unique_layers = num_layers + self.num_loops = num_loops + effective_depth = num_layers * num_loops + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = effective_depth // 2 + self.num_decoder_layers = effective_depth - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + mlp_hidden=mlp_hidden, + ) + for i in range(num_layers) + ] + ) + # Per-(loop, block) LoRA adapters for attention projections. + # Only created when num_loops > 1 and lora_rank > 0. + kv_dim = num_kv_heads * (model_dim // num_heads) + if lora_rank > 0 and num_loops > 1: + self.lora_adapters = nn.ModuleList( + [ + nn.ModuleList( + [AttentionLoRA(model_dim, kv_dim, lora_rank) for _ in range(num_layers)] + ) + for _ in range(num_loops) + ] + ) + else: + self.lora_adapters = None + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # Iterate through effective layers: each unique block is reused across loops. + # First half (encoder) stores skip connections; second half (decoder) pops them. + eff_idx = 0 + for loop_idx in range(self.num_loops): + for block_idx in range(self.num_unique_layers): + lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None + if eff_idx < self.num_encoder_layers: + x = self.blocks[block_idx](x, x0, lora=lora) + skips.append(x) + else: + dec_idx = eff_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights and skips: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[block_idx](x, x0, lora=lora) + eff_idx += 1 + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + eff_idx = 0 + for loop_idx in range(self.num_loops): + for block_idx in range(self.num_unique_layers): + lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None + if eff_idx < self.num_encoder_layers: + x = self.blocks[block_idx](x, x0, lora=lora) + skips.append(x) + else: + dec_idx = eff_idx - self.num_encoder_layers + if dec_idx < self.num_skip_weights and skips: + x = x + self.skip_weights[dec_idx].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[block_idx](x, x0, lora=lora) + eff_idx += 1 + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +BOS_ID = 1 # SentencePiece BOS token ID + +def _find_docs(all_tokens: Tensor) -> list[tuple[int, int]]: + """Return (start, length) for each document, identified by BOS boundaries. + + Each document starts at a BOS token and extends to just before the next BOS. + The last document extends to the end of the token stream. + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].tolist() + if not bos_positions: + return [(0, all_tokens.numel())] + docs = [] + for i in range(len(bos_positions)): + start = bos_positions[i] + end = bos_positions[i + 1] if i + 1 < len(bos_positions) else all_tokens.numel() + if end - start >= 2: # need at least 2 tokens for (x, y) pair + docs.append((start, end - start)) + return docs + + +def _build_sliding_windows( + total_tokens: int, seq_len: int, stride: int +) -> list[tuple[int, int]]: + """Return (window_start, score_start) pairs covering every token exactly once. + + Every token in [0, total_tokens) is scored by exactly one window. + Full windows score their last `stride` positions (first window scores all seq_len). + One tail-aligned window covers any tokens beyond the last full window's end. + """ + if total_tokens <= 0: + return [] + if total_tokens <= seq_len: + return [(0, 0)] + + windows: list[tuple[int, int]] = [] + last_full_end = 0 + ws = 0 + while ws + seq_len <= total_tokens: + s = 0 if ws == 0 else seq_len - stride + windows.append((ws, s)) + last_full_end = ws + seq_len + ws += stride + + # One tail window ending exactly at total_tokens covers any remainder. + if last_full_end < total_tokens: + tail_ws = total_tokens - seq_len + tail_s = last_full_end - tail_ws # skip already-scored prefix + windows.append((tail_ws, tail_s)) + + return windows + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + + Windows of train_seq_len advance by `stride`. Only newly covered tokens + contribute to the score (first window scores all seq_len; non-first full + windows score the last `stride` tokens; one tail window covers any remainder). + Every validation token is counted exactly once. + """ + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + if args.doc_isolated_eval: + # Build windows per document — context never crosses document boundaries. + # Each window's (ws, s) is in absolute token-stream coordinates. + docs = _find_docs(val_tokens) + all_windows: list[tuple[int, int]] = [] + for doc_start, doc_len in docs: + doc_pred_len = doc_len - 1 # number of prediction positions + doc_windows = _build_sliding_windows(doc_pred_len, seq_len, stride) + for ws, s in doc_windows: + all_windows.append((doc_start + ws, s)) + else: + all_windows = _build_sliding_windows(total_tokens, seq_len, stride) + total_windows = len(all_windows) + + # Distribute across ranks + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = all_windows[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_items = my_windows[bi:bi + batch_seqs] + bsz = len(batch_items) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + score_starts: list[int] = [] + + for i, (ws, s) in enumerate(batch_items): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + score_starts.append(s) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, (ws, s) in enumerate(batch_items): + wlen = wlens[i] + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Progress (rank 0 only) + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_loops=args.num_loops, + lora_rank=args.lora_rank, + mlp_hidden=args.mlp_hidden, + ).to(device).bfloat16() + + # FP16 tied embedding export: skip int8 quantization for tok_emb.weight at export time. + # Avoids compounding int8 errors through both input embedding and output projection. + if args.fp16_embed_export and args.tie_embeddings: + _FP16_EXPORT_NAMES.add("tok_emb.weight") + log0(f"fp16_embed_export:enabled (tok_emb.weight kept in fp16, ~{args.vocab_size * args.model_dim * 2 / 1024:.0f}KB)") + + for module in base_model.modules(): + if isinstance(module, (CastedLinear, AttentionLoRA)): + module.float() + restore_low_dim_params_to_fp32(base_model) + log0(f"qat:{args.qat} (activates at 30% of iterations = step {int(args.iterations * 0.30)})") + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lora_adapters is not None: + lora_params = list(base_model.lora_adapters.parameters()) + optimizer_lora = torch.optim.Adam( + [{"params": lora_params, "lr": args.lora_lr, "base_lr": args.lora_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.append(optimizer_lora) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + n_lora = sum(p.numel() for p in base_model.lora_adapters.parameters()) if base_model.lora_adapters is not None else 0 + effective_depth = args.num_layers * args.num_loops + log0(f"model_params:{n_params} (unique_layers:{args.num_layers} loops:{args.num_loops} effective_depth:{effective_depth} lora_rank:{args.lora_rank} lora_params:{n_lora})") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + # NOTE: QAT graph priming removed — it caused torch.compile to use a slower + # compilation path for the non-QAT forward pass (step_avg jumped from 44ms to 58ms). + # The one-time recompile when QAT activates (~30-90s) is cheaper than the cumulative + # overhead of a slower non-QAT path across thousands of steps. + + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + + # Delayed QAT schedule: activate fake-quantize at 60% of training steps to + # avoid early-training instability before the loss landscape has settled. + if args.qat and step == int(args.iterations * 0.30): + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module._qat = True + log0(f"qat_activated step:{step}/{args.iterations}") + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs} doc_isolated:{args.doc_isolated_eval}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()