Record: Int6 MLP3x + STE QAT + Sliding Window (val_bpb=1.1594)#128
Record: Int6 MLP3x + STE QAT + Sliding Window (val_bpb=1.1594)#128rsavitt wants to merge 1 commit intoopenai:mainfrom
Conversation
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a56bb8cce5
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| training_time_ms = 0.0 | ||
| stop_after_step: int | None = None | ||
| torch.cuda.synchronize() | ||
| t0 = time.perf_counter() |
There was a problem hiding this comment.
Include warmup wallclock in the training budget
When WARMUP_STEPS > 0 (the default here is 20), the script does real forward/backward/optimizer work and triggers the expensive initial torch.compile() traces before training_time_ms/t0 are initialized. Because the challenge rule is "train in under 10 minutes on 8xH100s" (see README.md:6 and README.md:160), the logged 600s can still correspond to an end-to-end run that exceeds the limit by the warmup/compile overhead.
Useful? React with 👍 / 👎.
| usable = ((tokens.numel() - 1) // seq_len) * seq_len | ||
| if usable <= 0: | ||
| raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") | ||
| return tokens[: usable + 1] |
There was a problem hiding this comment.
Score the full validation split instead of truncating the tail
load_validation_tokens() drops the last ((N-1) mod TRAIN_SEQ_LEN) targets by rounding the fixed fineweb_val_* stream down to a multiple of TRAIN_SEQ_LEN. With the default TRAIN_SEQ_LEN=4096, that silently omits up to 4095 validation tokens unless the corpus length happens to align exactly, so the reported val_bpb is not actually computed on the full validation split that README.md:82 describes.
Useful? React with 👍 / 👎.
| if not t.is_floating_point() or t.numel() <= 65536: | ||
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | ||
| meta[name] = "passthrough" | ||
| continue | ||
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): |
There was a problem hiding this comment.
Preserve control tensors before the small-tensor shortcut
The t.numel() <= 65536 early return runs before the CONTROL_TENSOR_NAME_PATTERNS branch, so every control tensor in this model (q_gain, attn_scale, mlp_scale, resid_mix, skip_weights) is serialized as fp16 passthrough instead of the intended fp32 passthrough_ctrl. Those are the calibration parameters that most directly steer the round-tripped model, so downcasting them changes the post-quant checkpoint and can move the reported final_int6_zstd_roundtrip score.
Useful? React with 👍 / 👎.
Additional seed runs (H100 NVL, 8x)H100 SXM wasn't available for additional seeds, so ran on H100 NVL (slower — ~90ms/step vs 57ms on SXM, resulting in ~6,700 steps vs 10,535):
The NVL results are higher due to fewer training steps (hardware speed difference), not seed variance. All runs are under 16MB. Will add SXM multi-seed validation when 8xH100 SXM availability improves. |
|
Re: Codex review comments — P1 (warmup timing): The warmup-before-timer pattern is inherited from the official P2 (val truncation): With seq_len=4096, at most 4095 tokens are dropped from the ~50K-doc val set (~100M tokens). This is <0.004% of the validation data and within noise. P2 (control tensor order): Valid observation — control tensors are small (512–1024 elements) so they hit the SXM multi-seed validation in progress. |
3-Seed SXM Validation (8xH100 SXM)
Mean: 1.16306 BPB | std: 0.00381 Improvement over baseline (1.2244): 0.0613 nats One-sample t-test against threshold 1.2194 (baseline - 0.005):
All artifacts under 16,000,000 bytes. All runs on 8xH100 SXM within 600s wallclock. |
- STE QAT: fake quantize->dequantize in CastedLinear forward pass Gradients pass through via STE (w + (w_hat - w).detach()) Activates after STE_QAT_START_FRAC of training (default 25%) USE_STE_QAT=1 to enable - forward_with_adapter refactored to reuse _forward_body - All Tier 2 features are env-var controlled, disabled by default
Summary
val_bpb=1.1594 (post int6+zstd quantization roundtrip, sliding window eval stride=64 on 8xH100 SXM).
Six-technique stack:
Key Metrics (seed 1337)
Configuration
Command
Acknowledgments
Composes techniques from community PRs #42, #50, #52, #65, #70. Built with Claude Code as an autonomous research agent.
Test plan
🤖 Generated with Claude Code