Skip to content

KURE/R2 + Tanh Reparam + Parallel EMA + LoRA TTT#4

Open
machdragon wants to merge 11 commits intomainfrom
submission/lawa-kure-r2-ttt
Open

KURE/R2 + Tanh Reparam + Parallel EMA + LoRA TTT#4
machdragon wants to merge 11 commits intomainfrom
submission/lawa-kure-r2-ttt

Conversation

@machdragon
Copy link
Owner

@machdragon machdragon commented Mar 20, 2026

Summary

Built on PR openai#201 (LAWA-EMA + Int6 + Overtone + MLP3x, val_bpb=1.1551). Adds four improvements targeting quantization fidelity and eval-time adaptation.

Changes (1857 lines, +329 from base)

1. KURE + R2 regularization (lines 1081-1098, 1635)

  • quant_reg_loss() — kurtosis→1.8 penalty + 2×std outlier penalty
  • Added to training loop using unwrapped base_model (DDP-safe)
  • eps=1e-8 in kurtosis to prevent NaN when var≈0

2. Tanh reparameterization (lines 546-553)

  • CastedLinear._tanh_reparam class flag; forward applies torch.tanh(w) for 2D weights ≥ 64×64
  • .weight kept as nn.Parameter (safe for all ~15 callers: _init_weights, optimizer setup, TTT adapter sizing, tied embeddings, etc.)
  • LAWA shadows track tanh(P) directly; materialized before export
  • QAT STE still applies after tanh (correct ordering)

3. Parallel EMA tracks (lines 1569-1582, 1660-1666, 1703-1728)

  • Three decay rates: 0.995, 0.999, 0.9995 (default raised from 0.995)
  • Clone-then-pick via proxy eval at export time
  • Safe: doesn't mutate shadows during eval passes

4. Causal LoRA TTT (lines 1107-1287, 1838-1849)

  • Ported from PR [record bpb=1.195] sliding window + LoRA TTT openai/parameter-golf#77: BatchedLinearLoRA, BatchedTTTLoRA, eval_val_ttt_lora()
  • Integrated into GPT.forward() (lora= kwarg), Block.forward() (q/v delta fns), CausalSelfAttention.forward() (q/v deltas)
  • Per-token loss (reduction="none") when lora is set; mean loss otherwise
  • Per-doc LoRA reset, score-before-train, causal attention — leakage-free
  • Runs after int6 roundtrip quantization

Hyperparameters (env vars)

Variable Default Description
KURE_LAMBDA 0.01 Kurtosis regularization strength
R2_LAMBDA 0.01 Range regularization strength
TANH_REPARAM 1 Enable tanh weight reparameterization
LAWA_EMA_DECAY 0.999 Default EMA decay
TTT_LORA_ENABLED 1 Enable LoRA TTT at eval
TTT_LORA_RANK 8 LoRA rank
TTT_LORA_LR 0.01 LoRA learning rate
TTT_CHUNK_SIZE 256 Chunk size for TTT
TTT_EVAL_SEQ_LEN 1024 Context window for TTT
TTT_BATCH_SIZE 64 Batch size for TTT eval

Verification

  • python3 -c "import ast; ast.parse(...)" — syntax OK
  • quant_reg_loss, KURE_LAMBDA, R2_LAMBDA — reg wired in
  • tanh_reparam, torch.tanh(w) — tanh in CastedLinear.forward
  • lawa_decays, lawa_averaged — parallel tracks + safe export
  • BatchedTTTLoRA, eval_val_ttt_lora — TTT ported
  • 1e-8 in kurtosis — eps present
  • base_model in reg loss — DDP unwrap correct
  • .weight still exists on CastedLinear (inherits nn.Linear)
  • Run eval → update submission.json

Test plan

🤖 Generated with Claude Code


Open with Devin

machdragon and others added 2 commits March 20, 2026 15:22
Built on PR openai#201 (LAWA-EMA + Int6 + Overtone + MLP3x, val_bpb=1.1551).
Adds four improvements targeting quantization fidelity and eval-time adaptation:

- KURE kurtosis regularization + R2 outlier penalty for int6-friendly weights
- Tanh weight reparameterization bounding effective weights to [-1,1]
- Parallel EMA tracks (0.995/0.999/0.9995) with proxy-eval selection
- Causal LoRA TTT (rank 8) ported from PR openai#77 for eval-time adaptation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no bugs or issues to report.

Open in Devin Review

machdragon and others added 9 commits March 20, 2026 15:37
8xH100 launcher with all new env defaults: KURE_LAMBDA=0.01,
R2_LAMBDA=0.01, TANH_REPARAM=1, LAWA_EMA_DECAY=0.999, TTT_LORA_ENABLED=1.
CLI flags for key hyperparameters.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- LAWA/KURE: windreamer cu126_torch2100 + pytorch 2.10.0 cu12.6 devel base
- Drop FA hopper source build; add Dockerfile.modal-fa3 for GHCR reuse
- modal_fa3_image_smoke.py: quick H100 import check; RUNBOOK updates

Made-with: Cursor
- modal_train_volume_check: shared ensure before LAWA/KURE torchrun
- modal_fa3_image_smoke: mount parameter-golf-data, verify FA3 + SP load + shard count
- Dockerfile.modal-fa3: use /opt/conda/bin/python -m pip (PEP 668)
- RUNBOOK: required sync section + smoke/troubleshooting updates

Made-with: Cursor
- Continue after dataset put failure so tokenizer always uploads
- MODAL_SYNC_FORCE=1 for full dataset re-upload; tokenizer uses --force
- RUNBOOK: document re-run + MODAL_SYNC_FORCE

Made-with: Cursor
- Add modal_image_fa3_pytorch.py: /opt/conda/bin/python -m pip per Modal run_commands
- Smoke + LAWA + KURE use shared builder + add_local_python_source(volume_check)
- RUNBOOK: link Modal images guide + troubleshooting for externally-managed-environment

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant