Skip to content

fix(mha): don't compute FA3 scheduler metadata for non-FA3 backends#276

Closed
qywu wants to merge 1 commit into
mainfrom
fix/mha-scheduler-metadata-guard
Closed

fix(mha): don't compute FA3 scheduler metadata for non-FA3 backends#276
qywu wants to merge 1 commit into
mainfrom
fix/mha-scheduler-metadata-guard

Conversation

@qywu
Copy link
Copy Markdown
Collaborator

@qywu qywu commented May 27, 2026

Summary

MHAAttnBackend._maybe_compute_scheduler_metadata's docstring says it returns None when the active backend doesn't consume pre-computed scheduler metadata, but the implementation unconditionally calls mha_decode_scheduler_metadata and returns its result. Downstream in forward_decode, the non-None tensor gets passed to the selected kernel — and triton / fa4 / flashinfer reject the unknown scheduler_metadata kwarg:

TypeError: triton_mha_decode_with_kvcache() got an unexpected keyword
argument 'scheduler_metadata'

Guard the call so the docstring matches the behaviour: skip the compute (and therefore the downstream kwarg) when kernel_solution is anything other than "fa3" or None (None == auto-select, may land on FA3).

Repro

python -m tokenspeed.cli serve <any-MHA-model> --attention-backend=triton
# first decode forward TypeErrors as above

Test plan

  • --attention-backend=triton end-to-end inference on Hopper (Qwen2-1.5B-Instruct) — confirmed locally; produces correct output.
  • --attention-backend=fa3 continues to consume the pre-computed metadata as before.
  • --attention-backend=mha (auto) continues to compute it (auto-select may still land on FA3).
  • --attention-backend=fa4 / flashinfer no longer trip the TypeError.

Related

Encountered while verifying #272 / #273 / #274 / #275 end-to-end on H100.

``_maybe_compute_scheduler_metadata``'s docstring promises ``None`` when
the active backend doesn't consume pre-computed scheduler metadata, but
the implementation unconditionally calls ``mha_decode_scheduler_metadata``
and returns its result. Downstream, ``forward_decode`` then passes the
non-None tensor to whichever kernel is selected — and triton / fa4 /
flashinfer reject the unknown ``scheduler_metadata`` kwarg:

    TypeError: triton_mha_decode_with_kvcache() got an unexpected
    keyword argument 'scheduler_metadata'

Guard the call so the docstring is the truth: skip the compute (and the
downstream kwarg) when ``kernel_solution`` is anything other than
``"fa3"`` or ``None`` (None == auto-select, which may land on FA3).

Found while running --attention-backend=triton end-to-end on H100;
reproduces with any non-FA3 selection on Hopper.

Signed-off-by: Qingyang Wu <willqywu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant