Skip to content

Latest commit

 

History

History
581 lines (447 loc) · 31.5 KB

File metadata and controls

581 lines (447 loc) · 31.5 KB

Adding a New Model: The Complete Guide

Target audience: Engineers who need to integrate a new VLA model into FlashRT (e.g. Pi0.6, a fresh open-source VLA).

Time estimate: A single (framework, hardware) combination runs around 800-1200 lines of code, or 1-2 weeks of work, assuming the model's structure is close to Pi0.5 / Pi0. All four combinations (torch/jax × thor/rtx) take roughly 3-4 weeks.

Read in this order (don't skip ahead — each doc assumes the previous):

  1. This doc §0–§1 (you are here) — the repository contract and which files you'll touch
  2. flash_rt/frontends/torch/_template/README.md — the starter package. Open in a separate window before reading further; the rest of this doc references it. The template has 4 stub files (frontend / pipeline / weights_spec / attention) you copy and fill in.
  3. docs/stable_api.md — public API surface and naming conventions you must respect
  4. docs/calibration.md — how FP8 calibration works (read §2 + §10 before writing your _calibrate twin)
  5. docs/kernel_fusion.md — kernel naming + decision tree for which fvk.* to call where (skim §1 + §2; reference the rest as you write pipeline.py)
  6. docs/plugin_model_template.md — only if you're shipping a closed-source model as an external plugin instead of a PR into this repo. Also contains the first-light cosine routing table (Q&A section) — the most useful debugging reference once your model first runs.

Suggested first-week schedule for an ML-infra engineer with the source model already running in PyTorch:

  • Day 1 — read items 1–3, copy the template, list every weight tensor in your checkpoint
  • Day 2–3 — fill in weights_spec.py (declarative, mostly mechanical) and attention.py (~60 lines)
  • Day 4–5 — write pipeline.py (the bulk of the work; encoder_forward first, leave decoder_forward for day 6)
  • Day 6–7 — wire up frontend.py, run first inference, debug cosine using the routing table

0. Repository Contract (hard rules every new model must follow)

Mandatory rules:

1. Every (model, hardware) compute path lives in its own file:
     models/<model>/pipeline_<hw>.py
   — The suffix is required (_thor / _rtx). There is no default
     pipeline.py entry point.
   — No runtime hardware forks such as `if self._has_sm100` or
     `if arch == 'thor'`.
   — Even if two hardware implementations are 90% identical, they must
     still be split. Shared logic is reused through function calls or
     imports, not through in-file branching.

2. Every (model, framework, hardware) IO path = its own frontend:
     frontends/<framework>/<model>_<hw>.py
     class name: <Model><Fw>Frontend<Hw>
   — Example: frontends/torch/pi05_rtx.py contains Pi05TorchFrontendRtx
   — Same rule: split thor and rtx frontends even when they duplicate
     most of their code.

3. hardware/<hw>/shared_primitives.py is a closed set:
   — Only model-agnostic helpers belong here:
       _gpu_* helpers, _measure_scale_gpu,
       siglip_forward        (usable by any model with a SigLIP tower),
       encoder_forward       (usable by any Paligemma-encoder model),
       encoder_forward_calibrate
   — Model-specific forward / decoder functions are not allowed in this
     file. They go into models/<m>/pipeline_<hw>.py.

4. _PIPELINE_MAP is strictly one-to-one:
     ("model", "framework", "hw") -> ("flash_rt.frontends.<fw>.<m>_<hw>",
                                      "ClassName")
   Each tuple points to exactly one file and one class. Multiple tuples
   sharing a class (i.e. runtime forking) is forbidden.

Known historical exception (do NOT copy this pattern):
  pi0fast: frontends/torch/pi0fast.py is a single file with 14 runtime
           `_has_sm100` branches. Deprecated layout — pending split into
           per-hardware files. New models must follow rules 1-4 above.

1. Overview: which files you will touch

Walking through a hypothetical pi06 model (Paligemma backbone) that needs to support both Thor and RTX, under both torch and jax:

flash_rt/
├── hardware/__init__.py                  # 4 new lines in _PIPELINE_MAP
├── hardware/thor/attn_backend.py         # add make_pi06_attention_spec (if shapes change)
├── hardware/rtx/attn_backend.py          # (RTX) same
├── models/pi06/
│   ├── pipeline_thor.py                  # NEW — Thor forward functions
│   └── pipeline_rtx.py                   # NEW — RTX Pi06Pipeline class
├── frontends/torch/
│   ├── _pi06_thor_spec.py                # NEW — Thor torch WEIGHT_SPEC
│   ├── _pi06_rtx_spec.py                 # NEW — RTX torch WEIGHT_SPEC (if dims / iface differ)
│   ├── pi06_thor.py                      # NEW — Thor torch frontend
│   └── pi06_rtx.py                       # NEW — RTX torch frontend
├── frontends/jax/
│   ├── _pi06_thor_spec.py                # NEW
│   ├── _pi06_rtx_spec.py                 # NEW
│   ├── pi06_thor.py                      # NEW
│   └── pi06_rtx.py                       # NEW
├── configs/pi06.yaml                     # metadata
└── tests/test_all_models_precision.py    # add one segment

All four combinations together: ~3500-4500 lines. A single (framework, hardware) combination: ~800-1200 lines (of which ~120 lines are declarative spec).

Reference implementations:

  • pi05 — all four combinations complete: models/pi05/{pipeline_thor.py, pipeline_rtx.py} plus four frontends
  • pi0 — Thor is done, RTX is being refactored in stage 8
  • groot — Thor and RTX are done (jax only on Thor)

1.5. Working from the template

Before reading §2, copy the starter template:

# For a new model called "mymodel" on Thor:
cp -r flash_rt/frontends/torch/_template /tmp/mymodel_thor_work
cd /tmp/mymodel_thor_work
$EDITOR README.md   # 5-min read; explains the file split

Then work file-by-file in this order (each file's docstring tells you exactly what to translate from your source model):

  1. weights_spec.pyflash_rt/frontends/torch/_<mymodel>_thor_spec.py The declarative weight loader. Map each state_dict[...] key from your reference checkpoint to a WEIGHT_SPEC row. Pure mechanical work; ~60-120 lines for a Pi0.5-shape model.

  2. attention.py → append make_<mymodel>_attention_spec() to flash_rt/hardware/thor/attn_backend.py ~60 lines. Declares one add_site() call per distinct attention shape in your model (vision, encoder, decoder, etc.).

  3. pipeline.pyflash_rt/models/<mymodel>/pipeline_thor.py The hard part. Translate your reference model's forward() into a sequence of fvk.* kernel calls. The template's # WHAT YOU TRANSLATE block at the top shows the line-by-line mapping pattern. You'll write two functions per stage: a production forward (FP8, captured into CUDA Graph) and a calibration twin (BF16 + measures activation amax). 200-400 lines per hardware target.

  4. frontend.pyflash_rt/frontends/torch/<mymodel>_thor.py Wires it all together. Owns weight upload, buffer allocation, calibration cache, and CUDA Graph capture. Class name must be <Model>TorchFrontendThor per §0 rule 2. ~400-700 lines.

After all four files compile and your first infer() produces non-NaN output, run cosine vs your PyTorch FP32 reference. Use the first-light cosine routing table in plugin_model_template.md to narrow down where to look — that table maps the cos number you see directly to the most likely root cause.

For RTX, repeat with pipeline_rtx.py / <mymodel>_rtx.py. For JAX, the template covers torch only — copy from frontends/jax/pi05_thor.py for the JAX patterns (Orbax loading, weight cache, etc.).


2. Step-by-step walkthrough

§2 below provides the detailed reference for each step the template guides you through. Use it as a lookup, not a tutorial — you should already have copied the template before reading further.

(1) AttentionSpec — 15-35 lines

File: flash_rt/hardware/thor/attn_backend.py (Thor) or flash_rt/hardware/rtx/attn_backend.py (RTX).

Copy make_pi05_attention_spec and adjust:

def make_pi06_attention_spec(num_views: int, *,
                              enc_total_keys: int, dec_total_keys: int) -> AttentionSpec:
    """Pi0.6: 24L encoder / 24L decoder / H_dim=256 / GQA 8:1."""
    S_sig = num_views * 256
    NH_sig, HD_sig = 16, 72              # SigLIP 1152/16
    NH_enc, HD_enc = 8, 256              # Paligemma 2048/8
    NH_dec, HD_dec = 8, 256              # Action expert

    return AttentionSpec(
        sites=[
            SiteSpec(
                name="siglip", layer_count=27, q_seq_len=S_sig, kv_seq_len=S_sig,
                num_heads=NH_sig, head_dim=HD_sig,
                extra={"kernel": None},  # fmha_strided_full dispatcher
            ),
            SiteSpec(
                name="encoder", layer_count=24, q_seq_len=...,  # Se filled at runtime
                kv_seq_len=enc_total_keys,
                num_heads=NH_enc, head_dim=HD_enc, num_kv_heads=1,
                extra={"kernel": "standard"},
            ),
            SiteSpec(
                name="decoder", layer_count=24, q_seq_len=10,
                kv_seq_len=dec_total_keys,
                num_heads=NH_dec, head_dim=HD_dec, num_kv_heads=1,
                extra={"kernel": "standard"},
            ),
        ],
    )

Allowed values for extra["kernel"] (see backend.py:AttentionBackend for the full table):

kernel value underlying fvk primitive used for
None (siglip only) fmha_strided_full SigLIP 2D vision attention
"standard" attention_qkv_fp16 GQA encoder/decoder, no state padding
"state_masked" attention_qkv_fp16_state_masked Pi0 decoder (the first token is state)
"mha" attention_mha_fp16 GROOT Qwen3 full-MHA plus DiT self/cross

Do not invent new kernel values. If you need a new variant, extend the dispatch branches in hardware/thor/attn_backend.py:ThorFlashAttnBackend.run.


(2) Pipeline forward — 200-400 lines per hardware; the bulk of the hand-written code

Files:

  • flash_rt/models/pi06/pipeline_thor.py (Thor path)
  • flash_rt/models/pi06/pipeline_rtx.py (RTX path)

Each hardware gets its own file, even if the two paths end up looking similar. Genuinely shared code lives in hardware/<hw>/shared_primitives.py (reserved for truly model-agnostic helpers) or is imported between sibling functions.

Recent references to copy from:

Every forward function must obey the pointer-interface contract:

def decoder_forward_pi06(
    gemm: fvk.GemmRunner,
    fvk_module,                    # flash_rt_kernels
    bufs: dict,                    # {'x': ptr, 'xn': ptr, ...}
    weights: dict,                 # {'qw': ptr, 'ow': ptr, 'gu': ptr, 'd': ptr, ...}
    dims: dict,                    # {'S': 10, 'Da': 1024, 'Ha': 4096, 'La': 24, ...}
    scales_dev: int,               # device pointer to fp32 scale array
    *,
    attn=None,                     # AttentionBackend; None = legacy fallback
    stream: int = 0,
):
    """Every argument is a raw pointer or a Python primitive that ctypes can pass
    through — this is what makes the function safe to call repeatedly during
    CUDA Graph capture."""
    ...

Why this interface: CUDA Graph capture requires the same Python function, calling the same sequence of kernels, with the same pointers, on every replay. Tensor objects can be garbage-collected or reallocated between replays; raw .data_ptr() values cannot.

Catalog of kernels available for building forwards: docs/kernel_fusion.md lists all 93 public fvk functions and the legal fusion patterns.

Key rules:

  • All intermediate buffers must be pre-allocated in frontend._load_weights. The forward only reads pointers — no dynamic allocation.
  • Never call .cpu(), .numpy(), torch.empty(), or sync() inside a forward.
  • Attention goes through attn.run(site=..., layer_idx=i, ...). Do not call fvk.attention_qkv_fp16(...) directly.
  • Full rule set: docs/kernel_fusion.md §5 known-failed optimizations

(3) Torch WEIGHT_SPEC — 60-120 lines per (framework, hardware) combo, declarative

Files:

  • flash_rt/frontends/torch/_pi06_thor_spec.py
  • flash_rt/frontends/torch/_pi06_rtx_spec.py (only if dims or weight interface differ)

When the two hardwares share the exact same weight interface (common — both sides read the same safetensors checkpoint), a single spec file can back both frontends via from ._pi06_thor_spec import build_spec. The default is still one spec file per hardware, to avoid a future dim change on one side accidentally regressing the other.

If the backbone is in the Paligemma / Gemma family (very likely):

from flash_rt.executors.weight_loader import Item, LayerBlock, ModelWeightSpec
from flash_rt.executors.torch_weights import (
    FlatCat, FusedGateUp, FusedQKV, Quant, TensorList, ToFp16, tT,
)
from flash_rt.frontends.torch._thor_spec_common import (
    paligemma_encoder_block, paligemma_siglip_block,
)


def _decoder_block():
    dp = "paligemma_with_expert.gemma_expert.model.layers.{i}"
    return LayerBlock(
        prefix_fmt="", num_layers=24, name="decoder",
        items=[
            Item("qkv_w",
                 FusedQKV(q=f"{dp}.self_attn.q_proj.weight",
                          k=f"{dp}.self_attn.k_proj.weight",
                          v=f"{dp}.self_attn.v_proj.weight",
                          norm_fuse=f"{dp}.input_layernorm.weight",
                          interleave_q_heads=8,
                          interleave_k_heads=1),
                 [tT(), Quant()],
                 FlatCat("_dec_qkv_flat"), scale_into="_ae_w_scales"),
            # ... follow the pattern in _pi0_thor_spec.py
        ],
    )


def build_spec() -> ModelWeightSpec:
    return ModelWeightSpec(
        framework="torch",
        blocks=[
            paligemma_siglip_block(),
            paligemma_encoder_block(num_layers=24),
            _decoder_block(),
        ],
    )

If the backbone is a novel architecture (Qwen3, etc.): look at frontends/torch/groot_thor.py::_load_qwen3_weights, which is still a hand-written loop rather than a declarative spec. You will likely need to either:

  • add a new shared block builder to _thor_spec_common.py, or
  • define a new composite source (something like FusedQKV) — see flash_rt/executors/torch_weights.py.

Op order must be byte-identical: compare your spec, op by op, against the legacy hand-written loader — .T.contiguous() vs .t().contiguous(), ToFp16 before or after Quant, exactly when norm_fuse is applied. A single character wrong causes precision regressions.


(4) Frontend — 700-1000 lines per frontend; ~2800-4000 lines across all four

Files:

  • flash_rt/frontends/torch/pi06_thor.py (class: Pi06TorchFrontendThor)
  • flash_rt/frontends/torch/pi06_rtx.py (class: Pi06TorchFrontendRtx)
  • flash_rt/frontends/jax/pi06_thor.py (class: Pi06JaxFrontendThor)
  • flash_rt/frontends/jax/pi06_rtx.py (class: Pi06JaxFrontendRtx)

Class-name rule: <Model><Framework>Frontend<HW> in CamelCase — e.g. Pi05TorchFrontendThor, Pi05TorchFrontendRtx, GrootJaxFrontendThor.

Skeleton: copy the nearest sibling (same framework, same hardware) and edit:

What to change Where Lines
__init__ num_views, checkpoint path a few
_load_norm_stats new norm_stats path (if it moved) 20
_load_weights call _pi06_<hw>_spec.build_spec(); update dim constants; update _sig_weights keys 120
set_prompt tokenizer; time_mlp precompute; call _calibrate and _capture_*_graph 100
_calibrate call encoder_forward_calibrate / decoder_forward_calibrate 150
_capture_*_graph update dims; call models/pi06/pipeline_<hw>.py::{encoder,decoder}_forward_pi06 135
_autotune_enc_ae copy unchanged 50
infer input preprocessing / noise / action decode 80
get_latency_stats copy unchanged 15

Things you must never do:

  • Allocate new tensors inside infer (violates the CUDA Graph contract).
  • Change graph topology inside _calibrate (triggers Myelin tactic drift).
  • Skip .contiguous() (column-major vs row-major layout bugs).
  • Detect required hardware routing at runtime inside a frontend (hasattr(fvk, ...)) and branch on it — this is the pi0fast anti-pattern. New models must ship two independent thor/rtx frontends. Optional fast-path probes are allowed only with an equivalent tested fallback.

(5) JAX-side differences worth calling out

Where the JAX side diverges from torch:

  • The source is OrbaxDictSource(engine_w), where engine_w is the dict[str, ndarray] exported by openpi. The key names are not HF safetensors paths; they follow openpi's internal schema (e.g. vision.layer.{i}.qkv.weight). See _thor_spec_common.py.
  • engine_w typically has QKV and gate_up already fused (openpi does this during export). So the spec does not need FusedQKV / FusedGateUp — plain JaxQuant() is enough.
  • The sink is CudaBufferFlat / CudaBufferAttr plus an explicit cache=... argument (weight caching relies on it).
  • The frontend must set self._cache_blobs = {} before calling WeightLoader.run(...).

Weight cache behavior: the default is weight_cache=True. The first load takes ~30-40s; results are cached to ~/.flash_rt/weights/<hash>_nv<N>.bin, and subsequent loads take ~5s. When modifying a spec you must preserve the cache key schema (sig_wt_fp8.{0..11}, etc.); otherwise the cache format changes and users have to wipe it manually.


(6) _PIPELINE_MAP registration — 4 lines (per hardware × per framework)

File: flash_rt/hardware/__init__.py

_PIPELINE_MAP: dict[...] = {
    ...  # existing entries
    # ── Pi0.6 ──
    ("pi06", "torch", "thor"):
        ("flash_rt.frontends.torch.pi06_thor", "Pi06TorchFrontendThor"),
    ("pi06", "torch", "rtx_sm120"):
        ("flash_rt.frontends.torch.pi06_rtx",  "Pi06TorchFrontendRtx"),
    ("pi06", "jax", "thor"):
        ("flash_rt.frontends.jax.pi06_thor",   "Pi06JaxFrontendThor"),
    ("pi06", "jax", "rtx_sm120"):
        ("flash_rt.frontends.jax.pi06_rtx",    "Pi06JaxFrontendRtx"),
}

One entry, one class. Two entries pointing at the same class is the pi0fast anti-pattern.

Then confirm that config="pi06" is accepted in api.py::load_model — the function validates configs near the top.


(7) Config YAML — 30 lines

File: flash_rt/configs/pi06.yaml

Copy pi05.yaml as a starting point. Fields: num_layers, hidden_dim, num_heads, head_dim, ffn_hidden_dim, num_views, action_horizon, vocab_size, and so on.

This YAML is used only for logging and metadata. Runtime dimensions still come from the constants hard-coded inside the frontend.


(8) Precision test — 80 lines

File: tests/test_all_models_precision.py

  1. Near the top add PI06_SCRIPT = """...""": load the pipe, set a prompt, run 5 warmup iterations, patch the RNG, record 20 latency samples, compute cosine similarity against the pytorch reference.
  2. Add 'pi06': ('Pi0.6', PI06_SCRIPT) to the _configs dict.
  3. Save the pytorch reference to /tmp/pi06_pytorch_ref.npy.

On the 5090 side, additionally add a pi06 segment to your local smoke / cosine / latency benchmark scripts (you'll likely have your own; the public test suite covers smoke and unit tests, latency benchmarks are typically per-team).


3. Validation protocol — run on every commit

# CPU unit tests (seconds)
python tests/test_weight_loader.py            # 16/16
python tests/test_thor_attn_backend.py        # 12/12
python tests/test_thor_groot_attn_backend.py  # 11/11

# 5090 GPU validation (if you added an RTX path)
python examples/quickstart.py --checkpoint <ckpt> --config pi06 \
    --benchmark 200                            # smoke + latency
# Cosine: load the model, run predict() with matched_noise, compare
# against your PyTorch FP32 reference run on the same observation.

# Thor GPU precision sweep (~3-5 minutes)
free -h | head -2   # always check free memory before heavy Thor runs
python tests/test_all_models_precision.py --model pi06

Thresholds:

  • First-time bring-up of a new model: cos ≥ 0.995 (vs pytorch ref), and P50 inside the target latency budget.
  • Models structurally derived from Pi0.5 / Pi0: the "bit-identical" band (~0.9986) indicates the FP8 bytes match exactly.

If cosine falls out of the window:

  1. Don't guess. First check the spec's op order byte-for-byte against the legacy loader.
  2. Use an A/B comparison to rule out Myelin tactic jitter — run 2-3 times back-to-back.
  3. If it really is a regression, revert the commit immediately. Don't try to patch it in a follow-up.

4. Thor-specific pitfalls (must read)

4.1 CUDA Graph non-determinism

Recompiling the same MLIR → Myelin picks a different tactic → ±2ms P50 drift and ~0.001 cos jitter. This is specific to Thor.

Don't:

  • Draw conclusions from a single measurement (always A/B).
  • "Fix" a ±0.001 jitter in a new commit (it's almost certainly tactic drift, not code).
  • Compare latency numbers taken at different times.

Do:

  • Use a timing cache to pin the tactic (though you cannot choose the optimal one directly).
  • Keep a reference timing cache around (see deployment/scripts/l2v2_timing_cache.bin).

4.2 Don't run heavy tasks in parallel

Thor has 122Gi of unified memory. Loading two models concurrently will OOM. Tests must run serially.

4.3 Kernel binding and CMake ownership

Adding a new model should usually reuse existing fvk.* entries. If you do need a new kernel or a new pybind entry, keep the CMake target ownership in lockstep with the binding:

  • An unconditional m.def(...) in csrc/bindings.cpp must link against an implementation that exists in every supported GPU_ARCH build, or it must call an unconditional stub that raises a clear "not built / not supported" error.
  • If the .cu implementation is hardware- or feature-gated, the binding must be gated the same way. Do not let a public flash_rt_kernels symbol depend on a model-specific object target that only builds for one architecture.
  • Shared quantize, layout, RoPE, activation, and utility kernels belong in the main flash_rt_kernels target unless every binding and every caller is gated with the same condition.
  • Model-specific object libraries should contain only model- and hardware-specific kernels.

4.4 Kernel alias, fused helper, and shape contract checks

Many model integrations reuse older fvk.* names as thin wrappers around new shared helpers. Treat those names as ABI: callers may depend on their argument order, tensor shape convention, dtype, rounding, and in-place behavior.

Incident note: #30 introduced this exact class of integration bug when the legacy bias_gelu_bf16(_strict) public names were routed to a shared bias_gelu_inplace_bf16(M, N) helper but the wrapper passed (seq_len * dim, dim) instead of (seq_len, dim). #40 is the reference fix.

Before replacing an old kernel implementation with a shared helper:

  • Write down the public binding contract: pointer arguments, seq_len / dim or M / N meaning, bias shape, output shape, stream argument, and whether the op is in-place.
  • Compare the old kernel indexing with the new helper indexing. A legacy (seq_len, dim) tensor usually maps to helper (M=seq_len, N=dim), not (M=seq_len * dim, N=dim).
  • Preserve suffix semantics. *_strict must keep any explicit intermediate rounding or reference-order behavior; *_bf16 and *_fp16 must not silently swap dtype; *_static must not read or write device-side dynamic scales.
  • Keep optional fast paths optional. hasattr(fvk, "...") is acceptable only when the fallback produces the same result and is still covered by tests. Required hardware routing belongs in _PIPELINE_MAP, frontend selection, or a clear constructor error.
  • Add or update a small binding-level test when the wrapper changes argument mapping. The test does not need a full model checkpoint; a tiny tensor is enough to catch shape expansion, OOB writes, and wrong bias broadcasting.

This check is mandatory for fused replacements such as add_bias + activation, residual + norm, qkv split + rope, and any decoder INT8 / FP8 helper that reuses a legacy public name.

For any CMake / binding ownership change, validate every affected hardware family:

cmake -B build_<arch> -S . -DGPU_ARCH=<arch>
cmake --build build_<arch> -j$(nproc) --target flash_rt_kernels
PYTHONPATH=. python - <<'PY'
from flash_rt import flash_rt_kernels
print(flash_rt_kernels.__file__)
PY
nm -D -u flash_rt/flash_rt_kernels*.so | c++filt | grep 'flash_rt::' || true

When moving a source out of a model-specific object target into flash_rt_kernels, test both sides of the change: the architecture that was missing the symbol must import cleanly, and the architecture that already had the source must not fail with duplicate definitions.

4.5 Don't rebuild the kernel .so

flash_rt/flash_rt_kernels.cpython-312-aarch64-linux-gnu.so (3.6MB) is a production binary. Adding a new model should not trigger a kernel rebuild — every fvk function you need is already in this .so. If you genuinely need a new kernel, that's a separate CUDA development flow, with explicit version backups.


5. Time estimate (realistic)

Assuming the new model is structurally similar to Pi0.5 / Pi0 (Paligemma backbone, flow-matching decoder), for a single (framework, hardware) combination:

Phase Estimate
(1)(6)(7) Skeleton and registration half a day
(2) Pipeline forward — forked from Pi0 with dim-constant edits 1-2 days
(3) WEIGHT_SPEC authoring + byte-diff validation 1 day
(4) Frontend — fork Pi0, edit dims / calibration / graph capture 3-4 days
(8) Tests and debugging 2-3 days
Total per combination ~1-2 weeks

All four combinations (torch/jax × thor/rtx): roughly 3-4 weeks — subsequent frontends reuse a lot of code.

If the backbone is a new architecture (Qwen3-like), add 1-2 more weeks for shared-block extensions, kernel compatibility evaluation, and possibly a new attention variant.


6. Quick checklist

  • (1) New AttentionSpec added to the correct hardware's attn_backend.py; unit tests pass.
  • (2) Pipeline forward functions use the pointer-only interface, do no dynamic allocation, and each hardware has its own pipeline_<hw>.py file.
  • (3) _<model>_<hw>_spec.py smoke-builds via build_spec().
  • (4) Frontend is fully implemented, each (framework, hardware) has its own <m>_<hw>.py file, and all buffers are pre-allocated in _load_weights.
  • No file uses if self._has_sm100 or hasattr(fvk, '...') for required hardware routing.
  • shared_primitives.py has not gained any model-specific functions.
  • Any new csrc/bindings.cpp entry and its .cu implementation have matching CMake guards / target ownership.
  • Any reused or renamed fvk.* alias preserves the old binding shape contract, dtype, strict rounding behavior, and in-place semantics.
  • Any fused helper has been compared with the unfused reference on a tiny tensor for shape, bias broadcasting, and numeric parity.
  • Any hasattr(fvk, "...") branch is an optional fast path with a tested fallback, not required hardware routing.
  • If a kernel source moved between object targets and flash_rt_kernels, every affected GPU_ARCH builds flash_rt_kernels, imports it from Python, and has no missing or duplicate symbols.
  • (6) The four _PIPELINE_MAP entries are one-to-one, with no two rows pointing at the same class.
  • (7) YAML dims match the constants in the code.
  • (8) test_all_models_precision.py passes three consecutive A/B runs.
  • Weight-cache keys remain compatible with legacy (if the JAX spec changed).
  • Commit format: feat(<model>-<framework>-<hw>): ...

7. FAQ

Q: Why are runtime hardware forks like if hasattr(fvk, 'cutlass_fp8_sq') disallowed for required routing? A: Because of the lesson learned from pi0fast. A single file with many branches grows maintenance cost explosively: adding a new hardware means touching 14 spots; adding a new optimization means redoing it on every branch; stack traces no longer tell you which hardware path you were on; and CUDA Graphs capture different kernel sequences per hardware anyway, so if branching cannot actually unify them. Splitting per hardware lets each file focus on exactly one execution path. Optional hasattr(fvk, "...") probes are acceptable only when they select a fast path and the fallback is correct, tested, and documented.

Q: The thor and rtx frontends are 90% identical — wouldn't merging them save a lot of code? A: Short-term, yes. But "shared between two ends" means adding a third hardware requires splitting again, every change risks breaking the other side, and the test matrix becomes N×M. With per-hardware files, adding a new hardware is just adding a new file while the existing files stay stable. The total line count is slightly higher, but maintenance entropy is dramatically lower.

Q: KeyError: ... at load time? A: Some key in your WEIGHT_SPEC doesn't exist in the checkpoint. Inspect the actual safetensors keys:

python -c "from safetensors import safe_open; sf=safe_open('/path/to/model.safetensors', 'pt'); [print(k) for k in list(sf.keys())[:50]]"

Q: After loading, cosine is around 0.5? A: Likely causes: wrong QKV interleave (bad GQA head count), mixing .T.contiguous() with .t().contiguous(), or applying norm_fuse at the wrong point. Start with docs/calibration.md §4 precision journey.

Q: CUDA Graph capture fails? A: Your forward contains a dynamic allocation or a conditional branch that causes capture to take a different kernel path. Details in kernel_fusion.md §6.

Q: JAX loading takes ~40s — too slow? A: That's the expected first-load cost. Keep weight_cache=True (the default); subsequent loads are ~5s. If you changed the WEIGHT_SPEC's cache key, you need rm -rf ~/.flash_rt/weights/ so the cache can be rebuilt.

Q: New model OOMs on Thor? A: Thor has 122Gi of unified memory. Check: (1) free -h shows free memory greater than model size × 1.5; (2) no other pipeline is running concurrently; (3) the previous weight_cache version has been cleaned up.