Skip to content

Conversation

@zhongbozhu
Copy link
Collaborator

@zhongbozhu zhongbozhu commented Jan 6, 2026

Description

Fix this: #2558

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
@zhongbozhu zhongbozhu self-assigned this Jan 6, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 6, 2026

Greptile Summary

Fixed critical NVFP4 performance regression by eliminating repeated cudaMallocAsync/cudaFreeAsync calls in hot path. The kernel function group_row_col_rht_gemm_ntt_w_sfc was allocating and freeing a 4-byte tile scheduler workspace on every invocation, causing severe performance degradation (570 TFLOPS to 240-400 TFLOPS).

Changes:

  • Modified group_row_col_rht_gemm_ntt_w_sfc to accept pre-allocated workspace pointer instead of managing allocation internally
  • Updated API signature to pass workspace buffer from caller (nvte_group_hadamard_transform_cast_fusion)
  • PyTorch extension now allocates workspace once per call using at::empty and passes it through the call chain
  • Workspace is reused with cudaMemsetAsync instead of repeated malloc/free operations

This aligns with best practices for CUDA performance optimization by avoiding synchronous allocation operations in frequently-called kernel functions.

Confidence Score: 5/5

  • This PR is safe to merge - it fixes a critical performance bug with a well-understood solution
  • The fix is straightforward and follows CUDA best practices by moving memory allocation out of the hot path. The change is minimal, focused, and directly addresses the root cause of the performance regression described in issue NVFP4 performance regression in TE main branch #2558. The workspace allocation strategy is sound: a small 4-byte buffer allocated once per call and properly validated before use.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu Changed from per-call cudaMallocAsync/cudaFreeAsync to caller-provided workspace, eliminating performance bottleneck from repeated allocations
transformer_engine/common/include/transformer_engine/hadamard_transform.h Added quant_workspace parameter to function signature to support external workspace allocation
transformer_engine/pytorch/csrc/extensions/cast.cpp Allocates 4-byte tile scheduler workspace once per call and passes it to the kernel function

Sequence Diagram

sequenceDiagram
    participant Caller as PyTorch Extension
    participant API as nvte_group_hadamard_transform_cast_fusion
    participant Core as group_hadamard_transform_cast_fusion
    participant Kernel as group_row_col_rht_gemm_ntt_w_sfc
    participant GPU as CUDA Kernel

    Note over Caller: BEFORE (Performance Issue)
    Caller->>API: Call without workspace
    API->>Core: Forward call (no workspace)
    Core->>Kernel: Launch kernel
    Note over Kernel: cudaMallocAsync (slow!)
    Kernel->>GPU: memset workspace to 0
    Kernel->>GPU: Launch CUDA kernel
    Note over GPU: Execute computation
    Note over Kernel: cudaFreeAsync (slow!)
    Kernel-->>Caller: Return

    Note over Caller: AFTER (This Fix)
    Caller->>Caller: Allocate 4-byte workspace once
    Caller->>API: Call with workspace parameter
    API->>Core: Forward call with workspace
    Core->>Core: Extract workspace pointer
    Core->>Kernel: Pass workspace to kernel
    Note over Kernel: memset workspace to 0 (reuse!)
    Kernel->>GPU: Launch CUDA kernel
    Note over GPU: Execute computation
    Kernel-->>Caller: Return (no deallocation needed)
Loading

@zhongbozhu
Copy link
Collaborator Author

/te-ci L1

ksivaman
ksivaman previously approved these changes Jan 6, 2026
Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seem to be a lot of CI failures, ptal @zhongbozhu

@zhongbozhu
Copy link
Collaborator Author

@ksivaman I don't see CI failures from my side, maybe it's because github is not informed super well with the CI status?

/*args=*/kernel_args,
/*rng_state=*/rng_state, /*sm_count=*/sm_count,
/*rng_state=*/rng_state,
/*tile_scheduler_workspace=*/tile_scheduler_workspace,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer a more generic workspace name to be honest. Proper handling of this would also require having some function that would return size of the required workspace.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from API level, it's called quant_workspace now

Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Comment on lines +1403 to +1404
NVTE_CHECK(quant_workspace.data.buffer_size_bytes() >= sizeof(uint32_t),
"Quantization workspace must be at least 4 bytes.");
Copy link
Collaborator

@timmoon10 timmoon10 Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we wanted to be fancy, we could add an option to query the workspace size, similar to how we do it for LayerNorm. If the workspace is not provided, we set the NVTETensor with the required size. This way the caller doesn't need to know the details of the workspace size.

That said, I think this approach is fine for now.

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

@timmoon10
Copy link
Collaborator

/te-ci

@timmoon10 timmoon10 merged commit de51c96 into NVIDIA:main Jan 7, 2026
40 of 42 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working fp4 MoE

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants