LAWA-EMA frontier fork with int6 QAT + 3x MLP#3
Conversation
Made-with: Cursor
Made-with: Cursor
Fork of pr162 (1.1483 BPB) with SWA replaced by LAWA-EMA: - Float32 shadow on GPU, updated every step (decay=0.995 default) - ENV-configurable: LAWA_ENABLED, LAWA_EMA_DECAY - ~68MB GPU overhead, ~0.1ms/step cost - 1218 lines, well under 1500-line cap Ready for ablation ladder: baseline repro, decay sweep, 11L scale-up. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Documents: what frontier methods are inherited from pr162, what we changed (SWA→LAWA-EMA), effective window comparison showing our default decay=0.995 gives a ~200-step window vs SWA's ~3000-step window, bull/bear cases, what openai#197 actually proved, and what metric we expect to improve (post-quant roundtrip BPB). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Set up Modal-based 1xH100 execution with reusable scripts for GPU smoke tests, training runs, and volume data sync, and ignore the local Modal virtualenv in git. Made-with: Cursor
Replace the pr162-based fork with pr198 (11L, WD=0.04, relu², FA3, NTK RoPE) as the base. SWA→LAWA-EMA swap and Overtone init are the only changes from pr198, giving a clean single-variable ablation on the strongest confirmed leaderboard submission. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Runs records/track_10min_16mb/lawa_frontier/train_gpt.py on 8x H100 via torchrun with PR openai#198 defaults (11 layers, LAWA-EMA decay=0.995, bigram vocab 2048). Uses devel CUDA image for FA3 compilation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The flash-attn pip package doesn't include the FA3 Hopper interface (flash_attn_interface). Build it from Dao-AILab/flash-attention/hopper with GPU-enabled image build step. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: | ||
| nn.init.orthogonal_(module.weight, gain=1.0) | ||
| if ".proj." in name or name.endswith(".proj"): | ||
| with torch.no_grad(): | ||
| module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) |
There was a problem hiding this comment.
🟡 BigramHashEmbedding zero-init proj weight silently overwritten by orthogonal init in _init_weights
BigramHashEmbedding.__init__ explicitly zeros self.proj.weight at records/track_10min_16mb/lawa_frontier/train_gpt.py:656, but GPT._init_weights() at line 786 iterates all modules and applies nn.init.orthogonal_ to any nn.Linear that (a) doesn't have _zero_init=True and (b) has both dimensions ≥ 64. Since bigram.proj is a CastedLinear(128, 512) without _zero_init, it matches the elif branch and gets orthogonal init scaled by 1/√(2·num_layers), completely overwriting the intended zero initialization. This means the bigram hash embedding contributes nonzero values from the start of training (scaled by self.scale=0.05), defeating the design of starting with zero bigram contribution.
Prompt for agents
In records/track_10min_16mb/lawa_frontier/train_gpt.py, the GPT._init_weights() method at line 786 overwrites BigramHashEmbedding.proj's intended zero initialization with orthogonal init. Fix by either: (1) Adding self.proj._zero_init = True in BigramHashEmbedding.__init__ at line 654 (after creating the CastedLinear), so _init_weights respects the zero-init intent, or (2) Adding a name-based exclusion in _init_weights to skip bigram.proj from orthogonal init.
Was this helpful? React with 👍 or 👎 to provide feedback.
| window_starts = [ws for ws in range(0, total_tokens, stride) | ||
| if min(ws + seq_len, total_tokens) - ws >= 1] |
There was a problem hiding this comment.
🟡 Sliding window eval double-counts tokens in partial windows at validation boundary
In eval_val_sliding, the window generation at line 886 includes partial windows (where ws + seq_len > total_tokens) via range(0, total_tokens, stride). The scoring logic at line 929 uses s = max(wlen - stride, 0) for partial windows. When the last full window's scored region overlaps with the partial window's scored region, some tokens get double-counted in loss_sum, token_count, and byte_count. Concretely, the overlap is stride - (total_tokens - ws_last_full - seq_len) tokens, up to stride tokens. With stride=64 and ~62M validation tokens, the impact on val_bpb is negligible (~0.0001%), but it's semantically incorrect. Compare with the correct implementation in records/.../train_gpt_sliding.py:53 which only creates full windows (while p + seq_len <= total).
| window_starts = [ws for ws in range(0, total_tokens, stride) | |
| if min(ws + seq_len, total_tokens) - ws >= 1] | |
| window_starts = [ws for ws in range(0, total_tokens, stride) | |
| if ws + seq_len <= total_tokens] |
Was this helpful? React with 👍 or 👎 to provide feedback.
1. BigramHashEmbedding.proj: set _zero_init=True so _init_weights skips orthogonal init and preserves the intended zero initialization. Without this, the proj weight was overwritten by orthogonal_ init, defeating the gradual bigram contribution ramp-up. 2. eval_val_sliding: only generate full windows (ws + seq_len <= total) to avoid partial windows double-counting tokens at the boundary. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
records/track_10min_16mb/lawa_frontier/train_gpt.py): Rebased onto PR 11-Layer Int6 + WD=0.04 + SWA + FA3 (val_bpb: 1.1318) openai/parameter-golf#198 base with LAWA-EMA averaging (decay=0.995), Overtone init, wider 3x MLP, STE int6 QAT, MTP, seq2048 + NTK RoPE, fp16 embed, and sliding window eval. Achieves 1.1318 BPB.scripts/modal_train_lawa.pyfor the frontier fork (8x H100, FA3 Hopper kernels built from source) andscripts/modal_train_h100.pyfor baseline (1x H100).train_gpt.py.Test plan
modal_train_lawa.pysyntax check passesmodal run scripts/modal_train_lawa.py --helpshows CLI argsmodal run scripts/modal_train_lawa.py --run-id lawa-test-001🤖 Generated with Claude Code