Skip to content

[WIP] Depth-recurrent QAT: 3x4 loops, 768d, 15.6MB artifact#51

Closed
hacksurvivor wants to merge 3 commits intoopenai:mainfrom
hacksurvivor:depth-recurrent-qat
Closed

[WIP] Depth-recurrent QAT: 3x4 loops, 768d, 15.6MB artifact#51
hacksurvivor wants to merge 3 commits intoopenai:mainfrom
hacksurvivor:depth-recurrent-qat

Conversation

@hacksurvivor
Copy link

Summary

Depth-recurrent transformer with quantization-aware training.

Key optimizations

  • Depth recurrence: 3 shared blocks looped 4x = 12 effective layers (vs baseline 9)
  • Wider model: dim 768, 12 heads, 6 KV heads
  • Per-loop LoRA: rank-4 adapters on Q/V for loop specialization
  • QAT: STE fake-quantize in CastedLinear — minimal post-quant degradation
  • LAWA: Weight averaging during warmdown
  • fp16 tied embedding export

A10G validation (200 iterations, 1 shard)

Metric Value
val_bpb 2.0799
Post-quant val_bpb 2.0485
Artifact size 15.6 MB

Full 8xH100 multi-seed results pending compute grant.

Test plan

  • CPU forward/backward pass
  • Int8+zlib roundtrip under 16MB
  • A10G CUDA training (200 steps)
  • Full 8xH100 10-min training run
  • Multi-seed validation (p < 0.01)

Depth recurrence (3 shared blocks x 4 loops = 12 effective layers),
QAT via STE fake-quantize, per-loop LoRA adapters, LAWA weight
averaging, fp16 tied embedding export. Validated on A10G.
H100 multi-seed results pending.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9c5e06c38e

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +1116 to +1121
lawa_state = {k: v.detach().cpu().clone().float() for k, v in base_model.state_dict().items()}
lawa_count = 1
else:
lawa_count += 1
for k, v in base_model.state_dict().items():
lawa_state[k] += (v.detach().cpu().float() - lawa_state[k]) / lawa_count

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep LAWA averaging off the CPU hot path

Once warmdown starts, this code copies the entire state_dict() from GPU to CPU and casts it to float on every step. For the default 13M-parameter model that is tens of MB of host transfer per step, and over a 400-step warmdown it turns into many GB of extra PCIe traffic and Python-side work on every rank. In the challenge’s 10-minute setting, that overhead can easily consume the remaining wallclock budget and force an earlier stop or a worse final score.

Useful? React with 👍 / 👎.

Comment on lines +898 to +901
if os.environ.get("SKIP_COMPILE", "0") == "1":
compiled_model = base_model
else:
compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Make SKIP_COMPILE disable all torch.compile calls

This branch suggests SKIP_COMPILE=1 is enough to bypass Inductor, but main() still unconditionally runs torch.compile(zeropower_via_newtonschulz5) earlier at line 789. On the smoke-test machines where people set SKIP_COMPILE because compile is slow or unsupported, the script can still fail or spend a long time compiling before it ever reaches this branch, so the escape hatch is incomplete.

Useful? React with 👍 / 👎.

hacksurvivor and others added 2 commits March 19, 2026 16:10
INT8_KEEP_FLOAT_MAX_NUMEL=1M accidentally kept attention weights
(589K params) in fp16 instead of int8, inflating artifact to 17.6MB.
Fix: name-based pattern keeps only tok_emb in fp16, threshold back to
65K. Reduces raw payload from 19.5MB to 14.2MB.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- LAWA averaging now stays on GPU using lerp_() instead of copying to
  CPU every step. Eliminates PCIe overhead during warmdown.
- SKIP_COMPILE=1 now also skips torch.compile(zeropower_via_newtonschulz5)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants