From 5a58f34907124095e4f72163fb0828fb30a04693 Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Thu, 25 Sep 2025 05:12:24 +0000 Subject: [PATCH 01/12] Enable nvfp4 quantization * create NVFP4Quantizer at TE cpp side * modify mxfp8_quantize/cast_mxfp8_2D_kernel for nvfp4 generalization * temporary hijack mxfp8 torch side to call to nvfp4 quantization, will revert --- transformer_engine/common/common.h | 50 +++- .../transformer_engine/transformer_engine.h | 10 + .../common/util/cast_gated_kernels.cuh | 4 +- .../common/util/cast_kernels.cuh | 245 ++++++++++++------ .../common/util/dequantize_kernels.cuh | 4 +- .../common/util/pybind_helper.h | 3 +- transformer_engine/common/utils.cuh | 8 + transformer_engine/pytorch/csrc/common.h | 20 +- .../pytorch/csrc/extensions/cast.cpp | 7 +- transformer_engine/pytorch/csrc/pybind.h | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 74 ++++++ 11 files changed, 329 insertions(+), 98 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 22b448a001..64d9c09ad9 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -189,6 +189,7 @@ struct Tensor { } break; case NVTE_MXFP8_1D_SCALING: + case NVTE_NVFP4_1D_SCALING: case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: if (!has_data() && has_columnwise_data()) { return columnwise_data.shape; @@ -561,6 +562,25 @@ struct TypeInfo { NVTE_ERROR("Invalid type."); \ } +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8FP4ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat4E2M1: { \ + using type = fp4e2m1; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ +} + #define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -604,19 +624,23 @@ struct TypeInfo { NVTE_ERROR("Invalid type for 16 bit."); \ } -#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ - switch (SCALE_DIM) { \ - case 1: { \ - constexpr size_t DIM = 1; \ - { __VA_ARGS__ } \ - } break; \ - case 32: { \ - constexpr size_t DIM = 32; \ - { __VA_ARGS__ } \ - } break; \ - default: { \ - NVTE_ERROR("Invalid size of the MX scaling factor."); \ - } \ +#define TRANSFORMER_ENGINE_MXNV_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ + switch (SCALE_DIM) { \ + case 1: { \ + constexpr size_t DIM = 1; \ + { __VA_ARGS__ } \ + } break; \ + case 16: { \ + constexpr size_t DIM = 16; \ + { __VA_ARGS__ } \ + } break; \ + case 32: { \ + constexpr size_t DIM = 32; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Invalid size of the MX/NV scaling factor."); \ + } \ } #define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index dab4fcfe75..6bda3c5d0d 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -92,6 +92,7 @@ enum NVTEScalingMode { and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD). */ NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4, + NVTE_NVFP4_1D_SCALING = 5, NVTE_INVALID_SCALING = 100 }; @@ -431,6 +432,15 @@ inline bool is_fp8_dtype(const DType t) { */ inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; } +/*! \brief Check if TE datatype is FP8 or FP4 + * + * Return true if TE datatype is FP8 or FP4 + * \param[in] DType TE Datatype of interest + */ +inline bool is_narrow_dtype(const DType t) { + return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2 || t == DType::kFloat4E2M1; +} + /*! \struct TensorWrapper * \brief C++ wrapper for the NVTETensor class. */ diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index e2d9ecc519..0e0c481dda 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -831,9 +831,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const dim3 block_dim(THREADS_PER_CHUNK); const dim3 grid_dim(blocks_X, blocks_Y); - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + TRANSFORMER_ENGINE_MXNV_SCALE_DIM_SWITCH( scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + TRANSFORMER_ENGINE_MXNV_SCALE_DIM_SWITCH( scale_dim_X_rowwise, SCALE_DIM_X, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( gated_input.dtype(), IType, diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 610cbf41fa..9638a0595d 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -56,13 +56,13 @@ static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); template + size_t SCALE_DIM_X, typename ScaleType> __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) - cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + cast_mxnv_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_act_input, const __grid_constant__ CUtensorMap tensor_map_output_rowwise, const __grid_constant__ CUtensorMap tensor_map_output_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + ScaleType *const scales_rowwise, ScaleType *const scales_colwise, const float *noop, float *const dbias_workspace, float *const amax_ptr, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { @@ -75,23 +75,25 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 64 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // mx:2, nv:4 constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = - SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 64 constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = - SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // mx:2, nv:4 - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // mx:2, nv:4 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = - SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // mx:2, nv:4 constexpr size_t SCALES_COLWISE_PER_BLOCK_X = - SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 constexpr size_t THREADS_PER_SCALE_X_ROWWISE = - DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 - constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // mx:2, nv:1 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // mx:2, nv:1 + + constexpr size_t SCALES_COLWISE_PER_BUFFER_Y = MXFP8_BUFFER_DIM_Y / SCALE_DIM_Y; // mx:1, nv:2 const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; @@ -132,14 +134,24 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } } } + const size_t packing = [&] { + if constexpr (std::is_same_v || std::is_same_v) { + return 1; + } else if constexpr (std::is_same_v) { + return 2; + } else { + static_assert(!std::is_same_v, "Unsupported OType"); + return 0; + } + }(); // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned __shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; __shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; __shared__ alignas(128) - OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X/packing]; __shared__ alignas(128) - OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y/packing][MXFP8_SHMEM_DIM_X]; constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; @@ -275,10 +287,26 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __builtin_assume(thread_amax >= 0); block_amax = fmaxf(block_amax, thread_amax); - const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); - const e8m0_t biased_exponent = - float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); + float subwarp_amax; + if constexpr (SUBWARP_WIDTH == 1) { + // 1 thread has amax of 16 elements, therefore if width is 1 (block size of 16), we have the amax for the block + subwarp_amax = thread_amax; + } else { + // block size 32 would have SUBWARP_WIDTH=2, need to choose larger betweent the two + subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); + } + + const float x = subwarp_amax * Quantized_Limits::max_norm_rcp; + ScaleType thread_scales_rowwise{}; + if constexpr (std::is_same_v) { + thread_scales_rowwise = float_to_e8m0(x); // power of 2 values + } else if constexpr (std::is_same_v) { + thread_scales_rowwise = ScaleType(x); + } else { + static_assert(!std::is_same_v, "Unsupported ScaleType"); + } + // Only single thread writes the computed scaling factor if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { const int global_scales_offset_Y = @@ -287,16 +315,37 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; const int scale_idx = global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; + scales_rowwise[scale_idx] = thread_scales_rowwise; } - const float block_scale_inverse = exp2f_rcp(biased_exponent); + const float block_scale_inverse = [&] { + if constexpr (std::is_same_v) { + return exp2f_rcp(thread_scales_rowwise); + } else if constexpr (std::is_same_v) { + return __frcp_rn((float)thread_scales_rowwise); + } else { + static_assert(!std::is_same_v, "Unsupported ScaleType"); + return 0.0f; + } + }(); + if constexpr (std::is_same_v) { + // pack 2-fp4 as a byte + uint8_t *s_mem = reinterpret_cast(&out_rowwise_sh[buff][shmem_offset_y][0]); + const int packed_shmem_offset_x = shmem_offset_x / 2; #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + for (int j = 0; j < ELEMS_PER_THREAD / 2; ++j) { + // Use the CUDA intrinsic to convert and pack float2 -> fp4x2 (uint8_t) + const float2 f2 = {in_compute[2*j]*block_scale_inverse, in_compute[2*j+1] * block_scale_inverse}; + s_mem[packed_shmem_offset_x + j] = __nv_cvt_float2_to_fp4x2(f2, __NV_E2M1, cudaRoundNearest); + } + } else { +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + } + out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); } - out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); } } @@ -305,53 +354,88 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) float in_compute[SCALE_DIM_Y]; float amax = 0; + + for (int j = 0; j < SCALES_COLWISE_PER_BUFFER_Y; ++j) { + // if block size 16, total j = 2 #pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - const size_t row = row_base + i; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const size_t row = row_base + j*SCALE_DIM_Y + i; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); - float elt = static_cast(in_sh[buff][i][tid_colwise_X]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[buff][i][tid_colwise_X]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - if (!out_of_bounds) { - partial_dbias_colwise[chunk_X] += elt; + float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); } - } - in_compute[i] = elt; - if constexpr (IS_ACT || IS_DACT) { - if (!out_of_bounds) { + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][i][tid_colwise_X]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if (!out_of_bounds) { + partial_dbias_colwise[chunk_X] += elt; + } + } + in_compute[i] = elt; + if constexpr (IS_ACT || IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this amax = fmaxf(amax, fabsf(elt)); } - } else { - // If no activation, elt is 0 so we can safely do this - amax = fmaxf(amax, fabsf(elt)); } - } - __builtin_assume(block_amax >= 0); - __builtin_assume(amax >= 0); - block_amax = fmaxf(block_amax, amax); + __builtin_assume(block_amax >= 0); + __builtin_assume(amax >= 0); + block_amax = fmaxf(block_amax, amax); - const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + const float x = amax * Quantized_Limits::max_norm_rcp; - const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + ScaleType thread_scales_colwise{}; + if constexpr (std::is_same_v) { + thread_scales_colwise = float_to_e8m0(x); + } else if constexpr (std::is_same_v) { + thread_scales_colwise = ScaleType(x); + } else { + static_assert(!std::is_same_v, "Unsupported ScaleType"); + } - const float block_scale_inverse = exp2f_rcp(biased_exponent); -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - out_colwise_sh[buff][i][tid_colwise_X] = - static_cast(in_compute[i] * block_scale_inverse); + const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter*SCALES_COLWISE_PER_BUFFER_Y + j; // iter is buffer steps + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = thread_scales_colwise; + + const float block_scale_inverse = [&] { + if constexpr (std::is_same_v) { + return exp2f_rcp(thread_scales_colwise); + } else if constexpr (std::is_same_v) { + return __frcp_rn(x); + } else { + static_assert(!std::is_same_v, "Unsupported ScaleType"); + return 0.0f; + } + }(); + + if constexpr (std::is_same_v) { + uint8_t *s_mem = + reinterpret_cast(&out_colwise_sh[buff][0][tid_colwise_X]); + for (int i = 0; i < SCALE_DIM_Y / 2; ++i) { + // float2 {x, y} + // __nv_cvt_float2_to_fp4x2 packs FP4(y), FP4(x) into a byte + const float2 f2 = {in_compute[2*i] * block_scale_inverse, + in_compute[2*i+1] * block_scale_inverse}; + s_mem[i] = __nv_cvt_float2_to_fp4x2(f2, __NV_E2M1, cudaRoundNearest); + } + } else { + #pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + out_colwise_sh[buff][i][tid_colwise_X] = + static_cast(in_compute[i] * block_scale_inverse); + } + } } } @@ -917,15 +1001,15 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T } template -void mxfp8_quantize(const Tensor &input, const Tensor *act_input, + float (*OP)(float, const ParamOP &), const size_t ScaleDim, typename ScaleType> +void mxnv_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // TODO (ksivamani) Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { bool use_rowwise_scaling = output->has_data(); bool use_colwise_scaling = output->has_columnwise_data(); checkCuDriverContext(stream); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(is_narrow_dtype(output->dtype()), "Output must have FP8/FP4 type."); if (use_rowwise_scaling) { NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); @@ -937,8 +1021,8 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, CheckNoopTensor(*noop, "cast_noop"); // TODO: Make more general - const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; - const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? ScaleDim : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? ScaleDim : 1; const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); @@ -951,10 +1035,10 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t scale_stride_colwise = use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - e8m0_t *const scales_rowwise_ptr = - use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; - e8m0_t *const scales_colwise_ptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + ScaleType *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + ScaleType *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; @@ -976,13 +1060,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const dim3 block(MXFP8_THREADS_PER_CHUNK); const dim3 grid(blocks_X, blocks_Y); - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + TRANSFORMER_ENGINE_MXNV_SCALE_DIM_SWITCH( scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + TRANSFORMER_ENGINE_MXNV_SCALE_DIM_SWITCH( scale_dim_X_rowwise, SCALE_DIM_X, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8FP4ONLY( output->dtype(), OType, alignas(64) CUtensorMap tensor_map_input{}; @@ -1011,8 +1095,8 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, typeToNumBits(output->dtype())); } - cast_mxfp8_2D_kernel<<>>( + cast_mxnv_2D_kernel<<>>( tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, @@ -1144,8 +1228,9 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons break; } case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize(input, act_input, noop, output, dbias, - workspace, stream); + mxnv_quantize(input, act_input, noop, + output, dbias, workspace, + stream); break; } default: @@ -1260,7 +1345,13 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o break; } case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize( + mxnv_quantize( + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + break; + } + case NVTE_NVFP4_1D_SCALING: { + mxnv_quantize( *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); break; diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index e716065abd..48bc3e6b6f 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -306,9 +306,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s const dim3 block(THREADS_PER_CHUNK); const dim3 grid(chunks_X, chunks_Y); - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + TRANSFORMER_ENGINE_MXNV_SCALE_DIM_SWITCH( scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + TRANSFORMER_ENGINE_MXNV_SCALE_DIM_SWITCH( scale_dim_X_rowwise, SCALE_DIM_X, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( input.dtype(), IType, diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index a1cd85ba2a..d1ff7d6de2 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -22,7 +22,8 @@ .value("kFloat16", transformer_engine::DType::kFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ + .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index e6a54108ed..62bf5618de 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -7,6 +7,7 @@ #ifndef TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ #define TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ +#include #include #include #include @@ -903,6 +904,7 @@ __device__ __forceinline__ void reciprocal(float *value_inv, const float using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +using fp4e2m1 = __nv_fp4_e2m1; using e8m0_t = uint8_t; constexpr uint32_t FP32_MANTISSA_BITS = 23; @@ -925,6 +927,12 @@ struct Numeric_Traits { static constexpr double maxNorm = 57344; }; +template <> +struct Numeric_Traits { + static constexpr int maxUnbiasedExponent = 2; + static constexpr double maxNorm = 6; +}; + template struct Quantized_Limits { static constexpr int max_unbiased_exponent = Numeric_Traits::maxUnbiasedExponent; diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index d8c08651f2..1271cde1f5 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -207,7 +207,7 @@ class MXFP8Quantizer : public Quantizer { explicit MXFP8Quantizer(const py::handle& quantizer); - NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; } + NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; } void set_quantization_params(TensorWrapper* tensor) const override; @@ -216,6 +216,24 @@ class MXFP8Quantizer : public Quantizer { std::optional rowwise_data = std::nullopt) const override; }; + +class NVFP4Quantizer final : public MXFP8Quantizer { + public: + explicit NVFP4Quantizer(const py::handle& quantizer) + : MXFP8Quantizer(quantizer) { + } + + NVTEScalingMode get_scaling_mode() const override { + return NVTE_NVFP4_1D_SCALING; + } + + // use MXFP8Quantizer set_quantization_params since it is common param setting operation + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + std::unique_ptr convert_quantizer(py::handle quantizer); std::vector getTensorShape(at::Tensor t); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 4be2a8880e..a60108cd23 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -25,7 +25,12 @@ namespace { std::vector get_tensor_shape(const TensorWrapper &tensor) { const auto &shape = tensor.shape(); - return std::vector(shape.data, shape.data + shape.ndim); + std::vector logical_shape(shape.data, shape.data + shape.ndim); + + if (tensor.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + logical_shape[1] *= 2; // we only support 2D tensor for NVFP4 for now, and shape() + } + return logical_shape; } void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py, diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 9fd1ae4de9..301081a15e 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -98,7 +98,7 @@ constexpr std::array custom_types_converters = { std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, CreateQuantizer), std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, - CreateQuantizer), + CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index dc4d55d2fc..41ae7f9134 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -13,6 +13,7 @@ namespace transformer_engine::pytorch { constexpr size_t MXFP8_BLOCK_SIZE = 32; +constexpr size_t NVFP4_BLOCK_SIZE = 16; Quantizer::Quantizer(const py::handle& quantizer) { if (quantizer.is_none()) { @@ -534,4 +535,77 @@ std::pair MXFP8Quantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } + +std::pair NVFP4Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + + NVTE_CHECK(shape.size() == 2, "NVFP4 currently only support 2D tensor, got ndim=", shape.size(), ")"); + + size_t numel = 1; + for (auto s : shape) { + numel *= s; + } + + auto last_dim = shape.back(); + NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE== 0 && (numel / last_dim) % NVFP4_BLOCK_SIZE == 0, + "NVFP4 requires tensor dims that are divisble by ", NVFP4_BLOCK_SIZE, + " (got shape=", shape, ")"); + + std::vector rowwise_data_shape = {static_cast(shape[0]) , static_cast(shape[1]/2)}; + std::vector columnwise_data_shape = {static_cast(shape[0])/2, static_cast(shape[1]) }; + + TensorWrapper tensor(NVTE_NVFP4_1D_SCALING); + at::TensorOptions opts; + at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, columnwise_scale_inv; // TODO(pgadzinski) - change + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(rowwise_data_shape, opts); + } + auto sinv0 = roundup(numel / last_dim, 128); + auto sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); + rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + tensor.set_rowwise_data(data.data_ptr(), this->dtype, std::vector{shape[0], shape[1]/2}); + tensor.set_rowwise_scale_inv( + rowwise_scale_inv.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(sinv0), static_cast(sinv1)}); + } + + if (columnwise_usage) { + auto sinv0 = roundup(numel / (last_dim * NVFP4_BLOCK_SIZE), 4); + auto sinv1 = roundup(last_dim, 128); + columnwise_data = at::empty(columnwise_data_shape, opts); + columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, std::vector{shape[0]/2, shape[1]}); + tensor.set_columnwise_scale_inv( + columnwise_scale_inv.data_ptr(), DType::kFloat8E4M3, + std::vector{static_cast(sinv0), static_cast(sinv1)}); + } + this->set_quantization_params(&tensor); + + py::object ret; + if (internal) { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } else { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); + ret = MXFP8TensorClass("shape"_a = rowwise_data_shape, "dtype"_a = GetATenDType(dtype), // TODO; should we use logical shape? or + "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } + + return {std::move(tensor), std::move(ret)}; +} + } // namespace transformer_engine::pytorch From beebe775ffc3db76cfc6781fcd57087c6f7ede5d Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Thu, 25 Sep 2025 21:12:50 +0000 Subject: [PATCH 02/12] Enable nvfp4 dequantization * generalize dequantize_mxfp8_kernel to dequantize_mxnv_kernel --- transformer_engine/common/common.h | 2 + .../common/transformer_engine.cpp | 23 +-- .../common/util/dequantize_kernels.cuh | 137 +++++++++++++----- .../pytorch/csrc/extensions/cast.cpp | 6 +- transformer_engine/pytorch/csrc/pybind.h | 4 +- .../pytorch/csrc/type_converters.cpp | 32 ++++ 6 files changed, 157 insertions(+), 47 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 64d9c09ad9..0c95686983 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -50,6 +50,8 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } +inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; } + inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 6c395837fb..6d67d0ee84 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -131,23 +131,26 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { void CheckInputTensor(const Tensor &t, const std::string &name) { const DType type = t.dtype(); - if (is_fp8_dtype(type)) { - // FP8 input needs to have scale_inv + if (is_narrow_dtype(type)) { + // FP8/FP4 input needs to have scale_inv if (t.has_data()) { - NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8/FP4 scaling factor input ", name, "_scale_inverse must be allocated"); - NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0, - "FP8 scaling factor input ", name, + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || + t.scale_inv.dtype == DType::kFloat8E8M0 || + t.scale_inv.dtype == DType::kFloat8E4M3 , + "FP8/FP4 scaling factor input ", name, "_scale_inverse has invalid dtype " "(expected Float32 or Byte, got ", to_string(t.scale_inv.dtype), ")"); } if (t.has_columnwise_data()) { - NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8/FP4 scaling factor input ", name, "_columnwise_scale_inverse must be allocated"); NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 || - t.columnwise_scale_inv.dtype == DType::kFloat8E8M0, - "FP8 scaling factor input ", name, + t.columnwise_scale_inv.dtype == DType::kFloat8E8M0 || + t.columnwise_scale_inv.dtype == DType::kFloat8E4M3 , + "FP8/FP4 scaling factor input ", name, "_columnwise_scale_inverse has invalid dtype " "(expected Float32 or Byte, got ", to_string(t.columnwise_scale_inv.dtype), ")"); @@ -166,8 +169,8 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { const DType type = t.dtype(); - if (is_fp8_dtype(type)) { - // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax + if (is_narrow_dtype(type)) { + // FP8/FP4 output needs to have scale, scale_inv and (if delayed scaling) amax if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) { NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ", to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")"); diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index 48bc3e6b6f..c56a27d290 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -48,26 +48,36 @@ constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16 static_assert(ITERATIONS >= 1); -template +template __global__ void __launch_bounds__(THREADS_PER_CHUNK) - dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, - const size_t scales_stride) { + dequantize_mxnv_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const ScaleType *const scales_ptr, const size_t rows, const size_t cols, + const size_t scales_stride) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + const size_t packing = [&] { + if constexpr (std::is_same_v || std::is_same_v) { + return 1; + } else if constexpr (std::is_same_v) { + return 2; + } else { + static_assert(!std::is_same_v, "Unsupported OType"); + return 0; + } + }(); constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // mx:4, nv:8 - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // mx:4, nv:8 constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 constexpr size_t THREADS_PER_SCALE_X_ROWWISE = - DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // mx:2, nv:1 const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X / packing; const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; @@ -84,7 +94,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // const int thread_offset_X_colwise = tid_colwise_X; // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - __shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X/packing]; __shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; @@ -165,8 +175,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) : (scales_colwise_chunk_offset_X + tid_colwise_X); const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; - const e8m0_t biased_exponent = scales_ptr[scale_idx]; - const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + + const float block_scale = [&] { + if constexpr (std::is_same_v) { + const e8m0_t biased_exponent = scales_ptr[scale_idx]; + return exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + } else if constexpr (std::is_same_v) { + return static_cast(scales_ptr[scale_idx]); + } else { + static_assert(!std::is_same_v, "Unsupported ScaleType"); + return 0.0f; + } + }(); if constexpr (USE_ROWWISE_SCALING) { Vec in; @@ -177,15 +197,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - out.data.elt[j] = static_cast(block_scale * static_cast(in.data.elt[j])); + for (int j = 0; j < ELEMS_PER_THREAD/packing; ++j) { + if constexpr (std::is_same_v || std::is_same_v) { + out.data.elt[j] = static_cast(block_scale * static_cast(in.data.elt[j])); + } else if constexpr (std::is_same_v) { + // fp4(y), fp4(x) -> fp16.x, fp16.y (no need special handling, just reversing the convention of how we pack) + __half2_raw hfraw2 = __nv_cvt_fp4x2_to_halfraw2(in.data.elt[j].__x, __NV_E2M1); + __half2 h2; + memcpy(&h2, &hfraw2, sizeof(h2)); + out.data.elt[j*2] = static_cast(block_scale * static_cast(h2.x)); + out.data.elt[j*2+1] = static_cast(block_scale * static_cast(h2.y)); + } else { + static_assert(!std::is_same_v, "Unsupported IType"); + } } out.store_to(&out_sh[buff][shmem_offset_y][shmem_offset_x]); } else { #pragma unroll - for (int i = 0; i < BUFFER_DIM_Y; ++i) { - const float elt = static_cast(in_sh[buff][i][tid_colwise_X]); - out_sh[buff][i][tid_colwise_X] = static_cast(block_scale * elt); + for (int i = 0; i < BUFFER_DIM_Y/packing; ++i) { + if constexpr (std::is_same_v || std::is_same_v) { + const float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + out_sh[buff][i][tid_colwise_X] = static_cast(block_scale * elt); + } else if constexpr (std::is_same_v) { + // fp4(y), fp4(x) -> fp16.x, fp16.y (no need special handling, just reversing the convention of how we pack) + __half2_raw hfraw2 = __nv_cvt_fp4x2_to_halfraw2(in_sh[buff][i][tid_colwise_X].__x, __NV_E2M1); + __half2 h2; + memcpy(&h2, &hfraw2, sizeof(h2)); + out_sh[buff][i*2][tid_colwise_X] = static_cast(block_scale * static_cast(h2.x)); + out_sh[buff][i*2+1][tid_colwise_X] = static_cast(block_scale * static_cast(h2.y)); + } else { + static_assert(!std::is_same_v, "Unsupported IType"); + } } } @@ -247,33 +289,49 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str ); // NOLINT(*) } -static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +template +static void mxnv_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { bool use_rowwise_scaling = input.has_data(); bool use_colwise_scaling = input.has_columnwise_data(); checkCuDriverContext(stream); - const auto &input_shape = input.data.shape; + auto input_shape = input.data.shape; NVTE_CHECK(input_shape.size() >= 2, "Input must have at least 2 dimensions."); + if (input.scaling_mode == NVTE_NVFP4_1D_SCALING) { + input_shape[1] *= 2; // tensor.shape() use rowwise data shape. + } if (use_rowwise_scaling) { NVTE_CHECK(input.has_data(), "Cannot dequantize tensor without rowwise data."); - NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + NVTE_CHECK(is_narrow_dtype(input.data.dtype), "Input must have FP8 type."); } if (use_colwise_scaling) { NVTE_CHECK(input.has_columnwise_data(), "Cannot dequantize tensor without columnwise data."); - NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type."); + NVTE_CHECK(is_narrow_dtype(input.columnwise_data.dtype), "Input must have FP8 type."); } - NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + NVTE_CHECK(!is_narrow_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input_shape, "Input and output shapes need to match."); // TODO: Make more general - const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; - const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? ScaleDim : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? ScaleDim : 1; + + // t.flat_first_dim()/.flat_first_dim() depends on t.shape() which has a rather odd design + // t.shape returns rowwise data shape when it exists and + // shape of colwise data shape iff only rowwise data does not exist + // when both exist, rowwise data shape gets returned. + // rows, cols are logical dim. + size_t rows, cols; + if (input.has_data()) { + rows = input.flat_first_dim(); + cols = input.scaling_mode == NVTE_NVFP4_1D_SCALING ? input.flat_last_dim() * 2 : input.flat_last_dim(); + } else if (input.has_columnwise_data()) { + rows = input.scaling_mode == NVTE_NVFP4_1D_SCALING ? input.flat_first_dim() * 2 : input.flat_first_dim(); + cols = input.flat_last_dim(); + } - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); @@ -295,9 +353,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s DIVUP(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise) * scale_tensor_alignment_X_colwise; - const e8m0_t *const scales_ptr = - use_rowwise_scaling ? reinterpret_cast(input.scale_inv.dptr) - : reinterpret_cast(input.columnwise_scale_inv.dptr); + const ScaleType *const scales_ptr = + use_rowwise_scaling ? reinterpret_cast(input.scale_inv.dptr) + : reinterpret_cast(input.columnwise_scale_inv.dptr); const size_t scales_stride = use_rowwise_scaling ? scales_X_rowwise : scales_X_colwise; @@ -310,7 +368,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s scale_dim_Y_colwise, SCALE_DIM_Y, TRANSFORMER_ENGINE_MXNV_SCALE_DIM_SWITCH( scale_dim_X_rowwise, SCALE_DIM_X, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8FP4ONLY( input.dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( output->dtype(), OType, @@ -323,7 +381,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, cols, 0, typeToNumBits(output->dtype())); - dequantize_mxfp8_kernel + dequantize_mxnv_kernel <<>>(tensor_map_input, tensor_map_output, scales_ptr, rows, cols, scales_stride);); // NOLINT(*) ); // NOLINT(*) @@ -340,11 +398,20 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) if (is_tensor_scaling(input.scaling_mode)) { dequantization::fp8_dequantize(input, output, stream); - } else if (is_mxfp_scaling(input.scaling_mode)) { + } else if (is_mxfp_scaling(input.scaling_mode) || is_nvfp_scaling(input.scaling_mode)) { if (is_supported_by_CC_100()) { - dequantization::mxfp8_dequantize(input, output, stream); + switch (input.scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + dequantization::mxnv_dequantize<32, e8m0_t>(input, output, stream); + break; + case NVTE_NVFP4_1D_SCALING: + dequantization::mxnv_dequantize<16, fp8e4m3>(input, output, stream); + break; + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + NVTE_ERROR("MXFP8/NVFP4 Dequantization is NOT supported by architectures < 10.0"); } } else { // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index a60108cd23..29657f0290 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -131,7 +131,11 @@ py::object dequantize(const py::handle &input, transformer_engine::DType otype) NoneQuantizer q(none); - const auto &shape = convertShape(input_tensor.shape()); + auto shape = convertShape(input_tensor.shape()); + + if (input_tensor.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + shape[1] *= 2; // assumption: always 2D input, corresponding to rowwise data + } auto [out_tensor, out] = q.create_tensor(shape, otype); diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 301081a15e..079081e96e 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -83,6 +83,8 @@ std::unique_ptr CreateQuantizer(const py::handle quantizer) { TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantization_params); +TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantization_params); + std::unique_ptr CreateMXFP8Params(const py::handle params); TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, @@ -97,7 +99,7 @@ constexpr std::array custom_types_converters = { CreateQuantizer), std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, CreateQuantizer), - std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, + std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromNVFP4Tensor, CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index cb2121a457..7b7e2d433d 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -84,6 +84,38 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) { + auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor."); + + // Row-scaled data + const DType fp8_dtype = tensor.attr("_fp8_dtype").cast(); + if (rowwise_usage) { + const auto &data = tensor.attr("_rowwise_data").cast(); + const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); + ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); + } + + // Column-scaled data + if (columnwise_usage) { + const auto &data = tensor.attr("_columnwise_data").cast(); + const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); + ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, + getTensorShape(scale_inv)); + } + + // Quantizer state + quantizer->set_quantization_params(&ret); + + return ret; +} + TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantizer) { const DType dtype = tensor.attr("_fp8_dtype").cast(); bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); From 0fc33285ed79e2d7d807bcfa1f6998d2adf64f2b Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Thu, 25 Sep 2025 21:57:41 +0000 Subject: [PATCH 03/12] De-hijack/restoring mxfp8 block scaling recipe * create nvfp4 extension interface but not fully enabled. * mxfp8 trainablility restored. --- transformer_engine/pytorch/csrc/common.h | 2 +- .../pytorch/csrc/extensions/pybind.cpp | 19 +++++++++++++++++++ transformer_engine/pytorch/csrc/pybind.h | 14 +++++++++++++- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 1271cde1f5..17a81b2c13 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -207,7 +207,7 @@ class MXFP8Quantizer : public Quantizer { explicit MXFP8Quantizer(const py::handle& quantizer); - NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; } + NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; } void set_quantization_params(TensorWrapper* tensor) const override; diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 83f5291177..67a8af4ecd 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -32,6 +32,9 @@ PyTypeObject *MXFP8QuantizerClass = nullptr; PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; +PyTypeObject *NVFP4TensorPythonClass = nullptr; /// TODO Remove +PyTypeObject *NVFP4TensorBasePythonClass = nullptr; +PyTypeObject *NVFP4QuantizerClass = nullptr; void init_float8_extension() { if (Float8TensorPythonClass) return; @@ -65,6 +68,21 @@ void init_mxfp8_extension() { "Internal error: could not initialize pyTorch MXFP8 extension."); } +void init_nvfp4_extension() { + if (NVFP4TensorPythonClass) return; + auto fp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); + NVFP4QuantizerClass = + reinterpret_cast(PyObject_GetAttrString(fp4_module.ptr(), "NVFP4Quantizer")); + NVFP4TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(fp4_module.ptr(), "NVFP4Tensor")); + auto fp4_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.nvfp4_tensor_base"); // TODO(VS) need to organize python side later + NVFP4TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp4_base_module.ptr(), "NVFP4TensorBase")); + NVTE_CHECK(NVFP4TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch NVFP4 extension."); +} + void init_float8blockwise_extension() { if (Float8BlockwiseQTensorBasePythonClass) return; auto fp8_module = @@ -89,6 +107,7 @@ void init_float8blockwise_extension() { void init_extension() { init_float8_extension(); init_mxfp8_extension(); +// init_nvfp4_extension(); init_float8blockwise_extension(); } diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 079081e96e..254150f818 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -40,6 +40,9 @@ extern PyTypeObject *MXFP8QuantizerClass; extern PyTypeObject *Float8BlockwiseQTensorPythonClass; extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; extern PyTypeObject *Float8BlockwiseQuantizerClass; +extern PyTypeObject *NVFP4TensorPythonClass; +extern PyTypeObject *NVFP4TensorBasePythonClass; +extern PyTypeObject *NVFP4QuantizerClass; void init_extension(); @@ -47,6 +50,7 @@ void init_float8_extension(); void init_mxfp8_extension(); +void init_nvfp4_extension(); namespace detail { inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } @@ -65,6 +69,12 @@ inline bool IsMXFP8Tensor(PyObject *obj) { return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; } +inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4QuantizerClass; } + +inline bool IsNVFP4Tensor(PyObject *obj) { + return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorBasePythonClass; +} + inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; } @@ -99,7 +109,9 @@ constexpr std::array custom_types_converters = { CreateQuantizer), std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, CreateQuantizer), - std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromNVFP4Tensor, + std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, + CreateQuantizer), + std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; From 654efc1173521756cfb9f484920bbae3493bd1e5 Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Thu, 25 Sep 2025 23:09:12 +0000 Subject: [PATCH 04/12] Create NVFP4BlockScaling Recipe * create NVFP4BlockScaling, NVFP4BlockScalingRecipeState class * subclassing: - NVFP4TensorBase(MXFP8TensorBase) - NVFP4Quantizer(MXFP8Quantizer) - NVFP4Tensor(MXFP8Tensor) --- transformer_engine/common/recipe/__init__.py | 31 +- transformer_engine/pytorch/constants.py | 1 + .../pytorch/csrc/extensions/pybind.cpp | 4 +- transformer_engine/pytorch/csrc/quantizer.cpp | 8 +- transformer_engine/pytorch/fp8.py | 38 +++ .../tensor/_internal/mxfp8_tensor_base.py | 7 + .../pytorch/tensor/nvfp4_tensor.py | 264 ++++++++++++++++++ 7 files changed, 346 insertions(+), 7 deletions(-) create mode 100644 transformer_engine/pytorch/tensor/nvfp4_tensor.py diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index fc8d73a136..13e6a00635 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -22,10 +22,12 @@ class _FormatHelper(NamedTuple): class Format(Enum): """ - Supported FP8 formats. + Supported FP8/FP4 formats. Values ------ + E2M1 : + All FP8 tensors are in e2m1 format E4M3 : All FP8 tensors are in e4m3 format E5M2 : @@ -35,6 +37,7 @@ class Format(Enum): FP8 tensors in the backward pass are in e5m2 format """ + E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) @@ -66,6 +69,10 @@ class Recipe: Base recipe class. """ + def nvfp4(self): + """Whether the given recipe is NVFP4 block scaling.""" + return isinstance(self, NVFP4BlockScaling) + def mxfp8(self): """Whether the given recipe is MXFP8 block scaling.""" return isinstance(self, MXFP8BlockScaling) @@ -232,6 +239,28 @@ def __repr__(self) -> str: ) +@dataclass() +class NVFP4BlockScaling(Recipe): + """ + TODO(VS): documentation + abusing fp8 prefix now as refactoring requires broader changes + """ + margin: int = 0 + fp8_format: Format = Format.E2M1 + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp8_format == Format.E2M1, "Only E2M1 training is supported, fwd and bwd in E2M1." + + def __repr__(self) -> str: + return ( + f"recipe_type={self.__class__.__name__}, " + f"margin={self.margin}, " + f"format={str(self.fp8_format).split('.')[1]}" + ) + + @dataclass() class MXFP8BlockScaling(Recipe): """ diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index d1470e22e3..935f7578e8 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -89,3 +89,4 @@ dist_group_type = torch.distributed.ProcessGroup MXFP8_BLOCK_SCALING_SIZE = 32 +NVFP4_BLOCK_SCALING_SIZE = 16 diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 67a8af4ecd..f0bbe0277e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -76,7 +76,7 @@ void init_nvfp4_extension() { NVFP4TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(fp4_module.ptr(), "NVFP4Tensor")); auto fp4_base_module = - py::module_::import("transformer_engine.pytorch.tensor._internal.nvfp4_tensor_base"); // TODO(VS) need to organize python side later + py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base"); // NVFP4TensorBase is currently lives here, TODO: generalize MXFP8TensorBase NVFP4TensorBasePythonClass = reinterpret_cast( PyObject_GetAttrString(fp4_base_module.ptr(), "NVFP4TensorBase")); NVTE_CHECK(NVFP4TensorPythonClass != nullptr, @@ -107,7 +107,7 @@ void init_float8blockwise_extension() { void init_extension() { init_float8_extension(); init_mxfp8_extension(); -// init_nvfp4_extension(); + init_nvfp4_extension(); init_float8blockwise_extension(); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 41ae7f9134..ff8c076136 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -591,14 +591,14 @@ std::pair NVFP4Quantizer::create_tensor( py::object ret; if (internal) { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); - ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorBasePythonClass)); + ret = NVFP4TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, "rowwise_scale_inv"_a = rowwise_scale_inv, "columnwise_scale_inv"_a = columnwise_scale_inv, "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } else { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - ret = MXFP8TensorClass("shape"_a = rowwise_data_shape, "dtype"_a = GetATenDType(dtype), // TODO; should we use logical shape? or + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); + ret = NVFP4TensorClass("shape"_a = rowwise_data_shape, "dtype"_a = GetATenDType(dtype), // TODO; should we use logical shape? or "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, "rowwise_scale_inv"_a = rowwise_scale_inv, "columnwise_scale_inv"_a = columnwise_scale_inv, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 5ef5132c8b..fb6cf3061f 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -19,6 +19,7 @@ DelayedScaling, Format, MXFP8BlockScaling, + NVFP4BlockScaling, Float8CurrentScaling, Float8BlockScaling, ) @@ -87,6 +88,12 @@ def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType return tex.DType.kFloat8E5M2 +def get_fp4_te_dtype(fp4_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: + """Get fp4 data type according to recipe and tensor""" + # TODO(VS) recipe still abusing fp8 prefix, change in future + if fp4_recipe.fp8_format == Format.E2M1: + return tex.DType.kFloat4E2M1 + def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( @@ -813,6 +820,8 @@ def create( cls = DelayedScalingRecipeState elif recipe.mxfp8(): cls = MXFP8BlockScalingRecipeState + elif recipe.nvfp4(): + cls = NVFP4BlockScalingRecipeState elif recipe.float8_current_scaling(): cls = Float8CurrentScalingRecipeState elif recipe.float8_block_scaling(): @@ -961,6 +970,35 @@ def make_quantizers(self) -> list: return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] +class NVFP4BlockScalingRecipeState(RecipeState): + """TODO(VS) documentation""" + recipe: NVFP4BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: NVFP4BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp4_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + # TODO; Find better design for this, adding here to avoid circular import. + from .tensor.nvfp4_tensor import NVFP4Quantizer + + return [NVFP4Quantizer(self.dtype) for i in range(self.num_quantizers)] + class Float8BlockScalingRecipeState(RecipeState): """Configuration for Float8BlockScaling quantization. diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index ae00a4d72b..4ec49392b7 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -200,3 +200,10 @@ def update_usage( else: self._columnwise_data = None self._columnwise_scale_inv = None + +class NVFP4TensorBase(MXFP8TensorBase): + pass + # Note: probably good to generalize MXFP8TensorBase as MXNVTensorBase in the future + # but too much change at this point, we just subclass it for NVFP4 + # fp8_dtype: TE_DType attr name is misleading for NVFP4 but all fp8 recipe uses it, + # would require some refactoring to change it. \ No newline at end of file diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py new file mode 100644 index 0000000000..ed035e9c43 --- /dev/null +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -0,0 +1,264 @@ +from __future__ import annotations +from collections.abc import Iterable +import math +from typing import Optional, Tuple, Union + +import torch +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe +from ..constants import NVFP4_BLOCK_SCALING_SIZE +from ..utils import devices_match, round_up_to_nearest_multiple +from .quantized_tensor import QuantizedTensor, Quantizer +from .mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor + +aten = torch.ops.aten + +class NVFP4Quantizer(MXFP8Quantizer): + # no override on + # __init__ + # calibrate + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + + assert isinstance(dst, NVFP4Tensor), f"Cannot store quantized NVFP4 in {type(dst)} type." + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + # Update FP8 dtype + dst._fp8_dtype = self.dtype + + return dst + + def is_quantizable(self, inp: torch.Tensor) -> bool: + """Returns whether or not given inp can be quantized""" + if inp.ndim < 2: + return False + if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0: + return False + if math.prod(inp.shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE != 0: + return False + return True + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> NVFP4Tensor: + + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") + + assert ( + shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0 + and math.prod(shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE == 0 + ), ( + f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" + f" {NVFP4_BLOCK_SCALING_SIZE}" + ) + + # Allocate FP8 data TODO(VS), do we pack fp4 + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_inv = torch.zeros( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // NVFP4_BLOCK_SCALING_SIZE, 4), + dtype=torch.uint8, + device=device, + ) + + # Allocate FP8 data transpose if needed TODO(VS), do we pack fp4 + columnwise_data = None + columnwise_scale_inv = None + if self.columnwise_usage: + columnwise_data = torch.empty_like(data) + columnwise_scale_inv = torch.zeros( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // NVFP4_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + dtype=torch.uint8, + device=device, + ) + + # Construct NVFP4 tensor + return NVFP4Tensor( + shape=shape, + dtype=dtype, + fp8_dtype=self.dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=self, + requires_grad=requires_grad, + ) + + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return NVFP4BlockScaling + +class NVFP4Tensor(MXFP8Tensor): + # no override on + # quantize_ + # dequantize (low priority) + # clone + # view + # reshape + # contiguous + + def __repr__(self, *, tensor_contents=None): + return f"NVFP4Tensor(fp4_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + Quantizer can be used for in-place operations. + """ + if self._quantizer is not None: + return self._quantizer + return NVFP4Quantizer( + fp8_dtype=self._fp8_dtype, + ) + + def detach(self) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + # TODO(ksivamani): Fix the detach bug + return NVFP4Tensor.make_like(self) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._rowwise_data + out_data = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + out_shape = out_data.size() + return NVFP4Tensor( + shape=out_shape, + dtype=tensor.dtype, + rowwise_data=out_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=tensor._columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + quantizer=tensor._quantizer, + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + shape: torch.shape, + ) -> NVFP4Tensor: + """Build NVFP4Tensor, for use in __reduce__ + __reduce_ex__ assumes object constructor has positional + arguments. + """ + return NVFP4Tensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=fp8_dtype, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + shape=shape, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling""" + return ( + NVFP4Tensor._make_in_reduce_ex, + ( + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._fp8_dtype, + self.dtype, + self.shape, + ), + ) + + def _get_data(self) -> NVFP4Tensor: + """Get tensor data property""" + return super()._get_data() + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + Just takes FP8 data if setting from a MXFP8Tensor. Otherwise + casts to FP8. + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) + + # Just copy FP8 data if other tensor is MXFP8Tensor + if isinstance(tensor, NVFP4Tensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + NVFP4Tensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._fp8_dtype = tensor._fp8_dtype + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + return + + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.internal = False + self.data = self._quantizer.quantize(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to FP8 when setting MXFP8Tensor.data + data = property(_get_data, _set_data) \ No newline at end of file From ff705cfb00b4b1e62c27939d865be5704a51969a Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Fri, 26 Sep 2025 00:37:08 +0000 Subject: [PATCH 05/12] Enable cublasLtMatmul for nvfp4 * forward pass functional, backward raise exception due to only TN layout allowed in cublaslt nvfp4 --- .../common/gemm/cublaslt_gemm.cu | 83 +++++++++++++------ transformer_engine/common/swizzle/swizzle.cu | 4 +- transformer_engine/pytorch/module/linear.py | 4 + 3 files changed, 65 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index fa8785dcc7..1f03d8eacf 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -35,6 +35,8 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { return CUDA_R_8F_E4M3; case DType::kFloat8E5M2: return CUDA_R_8F_E5M2; + case DType::kFloat4E2M1: + return CUDA_R_4F_E2M1; default: NVTE_ERROR("Invalid type"); } @@ -119,8 +121,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } } - } else if (is_mxfp_scaling(A.scaling_mode)) { - // MXFP8 + } else if (is_mxfp_scaling(A.scaling_mode) || is_nvfp_scaling(A.scaling_mode)) { + // MXFP8/NVFP4 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). if (is_A_transposed) { @@ -178,7 +180,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } } - } else if (is_mxfp_scaling(B.scaling_mode)) { + } else if (is_mxfp_scaling(B.scaling_mode) || is_nvfp_scaling(B.scaling_mode)) { // MXFP8 // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). @@ -233,6 +235,12 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { + + const bool are_ab_nvfp = is_nvfp_scaling(inputA->scaling_mode) && is_nvfp_scaling(inputB->scaling_mode); + if (are_ab_nvfp) { + NVTE_CHECK(transa == CUBLAS_OP_T && transb == CUBLAS_OP_N, + "NVFP4 of Cublaslt supports only TN layout, i.e. transposed A and non-transposed B"); + } // Tensor dims in row-major order const int A0 = inputA->flat_first_dim(); const int A1 = inputA->flat_last_dim(); @@ -242,8 +250,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, // GEMM dims in column-major order const int m = transa == CUBLAS_OP_T ? A0 : A1; const int n = transb == CUBLAS_OP_T ? B1 : B0; - const int k = transa == CUBLAS_OP_T ? A1 : A0; - NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k, + const auto dim = (transa == CUBLAS_OP_T) ? A1 : A0; + const auto el_per_byte = are_ab_nvfp ? 2 : 1; + const int k = dim * el_per_byte; + + NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == dim, "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, ")"); const int ldd = m; @@ -269,26 +280,28 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } const bool gelu = pre_gelu_out != nullptr; const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype); + const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype); + NVTE_CHECK(!(use_fp4 && use_fp8), "A and B must use the same precision (both FP4 or both FP8)"); const cudaDataType_t A_type = get_cuda_dtype(param.Atype); const cudaDataType_t B_type = get_cuda_dtype(param.Btype); const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype); const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype); - NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr, - "FP8 input to GEMM requires inverse of scale!"); - NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr, - "FP8 input to GEMM requires inverse of scale!"); + NVTE_CHECK(!is_narrow_dtype(param.Atype) || param.A_scale_inv != nullptr, + "FP8/FP4 input to GEMM requires inverse of scale!"); + NVTE_CHECK(!is_narrow_dtype(param.Btype) || param.B_scale_inv != nullptr, + "FP8/FP4 input to GEMM requires inverse of scale!"); // check consistency of arguments: - // if fp8 is desired, context cannot be null - // fp8 + gelu fusion + fp8 aux is unavailable right now. - if (use_fp8 && gelu) { - NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype), - "fp8 Aux output for gemm + gelu fusion not supported!"); + // if fp8/fp4 is desired, context cannot be null + // fp8/fp4 + gelu fusion + fp8 aux is unavailable right now. + if ((use_fp8 || use_fp4) && gelu) { + NVTE_CHECK(!is_narrow_dtype(outputPreGelu->data.dtype), + "fp8/fp4 Aux output for gemm + gelu fusion not supported!"); } - if (is_fp8_dtype(outputD->data.dtype)) { - NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!"); + if (is_narrow_dtype(outputD->data.dtype)) { + NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8/FP4 GEMM output!"); } float one = 1.0; @@ -335,12 +348,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, // set fp8 attributes -- input and output types should already be set to fp8 as appropriate // Note: gelu fusion isn't available right now, and we don't need // amax(D) either (next op is high precision). - if (use_fp8) { - // Split accumulator. - const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, - &fastAccuMode, sizeof(fastAccuMode))); - + if (use_fp8 || use_fp4) { + if (use_fp8) { + // Fast accumulator mode can be only set for FP8 problems + // Split accumulator. + const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &fastAccuMode, sizeof(fastAccuMode))); + } // Scaling factors. #if CUDA_VERSION >= 12080 cublasLtMatmulMatrixScale_t scaling_mode_a; @@ -377,6 +392,26 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, sizeof(dummy_a_vec_stride))); } + } else if ((is_nvfp_scaling(inputA->scaling_mode) && is_nvfp_scaling(inputB->scaling_mode))) { + fp8e4m3 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + fp8e4m3 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; + // TODO(VS): do we need this? duplicate like mxfp8 + // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. + // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. + if (cublasLtGetVersion() <= 120803) { + const int64_t dummy_a_vec_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, + sizeof(dummy_a_vec_stride))); + } } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && (inputB->scaling_mode == NVTE_BLOCK_SCALING_1D || @@ -414,8 +449,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b))); #endif - if (is_fp8_dtype(outputD->data.dtype)) { - // Accumulation mode not supported for FP8 output + if (is_narrow_dtype(outputD->data.dtype)) { + // Accumulation mode not supported for FP8/FP4 output C = nullptr; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale))); diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index cea0e5080b..e80909e29c 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -201,7 +201,7 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons namespace transformer_engine { void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { - if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { + if (!is_narrow_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + "."); } @@ -216,7 +216,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s auto& scaling_mode = input->scaling_mode; // 1D block scaling, row-wise or colum-wise - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + if (scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING) { const int m = input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; const int k = diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 8a7c0ce2d1..fe31facdcd 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -201,6 +201,8 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + if input_quantizer.__class__.__name__ == "NVFP4Quantizer": + inputmat = inputmat.reshape(-1, inputmat.shape[-1]) inputmat = input_quantizer(inputmat) own_quantized_input = True else: @@ -295,6 +297,8 @@ def forward( ub_type=ub_type, extra_output=reduce_scatter_out, ) + if input_quantizer.__class__.__name__ == "NVFP4Quantizer": + gemm_out = gemm_out.reshape(inp.shape[:2] + (-1,)) nvtx_range_pop(f"{nvtx_label}.gemm") # ------------------------------------------------------ # Finished forward GEMM... From de9e65044918d59b8a5611c35de370b9902bd698 Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Fri, 26 Sep 2025 04:10:00 +0000 Subject: [PATCH 06/12] Fix nvfp4 dequantization --- .../common/util/dequantize_kernels.cuh | 14 ++++++++++---- .../pytorch/tensor/_internal/mxfp8_tensor_base.py | 13 +++++++++++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index c56a27d290..7ed1cca5bb 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -45,7 +45,7 @@ constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; // 8 = 128 / 16 constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; // 128 -constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16 static_assert(ITERATIONS >= 1); template @@ -194,15 +194,21 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int shmem_offset_y = thread_offset_Y; const int shmem_offset_x = thread_offset_X_rowwise; - in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); - + const int shmem_offset_x_in = (tid_rowwise_X/packing) * ELEMS_PER_THREAD; + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x_in]); + + // only used for packed fp4 + // every 2 threads loading the same 16 elements + // thread 0 unpacked and dequantized to first 16, + // thread 1 unpacked and dequantized to second 16 + const int base_idx = (tid_rowwise_X % packing) * (ELEMS_PER_THREAD/packing); //only used for packed fp4 #pragma unroll for (int j = 0; j < ELEMS_PER_THREAD/packing; ++j) { if constexpr (std::is_same_v || std::is_same_v) { out.data.elt[j] = static_cast(block_scale * static_cast(in.data.elt[j])); } else if constexpr (std::is_same_v) { // fp4(y), fp4(x) -> fp16.x, fp16.y (no need special handling, just reversing the convention of how we pack) - __half2_raw hfraw2 = __nv_cvt_fp4x2_to_halfraw2(in.data.elt[j].__x, __NV_E2M1); + __half2_raw hfraw2 = __nv_cvt_fp4x2_to_halfraw2(in.data.elt[base_idx+j].__x, __NV_E2M1); __half2 h2; memcpy(&h2, &hfraw2, sizeof(h2)); out.data.elt[j*2] = static_cast(block_scale * static_cast(h2.x)); diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index 4ec49392b7..34f9930e3c 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -202,8 +202,17 @@ def update_usage( self._columnwise_scale_inv = None class NVFP4TensorBase(MXFP8TensorBase): - pass # Note: probably good to generalize MXFP8TensorBase as MXNVTensorBase in the future # but too much change at this point, we just subclass it for NVFP4 # fp8_dtype: TE_DType attr name is misleading for NVFP4 but all fp8 recipe uses it, - # would require some refactoring to change it. \ No newline at end of file + # would require some refactoring to change it. + def __repr__(self): + data_rowwise = self.dequantize() + + return ( + "NVFP4TensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"rowwise_scaled_data={data_rowwise}" + f"rowwise_scale_inv={self._rowwise_scale_inv}, " + ")" + ) From 33bfd4fb551f026af74ec2063eb83c07cb72e227 Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Fri, 26 Sep 2025 18:44:54 +0000 Subject: [PATCH 07/12] nvfp4 scale calculation changes - clamp to between min subnormal and max normal of e4m3 range --- transformer_engine/common/util/cast_kernels.cuh | 12 ++++++++++-- transformer_engine/common/utils.cuh | 1 + 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 9638a0595d..db8dc5bee3 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -296,12 +296,16 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); } - const float x = subwarp_amax * Quantized_Limits::max_norm_rcp; + float x = subwarp_amax * Quantized_Limits::max_norm_rcp; ScaleType thread_scales_rowwise{}; if constexpr (std::is_same_v) { thread_scales_rowwise = float_to_e8m0(x); // power of 2 values } else if constexpr (std::is_same_v) { + // (amax/E2M1_NORM_MAX).clamp(E4M3_SUBNORM_MIN, E4M3_NORM_MAX) + const float maxNorm = static_cast(Numeric_Traits::maxNorm); + const float minSubNorm = static_cast(Numeric_Traits::minSubNorm); + x = fminf(fmaxf(x, minSubNorm), maxNorm); thread_scales_rowwise = ScaleType(x); } else { static_assert(!std::is_same_v, "Unsupported ScaleType"); @@ -391,12 +395,16 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __builtin_assume(amax >= 0); block_amax = fmaxf(block_amax, amax); - const float x = amax * Quantized_Limits::max_norm_rcp; + float x = amax * Quantized_Limits::max_norm_rcp; ScaleType thread_scales_colwise{}; if constexpr (std::is_same_v) { thread_scales_colwise = float_to_e8m0(x); } else if constexpr (std::is_same_v) { + // (amax/E2M1_NORM_MAX).clamp(E4M3_SUBNORM_MIN, E4M3_NORM_MAX) + const float maxNorm = static_cast(Numeric_Traits::maxNorm); + const float minSubNorm = static_cast(Numeric_Traits::minSubNorm); + x = fminf(fmaxf(x, minSubNorm), maxNorm); thread_scales_colwise = ScaleType(x); } else { static_assert(!std::is_same_v, "Unsupported ScaleType"); diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 62bf5618de..e552abaa17 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -919,6 +919,7 @@ template <> struct Numeric_Traits { static constexpr int maxUnbiasedExponent = 8; static constexpr double maxNorm = 448; + static constexpr double minSubNorm = 1.0 / 512.0; //2e-9 }; template <> From 393b7963f666df7cdb7476a08439b37d6000f37f Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Sat, 27 Sep 2025 18:09:29 +0000 Subject: [PATCH 08/12] Enable swizzling of NVFP4 scaling factor --- transformer_engine/pytorch/csrc/util.cpp | 52 ++++++++++++------------ 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index a878345ffc..fdd4a7b4bc 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -11,14 +11,21 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { using namespace transformer_engine::pytorch; + using DType = transformer_engine::DType; - if (input.scaling_mode() == NVTE_INVALID_SCALING) { - NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { - return std::nullopt; + switch (input.scaling_mode()) { + case NVTE_MXFP8_1D_SCALING: + case NVTE_NVFP4_1D_SCALING: + break; + case NVTE_INVALID_SCALING: + NVTE_ERROR("Invalid scaling mode for swizzle."); + break; + default: + return std::nullopt; } - NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + const size_t el_bit_size = input.element_size_bits(); + NVTE_CHECK((el_bit_size == 8 || el_bit_size == 4), "8/4-bit input type required for swizzling scaling factors."); NVTEBasicTensor scale_inv; if (rowwise) { @@ -42,35 +49,30 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap // Reconstruct input only to avoid swizzling both directions if not needed. // Use any 8 bit type, it's irrelevant. - transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); - transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper input_cu(input.scaling_mode()); + transformer_engine::TensorWrapper output_cu(input.scaling_mode()); + const DType scale_type = input.scaling_mode() == NVTE_NVFP4_1D_SCALING ? DType::kFloat8E4M3 : DType::kFloat8E8M0; + const DType data_type = input.scaling_mode() == NVTE_NVFP4_1D_SCALING ? DType::kFloat4E2M1 : DType::kFloat8E4M3; + if (rowwise) { - input_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); - output_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input_cu.set_rowwise_data(input.dptr(), data_type, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_type, scale_inv_shape); + output_cu.set_rowwise_data(input.dptr(), data_type, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_type, scale_inv_shape); } else { - input_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); - output_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, - transformer_engine::DType::kFloat8E8M0, scale_inv_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), data_type, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_type, scale_inv_shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), data_type, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_type, scale_inv_shape); } // Launch kernel nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_type, scale_inv_shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_type, scale_inv_shape); } return swizzled_scale_inv; From f39a4922893683ed644bd95214e9391aff75d0ab Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Sat, 27 Sep 2025 18:17:08 +0000 Subject: [PATCH 09/12] Create new NVFP4FwdMXFP8BwdScaling recipe * motivation: due to current TN-only layout for cublaslt NVFP8 matmul * this recipe uses TN NVFP4 forward, and NN/NT MXFP8 backward, * avoiding tensor relayout which can be costly to materialize for * large models. * piggyback NVFP4Quantizer for shadow MXFP8Quantizer needed for backward pass --- transformer_engine/common/recipe/__init__.py | 25 +++++++++++ .../pytorch/csrc/extensions/gemm.cpp | 2 +- transformer_engine/pytorch/fp8.py | 41 +++++++++++++++++++ transformer_engine/pytorch/module/linear.py | 7 +++- .../pytorch/tensor/nvfp4_tensor.py | 23 ++++++++++- 5 files changed, 94 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 13e6a00635..feb6ac7d42 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -77,6 +77,10 @@ def mxfp8(self): """Whether the given recipe is MXFP8 block scaling.""" return isinstance(self, MXFP8BlockScaling) + def nvfp4_fwd_mxfp8_bwd(self): + """Whether the given recipe is NVFP4 forward and MXFP8 backward scaling.""" + return isinstance(self, NVFP4FwdMXFP8BwdScaling) + def delayed(self): """Whether the given recipe is delayed scaling.""" return isinstance(self, DelayedScaling) @@ -252,6 +256,7 @@ class NVFP4BlockScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format == Format.E2M1, "Only E2M1 training is supported, fwd and bwd in E2M1." + raise NotImplementedError("NVFP4BlockScaling is not fully implemented yet, limited by TN layout cublaslt gemm support.Please consider NVFP4FwdMXFP8BwdScaling()") def __repr__(self) -> str: return ( @@ -260,6 +265,26 @@ def __repr__(self) -> str: f"format={str(self.fp8_format).split('.')[1]}" ) +@dataclass() +class NVFP4FwdMXFP8BwdScaling(Recipe): + """ + TODO(VS): documentation + abusing fp8 prefix now as refactoring requires broader changes + """ + margin: int = 0 + fp8_format: Format = Format.E2M1 + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp8_format == Format.E2M1, "Only E2M1 training is supported, fwd and bwd in E2M1." + + def __repr__(self) -> str: + return ( + f"recipe_type={self.__class__.__name__}, " + f"margin={self.margin}, " + f"format={str(self.fp8_format).split('.')[1]}" + ) @dataclass() class MXFP8BlockScaling(Recipe): diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 99bb4e69fd..56099eb13b 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -37,7 +37,7 @@ namespace transformer_engine::pytorch { namespace detail { bool is_low_precision(const DType type) { - return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; + return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat4E2M1; } std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index fb6cf3061f..79a16a3884 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -20,6 +20,7 @@ Format, MXFP8BlockScaling, NVFP4BlockScaling, + NVFP4FwdMXFP8BwdScaling, Float8CurrentScaling, Float8BlockScaling, ) @@ -822,6 +823,8 @@ def create( cls = MXFP8BlockScalingRecipeState elif recipe.nvfp4(): cls = NVFP4BlockScalingRecipeState + elif recipe.nvfp4_fwd_mxfp8_bwd(): + cls = NVFP4FwdMXFP8BwdRecipeState elif recipe.float8_current_scaling(): cls = Float8CurrentScalingRecipeState elif recipe.float8_block_scaling(): @@ -999,6 +1002,44 @@ def make_quantizers(self) -> list: return [NVFP4Quantizer(self.dtype) for i in range(self.num_quantizers)] +class NVFP4FwdMXFP8BwdRecipeState(RecipeState): + """TODO(VS) documentation""" + recipe: NVFP4FwdMXFP8BwdScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: NVFP4FwdMXFP8BwdScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp4_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + # TODO; Find better design for this, adding here to avoid circular import. + from .tensor.nvfp4_tensor import NVFP4Quantizer, MXFP8Quantizer + if self.mode == "forward" and self.num_quantizers == 3: + return [ + NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=True), # input + NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=True), # weight + NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=False), # output (unused) + ] + elif self.mode == "backward" and self.num_quantizers == 2: + # grad_output and grad_input (unused) + return [MXFP8Quantizer(tex.DType.kFloat8E4M3) for i in range(self.num_quantizers)] + else: + raise NotImplementedError("Unexpected entry, pls debug") + class Float8BlockScalingRecipeState(RecipeState): """Configuration for Float8BlockScaling quantization. diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index fe31facdcd..56208e50d3 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -624,7 +624,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") gemm_out, *_, reduce_scatter_out = general_gemm( - weight_fp8, + weight_fp8.bw_tensor if hasattr(weight_fp8, "bw_tensor") else weight_fp8, grad_output, get_workspace(), layout="NN", @@ -792,7 +792,10 @@ def wgrad_gemm( else: # Call wgrad GEMM now - wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output) + wgrad, grad_bias_ = wgrad_gemm( + inputmat_total.bw_tensor if hasattr(inputmat_total, "bw_tensor") else inputmat_total, + grad_output + ) # Update grad bias if needed if grad_bias is None: diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index ed035e9c43..1ec6571cf0 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -17,8 +17,29 @@ class NVFP4Quantizer(MXFP8Quantizer): # no override on - # __init__ # calibrate + def __init__( + self, + fp8_dtype: TE_DType = TE_DType.kFloat4E2M1, # default to NVFP4 + *, + rowwise: bool = True, + columnwise: bool = True, + mxfp8_bw_quantize: bool = False, + ) -> None: + super().__init__(fp8_dtype, rowwise=rowwise, columnwise=columnwise) + self.dtype = fp8_dtype + self.mxfp8_bw_quantize = mxfp8_bw_quantize + if mxfp8_bw_quantize: + self.mxfp8_quantizer = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3, rowwise=rowwise, columnwise=columnwise) + + def quantize(self, tensor, *, out = None, dtype = None): + nvfp4_quantized = super().quantize(tensor, out=out, dtype=dtype) + if self.mxfp8_bw_quantize: + # Use MXFP8 quantizer for backward pass quantization + nvfp4_quantized.bw_quantizer = self.mxfp8_quantizer + nvfp4_quantized.bw_tensor = self.mxfp8_quantizer(tensor) + return nvfp4_quantized + def update_quantized( self, src: torch.Tensor, From d4aed4de427484e1b0b84471e991b36a33287dae Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Sat, 27 Sep 2025 20:20:38 +0000 Subject: [PATCH 10/12] Improve efficiency of NVFP4FwdMXFP8BwdScaling Recipe * remove redundant quantization, step elapse improves --- transformer_engine/pytorch/fp8.py | 6 +++--- transformer_engine/pytorch/module/linear.py | 15 +++++++++------ transformer_engine/pytorch/tensor/nvfp4_tensor.py | 3 ++- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 79a16a3884..fa826c18c4 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1030,9 +1030,9 @@ def make_quantizers(self) -> list: from .tensor.nvfp4_tensor import NVFP4Quantizer, MXFP8Quantizer if self.mode == "forward" and self.num_quantizers == 3: return [ - NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=True), # input - NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=True), # weight - NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=False), # output (unused) + NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=True, rowwise=True, columnwise=False), # input + NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=True, rowwise=True, columnwise=False), # weight + NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=False, rowwise=False, columnwise=True), # output (unused) ] elif self.mode == "backward" and self.num_quantizers == 2: # grad_output and grad_input (unused) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 56208e50d3..db18630a58 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -200,7 +200,8 @@ def forward( else: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, columnwise=False if hasattr(input_quantizer, "mxfp8_bw_quantize") else backward_needs_input) if input_quantizer.__class__.__name__ == "NVFP4Quantizer": inputmat = inputmat.reshape(-1, inputmat.shape[-1]) inputmat = input_quantizer(inputmat) @@ -226,7 +227,8 @@ def forward( is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() ) - weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + weight_quantizer.set_usage( + rowwise=True, columnwise=False if hasattr(weight_quantizer, "mxfp8_bw_quantize") else columnwise_usage) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -349,13 +351,14 @@ def forward( isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) or not ctx.backward_input_needs_gather ): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + inputmat.update_usage( + rowwise_usage=False, columnwise_usage=False if hasattr(weight_quantizer, "mxfp8_bw_quantize") else True) saved_inputmat = inputmat # Weight with column-wise usage is needed for dgrad GEMM. if inp.requires_grad: if isinstance(weightmat, QuantizedTensorBase): - weightmat.update_usage(columnwise_usage=True) + weightmat.update_usage(columnwise_usage=False if hasattr(weight_quantizer, "mxfp8_bw_quantize") else True) if cpu_offloading and saved_inputmat is not None: mark_activation_offload(saved_inputmat) @@ -597,7 +600,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(grad_output, QuantizedTensorBase): grad_output.update_usage(rowwise_usage=True) if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase): - weight_fp8.update_usage(columnwise_usage=True) + weight_fp8.update_usage(columnwise_usage=False if hasattr(ctx.weight_quantizer, "mxfp8_bw_quantize") else True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD @@ -680,7 +683,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total_work = None if ctx.fp8 or ctx.debug: if isinstance(inputmat_total, QuantizedTensorBase): - inputmat_total.update_usage(columnwise_usage=True) + inputmat_total.update_usage(columnwise_usage=False if hasattr(ctx.weight_quantizer, "mxfp8_bw_quantize") else True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) inputmat_total = ctx.input_quantizer(inputmat_total) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 1ec6571cf0..356c615b33 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -30,7 +30,8 @@ def __init__( self.dtype = fp8_dtype self.mxfp8_bw_quantize = mxfp8_bw_quantize if mxfp8_bw_quantize: - self.mxfp8_quantizer = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3, rowwise=rowwise, columnwise=columnwise) + # tensor are typically quantized in different axis between forward and backward + self.mxfp8_quantizer = MXFP8Quantizer(fp8_dtype=TE_DType.kFloat8E4M3, rowwise=not rowwise, columnwise=not columnwise) def quantize(self, tensor, *, out = None, dtype = None): nvfp4_quantized = super().quantize(tensor, out=out, dtype=dtype) From 010e04279aaff051d57e7cb0fdffef10aeae98fd Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Tue, 14 Oct 2025 17:57:15 +0000 Subject: [PATCH 11/12] Enable LayerNormMLP and LayerNormLinear for NVFP4FwdMXFP8BwdScaling recipe --- transformer_engine/pytorch/fp8.py | 18 ++++---- .../pytorch/module/layernorm_linear.py | 25 +++++++---- .../pytorch/module/layernorm_mlp.py | 44 ++++++++++++------- 3 files changed, 55 insertions(+), 32 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index fa826c18c4..12f13ce353 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1028,17 +1028,19 @@ def __init__( def make_quantizers(self) -> list: # TODO; Find better design for this, adding here to avoid circular import. from .tensor.nvfp4_tensor import NVFP4Quantizer, MXFP8Quantizer - if self.mode == "forward" and self.num_quantizers == 3: - return [ - NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=True, rowwise=True, columnwise=False), # input - NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=True, rowwise=True, columnwise=False), # weight - NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=False, rowwise=False, columnwise=True), # output (unused) - ] - elif self.mode == "backward" and self.num_quantizers == 2: + if self.mode == "forward" and self.num_quantizers % 3 == 0: + quantizers = [] + for _ in range(self.num_quantizers//3): + quantizers.append(NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=True, rowwise=True, columnwise=False)) # input + quantizers.append(NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=True, rowwise=True, columnwise=False)) # weight + quantizers.append(NVFP4Quantizer(self.dtype, mxfp8_bw_quantize=False, rowwise=False, columnwise=True)) + return quantizers + elif self.mode == "backward": # grad_output and grad_input (unused) return [MXFP8Quantizer(tex.DType.kFloat8E4M3) for i in range(self.num_quantizers)] else: - raise NotImplementedError("Unexpected entry, pls debug") + raise NotImplementedError("pls debug.") + class Float8BlockScalingRecipeState(RecipeState): """Configuration for Float8BlockScaling quantization. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b99952ad2a..ddbd85b2e1 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -177,7 +177,8 @@ def forward( if fp8: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, columnwise=False if hasattr(input_quantizer, "mxfp8_bw_quantize") else backward_needs_input) if with_input_all_gather and isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -191,6 +192,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not hasattr(input_quantizer, "mxfp8_bw_quantize") # Disable quantized norm for NVFP4 forward and MXFP8 backward due to serial (two) quantization design at torch level ) # Apply normalization @@ -271,7 +273,8 @@ def forward( # Configure quantizer if weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + weight_quantizer.set_usage( + rowwise=True, columnwise=False if hasattr(weight_quantizer, "mxfp8_bw_quantize") else is_grad_enabled) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -397,11 +400,12 @@ def forward( isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) or not ctx.ln_out_needs_gather ): - ln_out.update_usage(rowwise_usage=False) + ln_out.update_usage( + rowwise_usage=False, columnwise_usage=False if hasattr(weight_quantizer, "mxfp8_bw_quantize") else True) # Weight with column-wise usage is needed for dgrad GEMM. if isinstance(weightmat, QuantizedTensorBase): - weightmat.update_usage(columnwise_usage=True) + weightmat.update_usage(columnwise_usage=False if hasattr(weight_quantizer, "mxfp8_bw_quantize") else True) if cpu_offloading: mark_activation_offload(inputmat, mu, rsigma, ln_out) @@ -667,7 +671,7 @@ def backward( if isinstance(grad_output, QuantizedTensorBase): grad_output.update_usage(rowwise_usage=True) if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase): - weight.update_usage(columnwise_usage=True) + weight.update_usage(columnwise_usage=False if hasattr(ctx.weight_quantizer, "mxfp8_bw_quantize") else True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD @@ -694,7 +698,7 @@ def backward( # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") gemm_out, *_, reduce_scatter_out = general_gemm( - weight, + weight.bw_tensor if hasattr(weight, "bw_tensor") else weight, grad_output, get_workspace(), layout="NN", @@ -777,7 +781,7 @@ def backward( ln_out_total_work = None if ctx.fp8 or ctx.debug: if isinstance(ln_out_total, QuantizedTensorBase): - ln_out_total.update_usage(columnwise_usage=True) + ln_out_total.update_usage(columnwise_usage=False if hasattr(ctx.weight_quantizer, "mxfp8_bw_quantize") else True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) @@ -864,7 +868,9 @@ def wgrad_gemm( else: # Call wgrad GEMM now - wgrad, grad_bias_ = wgrad_gemm(ln_out_total, grad_output) + wgrad, grad_bias_ = wgrad_gemm( + ln_out_total.bw_tensor if hasattr(ln_out_total, "bw_tensor") else ln_out_total, + grad_output) # Update grad bias if needed if grad_bias is None: @@ -1374,6 +1380,9 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + def extra_repr(self): + return f"{self.normalization}" + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 375db477b0..91311cebe9 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -96,7 +96,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "qgeglu": (tex.qgeglu, tex.dqgeglu, None), "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), } - if recipe.delayed() or recipe.mxfp8(): + if recipe.delayed() or recipe.mxfp8() or recipe.nvfp4_fwd_mxfp8_bwd(): # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] return { @@ -236,7 +236,9 @@ def forward( if fp8: if fc1_input_quantizer is None: raise ValueError("Missing quantizer for FC1 input tensor") - fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input) + fc1_input_quantizer.set_usage( + rowwise=True, columnwise=False if hasattr(fc1_input_quantizer, "mxfp8_bw_quantize") else backwards_needs_fc1_input) + if sequence_parallel and isinstance( fc1_input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -254,6 +256,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not hasattr(fc1_input_quantizer, "mxfp8_bw_quantize") # Disable quantized norm for NVFP4 forward and MXFP8 backward due to serial (two) quantization design at torch level ) # Apply normalization @@ -325,8 +328,7 @@ def forward( # which handles weight caching etc. # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch - fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=False if hasattr(fc1_weight_quantizer, "mxfp8_bw_quantize") else is_grad_enabled) fc1_weight_final = module.get_weight_workspace( tensor=fc1_weight, quantizer=fc1_weight_quantizer, @@ -388,6 +390,8 @@ def forward( gemm_gelu_fusion = False if debug: gemm_gelu_fusion = False + if fc2_input_quantizer is not None and hasattr(fc2_input_quantizer, "mxfp8_bw_quantize"): + fc2_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_outputs = general_gemm( fc1_weight_final, ln_out_total, @@ -431,10 +435,11 @@ def forward( act_out = fc2_input_quantizer(act_out) else: fc1_out, *_ = fc1_outputs - if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): + _recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + if fp8 and (_recipe.float8_block_scaling() or _recipe.nvfp4_fwd_mxfp8_bwd()): # tex.quantize does not support GELU fusion for blockwise. act_out = activation_func(fc1_out, None) - act_out = tex.quantize(act_out, fc2_input_quantizer) + act_out = fc2_input_quantizer(act_out) else: act_out = activation_func(fc1_out, fc2_input_quantizer) @@ -501,9 +506,9 @@ def forward( # Weight with column-wise usage is needed for dgrad GEMM. if isinstance(fc1_weight_final, QuantizedTensorBase): - fc1_weight_final.update_usage(columnwise_usage=True) + fc1_weight_final.update_usage(columnwise_usage=False if hasattr(fc1_weight_quantizer, "mxfp8_bw_quantize") else True) if isinstance(fc2_weight_final, QuantizedTensorBase): - fc2_weight_final.update_usage(columnwise_usage=True) + fc2_weight_final.update_usage(columnwise_usage=False if hasattr(fc2_weight_quantizer, "mxfp8_bw_quantize") else True) if cpu_offloading: mark_activation_offload( @@ -792,11 +797,11 @@ def backward( if ctx.fc2_weight_quantizer is not None and isinstance( ctx.fc2_weight, QuantizedTensorBase ): - ctx.fc2_weight.update_usage(columnwise_usage=True) + ctx.fc2_weight.update_usage(columnwise_usage=False if hasattr(ctx.fc2_weight_quantizer, "mxfp8_bw_quantize") else True) # Perform GEMM gemm_output, *_ = general_gemm( - fc2_weight, + fc2_weight.bw_tensor if hasattr(fc2_weight, "bw_tensor") else fc2_weight, grad_output, get_workspace(), layout="NN", @@ -862,7 +867,7 @@ def backward( # make sure required data is available if ctx.fp8 or ctx.debug: if isinstance(act_out, QuantizedTensorBase): - act_out.update_usage(columnwise_usage=True) + act_out.update_usage(columnwise_usage=False if hasattr(ctx.fc2_weight_quantizer, "mxfp8_bw_quantize") else True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) @@ -915,7 +920,9 @@ def fc2_wgrad_gemm( else: # Call wgrad GEMM now - fc2_wgrad, fc2_bias_grad_ = fc2_wgrad_gemm(act_out, grad_output) + fc2_wgrad, fc2_bias_grad_ = fc2_wgrad_gemm( + act_out.bw_tensor if hasattr(act_out, "bw_tensor") else act_out, + grad_output) # Update grad bias if needed if fc2_bias_grad is None: @@ -1024,7 +1031,7 @@ def fc2_wgrad_gemm( if ctx.fc1_weight_quantizer is not None and isinstance( ctx.fc1_weight_quantizer, QuantizedTensorBase ): - ctx.fc1_weight.update_usage(columnwise_usage=True) + ctx.fc1_weight.update_usage(columnwise_usage=False if hasattr(ctx.fc1_weight_quantizer, "mxfp8_bw_quantize") else True) # Output buffers for Userbuffers reduce-scatter gemm_out = None @@ -1038,7 +1045,7 @@ def fc2_wgrad_gemm( # dgrad GEMM gemm_out, *_, reduce_scatter_out = general_gemm( - fc1_weight, + fc1_weight.bw_tensor if hasattr(fc1_weight, "bw_tensor") else fc1_weight, dact, get_workspace(), out=gemm_out, @@ -1094,7 +1101,7 @@ def fc2_wgrad_gemm( ln_out_total_work = None if ctx.fp8 or ctx.debug: if isinstance(ln_out_total, QuantizedTensorBase): - ln_out_total.update_usage(columnwise_usage=True) + ln_out_total.update_usage(columnwise_usage=False if hasattr(ctx.fc1_weight_quantizer, "mxfp8_bw_quantize") else True) else: ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.fc1_input_quantizer(ln_out_total) @@ -1171,7 +1178,9 @@ def fc1_wgrad_gemm( else: # Call wgrad GEMM now - fc1_wgrad_outputs = fc1_wgrad_gemm(ln_out_total, dact) + fc1_wgrad_outputs = fc1_wgrad_gemm( + ln_out_total.bw_tensor if hasattr(ln_out_total, "bw_tensor") else ln_out_total, + dact) if fuse_gemm_and_bias_fc1_wgrad: fc1_wgrad, fc1_bias_grad = fc1_wgrad_outputs else: @@ -1682,6 +1691,9 @@ def reset_parameters(self, defer_init=False): if self.set_parallel_mode: setattr(self.fc2_bias, "sequence_parallel", self.sequence_parallel) + def extra_repr(self): + return f"{self.normalization}, {self.activation}" + @no_torch_dynamo() def forward( self, From c2056e01b7c1d458bb3a86e895c11eb401c181a0 Mon Sep 17 00:00:00 2001 From: Vui Seng Chua Date: Mon, 20 Oct 2025 19:49:42 +0000 Subject: [PATCH 12/12] Make NVFP4Quantizer to be compatible with NVFP4FwdMXFP8BwdScaling by default --- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 356c615b33..b8f21113db 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -7,7 +7,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe +from transformer_engine.common.recipe import NVFP4FwdMXFP8BwdScaling, Recipe from ..constants import NVFP4_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple from .quantized_tensor import QuantizedTensor, Quantizer @@ -131,7 +131,7 @@ def make_empty( ) def _get_compatible_recipe(self) -> Union[type[Recipe], None]: - return NVFP4BlockScaling + return NVFP4FwdMXFP8BwdScaling class NVFP4Tensor(MXFP8Tensor): # no override on