Skip to content

LAWA-EMA frontier fork with int6 QAT + 3x MLP#3

Merged
machdragon merged 10 commits intomainfrom
submission/lawa-frontier-int6-mlp3x
Mar 20, 2026
Merged

LAWA-EMA frontier fork with int6 QAT + 3x MLP#3
machdragon merged 10 commits intomainfrom
submission/lawa-frontier-int6-mlp3x

Conversation

@machdragon
Copy link
Owner

@machdragon machdragon commented Mar 20, 2026

Summary

  • LAWA-EMA frontier training script (records/track_10min_16mb/lawa_frontier/train_gpt.py): Rebased onto PR 11-Layer Int6 + WD=0.04 + SWA + FA3 (val_bpb: 1.1318) openai/parameter-golf#198 base with LAWA-EMA averaging (decay=0.995), Overtone init, wider 3x MLP, STE int6 QAT, MTP, seq2048 + NTK RoPE, fp16 embed, and sliding window eval. Achieves 1.1318 BPB.
  • Modal 8xH100 training scripts: scripts/modal_train_lawa.py for the frontier fork (8x H100, FA3 Hopper kernels built from source) and scripts/modal_train_h100.py for baseline (1x H100).
  • Non-record staging profile with 8xH100 submission achieving val_bpb=1.189.
  • Supporting tooling: data sync script, smoke test, trimmed root train_gpt.py.

Test plan

  • modal_train_lawa.py syntax check passes
  • modal run scripts/modal_train_lawa.py --help shows CLI args
  • Full training run: modal run scripts/modal_train_lawa.py --run-id lawa-test-001

🤖 Generated with Claude Code


Open with Devin

machdragon and others added 8 commits March 20, 2026 02:25
Fork of pr162 (1.1483 BPB) with SWA replaced by LAWA-EMA:
- Float32 shadow on GPU, updated every step (decay=0.995 default)
- ENV-configurable: LAWA_ENABLED, LAWA_EMA_DECAY
- ~68MB GPU overhead, ~0.1ms/step cost
- 1218 lines, well under 1500-line cap

Ready for ablation ladder: baseline repro, decay sweep, 11L scale-up.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Documents: what frontier methods are inherited from pr162, what we
changed (SWA→LAWA-EMA), effective window comparison showing our
default decay=0.995 gives a ~200-step window vs SWA's ~3000-step
window, bull/bear cases, what openai#197 actually proved, and what
metric we expect to improve (post-quant roundtrip BPB).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Set up Modal-based 1xH100 execution with reusable scripts for GPU smoke tests, training runs, and volume data sync, and ignore the local Modal virtualenv in git.

Made-with: Cursor
Replace the pr162-based fork with pr198 (11L, WD=0.04, relu², FA3,
NTK RoPE) as the base. SWA→LAWA-EMA swap and Overtone init are the
only changes from pr198, giving a clean single-variable ablation on
the strongest confirmed leaderboard submission.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Runs records/track_10min_16mb/lawa_frontier/train_gpt.py on 8x H100
via torchrun with PR openai#198 defaults (11 layers, LAWA-EMA decay=0.995,
bigram vocab 2048). Uses devel CUDA image for FA3 compilation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The flash-attn pip package doesn't include the FA3 Hopper interface
(flash_attn_interface). Build it from Dao-AILab/flash-attention/hopper
with GPU-enabled image build step.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 potential issues.

View 5 additional findings in Devin Review.

Open in Devin Review

Comment on lines +786 to +790
elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64:
nn.init.orthogonal_(module.weight, gain=1.0)
if ".proj." in name or name.endswith(".proj"):
with torch.no_grad():
module.weight.mul_(1.0 / math.sqrt(2 * num_layers))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 BigramHashEmbedding zero-init proj weight silently overwritten by orthogonal init in _init_weights

BigramHashEmbedding.__init__ explicitly zeros self.proj.weight at records/track_10min_16mb/lawa_frontier/train_gpt.py:656, but GPT._init_weights() at line 786 iterates all modules and applies nn.init.orthogonal_ to any nn.Linear that (a) doesn't have _zero_init=True and (b) has both dimensions ≥ 64. Since bigram.proj is a CastedLinear(128, 512) without _zero_init, it matches the elif branch and gets orthogonal init scaled by 1/√(2·num_layers), completely overwriting the intended zero initialization. This means the bigram hash embedding contributes nonzero values from the start of training (scaled by self.scale=0.05), defeating the design of starting with zero bigram contribution.

Prompt for agents
In records/track_10min_16mb/lawa_frontier/train_gpt.py, the GPT._init_weights() method at line 786 overwrites BigramHashEmbedding.proj's intended zero initialization with orthogonal init. Fix by either: (1) Adding self.proj._zero_init = True in BigramHashEmbedding.__init__ at line 654 (after creating the CastedLinear), so _init_weights respects the zero-init intent, or (2) Adding a name-based exclusion in _init_weights to skip bigram.proj from orthogonal init.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +886 to +887
window_starts = [ws for ws in range(0, total_tokens, stride)
if min(ws + seq_len, total_tokens) - ws >= 1]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Sliding window eval double-counts tokens in partial windows at validation boundary

In eval_val_sliding, the window generation at line 886 includes partial windows (where ws + seq_len > total_tokens) via range(0, total_tokens, stride). The scoring logic at line 929 uses s = max(wlen - stride, 0) for partial windows. When the last full window's scored region overlaps with the partial window's scored region, some tokens get double-counted in loss_sum, token_count, and byte_count. Concretely, the overlap is stride - (total_tokens - ws_last_full - seq_len) tokens, up to stride tokens. With stride=64 and ~62M validation tokens, the impact on val_bpb is negligible (~0.0001%), but it's semantically incorrect. Compare with the correct implementation in records/.../train_gpt_sliding.py:53 which only creates full windows (while p + seq_len <= total).

Suggested change
window_starts = [ws for ws in range(0, total_tokens, stride)
if min(ws + seq_len, total_tokens) - ws >= 1]
window_starts = [ws for ws in range(0, total_tokens, stride)
if ws + seq_len <= total_tokens]
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

machdragon and others added 2 commits March 20, 2026 10:49
1. BigramHashEmbedding.proj: set _zero_init=True so _init_weights
   skips orthogonal init and preserves the intended zero initialization.
   Without this, the proj weight was overwritten by orthogonal_ init,
   defeating the gradual bigram contribution ramp-up.

2. eval_val_sliding: only generate full windows (ws + seq_len <= total)
   to avoid partial windows double-counting tokens at the boundary.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@machdragon machdragon merged commit 1362b07 into main Mar 20, 2026
1 check was pending
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant