[WIP] Depth-recurrent QAT: 3x4 loops, 768d, 15.6MB artifact#51
[WIP] Depth-recurrent QAT: 3x4 loops, 768d, 15.6MB artifact#51hacksurvivor wants to merge 3 commits intoopenai:mainfrom
Conversation
Depth recurrence (3 shared blocks x 4 loops = 12 effective layers), QAT via STE fake-quantize, per-loop LoRA adapters, LAWA weight averaging, fp16 tied embedding export. Validated on A10G. H100 multi-seed results pending. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9c5e06c38e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| lawa_state = {k: v.detach().cpu().clone().float() for k, v in base_model.state_dict().items()} | ||
| lawa_count = 1 | ||
| else: | ||
| lawa_count += 1 | ||
| for k, v in base_model.state_dict().items(): | ||
| lawa_state[k] += (v.detach().cpu().float() - lawa_state[k]) / lawa_count |
There was a problem hiding this comment.
Keep LAWA averaging off the CPU hot path
Once warmdown starts, this code copies the entire state_dict() from GPU to CPU and casts it to float on every step. For the default 13M-parameter model that is tens of MB of host transfer per step, and over a 400-step warmdown it turns into many GB of extra PCIe traffic and Python-side work on every rank. In the challenge’s 10-minute setting, that overhead can easily consume the remaining wallclock budget and force an earlier stop or a worse final score.
Useful? React with 👍 / 👎.
| if os.environ.get("SKIP_COMPILE", "0") == "1": | ||
| compiled_model = base_model | ||
| else: | ||
| compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) |
There was a problem hiding this comment.
Make
SKIP_COMPILE disable all torch.compile calls
This branch suggests SKIP_COMPILE=1 is enough to bypass Inductor, but main() still unconditionally runs torch.compile(zeropower_via_newtonschulz5) earlier at line 789. On the smoke-test machines where people set SKIP_COMPILE because compile is slow or unsupported, the script can still fail or spend a long time compiling before it ever reaches this branch, so the escape hatch is incomplete.
Useful? React with 👍 / 👎.
INT8_KEEP_FLOAT_MAX_NUMEL=1M accidentally kept attention weights (589K params) in fp16 instead of int8, inflating artifact to 17.6MB. Fix: name-based pattern keeps only tok_emb in fp16, threshold back to 65K. Reduces raw payload from 19.5MB to 14.2MB. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- LAWA averaging now stays on GPU using lerp_() instead of copying to CPU every step. Eliminates PCIe overhead during warmdown. - SKIP_COMPILE=1 now also skips torch.compile(zeropower_via_newtonschulz5) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary
Depth-recurrent transformer with quantization-aware training.
Key optimizations
A10G validation (200 iterations, 1 shard)
Full 8xH100 multi-seed results pending compute grant.
Test plan