Skip to content

MXFP8 uses logical block size 32, but some ROCm training paths still hard-require block_len == 128 #539

@ramgandikota

Description

@ramgandikota

Describe the bug

MXFP8 appears to use a logical block size of 32 on ROCm/gfx950, but some ROCm training paths in the installed package still hard-require block_len == 128.

From the installed package:

  • transformer_engine/pytorch/constants.py defines MXFP8_BLOCK_SCALING_SIZE = 32
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py checks MXFP8 quantizability using divisibility by 32
  • but transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast_hip.cpp still has:
    TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
  • and transformer_engine/pytorch/distributed.py has:
    if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128): raise NotImplementedError(...)

From the training side, this makes the current MXFP8 support boundary hard to interpret on MI350X/gfx950.

Steps/Code to reproduce bug

  1. Install ROCm Transformer Engine v2.8_rocm on MI350X/gfx950.
  2. Inspect the installed package:
# transformer_engine/pytorch/constants.py
MXFP8_BLOCK_SCALING_SIZE = 32
# transformer_engine/pytorch/tensor/mxfp8_tensor.py
if inp.shape[-1] % MXFP8_BLOCK_SCALING_SIZE != 0:
    return False
if math.prod(inp.shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE != 0:
    return False
// transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast_hip.cpp
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
# transformer_engine/pytorch/distributed.py
if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128):
    raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")

One of the shapes I am testing has hidden_size=2880, which is divisible by 32 but not by 128:

  • 2880 % 32 == 0
  • 2880 % 128 == 64

So it is not clear whether this shape should be expected to work for MXFP8 training on gfx950.

Expected behavior

Either:

  1. MXFP8 training paths on gfx950 should consistently support shapes that satisfy the logical MXFP8 32-wide constraints, or
  2. the docs / package behavior should clearly document that some current ROCm training paths still require 128-based packing/tiling even when MXFP8 itself is exposed as a 32-wide logical format.

Environment overview (please complete the following information)

  • Environment location: Docker
  • Method of Transformer Engine install: from source

Exact install command used:

pip uninstall -y transformer-engine transformer_engine 2>/dev/null || true
pip install --no-build-isolation "git+https://github.com/ROCm/TransformerEngine.git@v2.8_rocm"

Base image used in the custom Docker build:

docker pull docker.io/rocm/primus:v25.11

The final runtime image is a custom image built on top of that base, so there is no single docker run command from upstream to share.

Environment details

  • OS version: Ubuntu 22.04.5 LTS
  • PyTorch version: 2.10.0.dev20251112+rocm7.1
  • Python version: 3.10.12
  • Transformer Engine version: 2.8.0+a365f2de
  • CUDA version: N/A (ROCm environment)
  • CUDNN version: N/A

Device details

  • GPU model: MI350X / gfx950

Additional context

I also have a runtime failure in a higher-level training stack when fp8_recipe=mxfp8 is enabled, but that runtime failure is in a downstream MoE grouped-GEMM path rather than a direct TE dense-GEMM traceback.

I am filing this issue mainly to clarify the TE-side support boundary, because the installed package currently exposes both:

  • a logical MXFP8 block size of 32, and
  • some ROCm training/helper paths that still require block_len == 128

If helpful, I can provide more environment details or a smaller reproducer.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions