Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {

inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }

inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }

inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
end, " in a vector with ", shape.size(), " entries");
Expand Down Expand Up @@ -189,6 +191,7 @@ struct Tensor {
}
break;
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING:
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape;
Expand Down Expand Up @@ -561,6 +564,25 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}

#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8FP4ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat4E2M1: { \
using type = fp4e2m1; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}

#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
Expand Down Expand Up @@ -604,19 +626,23 @@ struct TypeInfo {
NVTE_ERROR("Invalid type for 16 bit."); \
}

#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \
switch (SCALE_DIM) { \
case 1: { \
constexpr size_t DIM = 1; \
{ __VA_ARGS__ } \
} break; \
case 32: { \
constexpr size_t DIM = 32; \
{ __VA_ARGS__ } \
} break; \
default: { \
NVTE_ERROR("Invalid size of the MX scaling factor."); \
} \
#define TRANSFORMER_ENGINE_MXNV_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \
switch (SCALE_DIM) { \
case 1: { \
constexpr size_t DIM = 1; \
{ __VA_ARGS__ } \
} break; \
case 16: { \
constexpr size_t DIM = 16; \
{ __VA_ARGS__ } \
} break; \
case 32: { \
constexpr size_t DIM = 32; \
{ __VA_ARGS__ } \
} break; \
default: { \
NVTE_ERROR("Invalid size of the MX/NV scaling factor."); \
} \
}

#define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \
Expand Down
83 changes: 59 additions & 24 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2;
case DType::kFloat4E2M1:
return CUDA_R_4F_E2M1;
default:
NVTE_ERROR("Invalid type");
}
Expand Down Expand Up @@ -119,8 +121,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
}
} else if (is_mxfp_scaling(A.scaling_mode)) {
// MXFP8
} else if (is_mxfp_scaling(A.scaling_mode) || is_nvfp_scaling(A.scaling_mode)) {
// MXFP8/NVFP4
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
if (is_A_transposed) {
Expand Down Expand Up @@ -178,7 +180,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
}
} else if (is_mxfp_scaling(B.scaling_mode)) {
} else if (is_mxfp_scaling(B.scaling_mode) || is_nvfp_scaling(B.scaling_mode)) {
// MXFP8
// Note: Row-wise and column-wise data are scaled along different
// dimensions (with matrix interpreted in row-major order).
Expand Down Expand Up @@ -233,6 +235,12 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split,
int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) {

const bool are_ab_nvfp = is_nvfp_scaling(inputA->scaling_mode) && is_nvfp_scaling(inputB->scaling_mode);
if (are_ab_nvfp) {
NVTE_CHECK(transa == CUBLAS_OP_T && transb == CUBLAS_OP_N,
"NVFP4 of Cublaslt supports only TN layout, i.e. transposed A and non-transposed B");
}
// Tensor dims in row-major order
const int A0 = inputA->flat_first_dim();
const int A1 = inputA->flat_last_dim();
Expand All @@ -242,8 +250,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// GEMM dims in column-major order
const int m = transa == CUBLAS_OP_T ? A0 : A1;
const int n = transb == CUBLAS_OP_T ? B1 : B0;
const int k = transa == CUBLAS_OP_T ? A1 : A0;
NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == k,
const auto dim = (transa == CUBLAS_OP_T) ? A1 : A0;
const auto el_per_byte = are_ab_nvfp ? 2 : 1;
const int k = dim * el_per_byte;

NVTE_CHECK((transb == CUBLAS_OP_T ? B0 : B1) == dim,
"GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
")");
const int ldd = m;
Expand All @@ -269,26 +280,28 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
}
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype);
const bool use_fp4 = is_fp4_dtype(param.Atype) || is_fp4_dtype(param.Btype);
NVTE_CHECK(!(use_fp4 && use_fp8), "A and B must use the same precision (both FP4 or both FP8)");

const cudaDataType_t A_type = get_cuda_dtype(param.Atype);
const cudaDataType_t B_type = get_cuda_dtype(param.Btype);
const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype);
const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype);

NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_narrow_dtype(param.Atype) || param.A_scale_inv != nullptr,
"FP8/FP4 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_narrow_dtype(param.Btype) || param.B_scale_inv != nullptr,
"FP8/FP4 input to GEMM requires inverse of scale!");

// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8 && gelu) {
NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
"fp8 Aux output for gemm + gelu fusion not supported!");
// if fp8/fp4 is desired, context cannot be null
// fp8/fp4 + gelu fusion + fp8 aux is unavailable right now.
if ((use_fp8 || use_fp4) && gelu) {
NVTE_CHECK(!is_narrow_dtype(outputPreGelu->data.dtype),
"fp8/fp4 Aux output for gemm + gelu fusion not supported!");
}
if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
if (is_narrow_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8/FP4 GEMM output!");
}

float one = 1.0;
Expand Down Expand Up @@ -335,12 +348,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision).
if (use_fp8) {
// Split accumulator.
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
&fastAccuMode, sizeof(fastAccuMode)));

if (use_fp8 || use_fp4) {
if (use_fp8) {
// Fast accumulator mode can be only set for FP8 problems
// Split accumulator.
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
&fastAccuMode, sizeof(fastAccuMode)));
}
// Scaling factors.
#if CUDA_VERSION >= 12080
cublasLtMatmulMatrixScale_t scaling_mode_a;
Expand Down Expand Up @@ -377,6 +392,26 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
sizeof(dummy_a_vec_stride)));
}
} else if ((is_nvfp_scaling(inputA->scaling_mode) && is_nvfp_scaling(inputB->scaling_mode))) {
fp8e4m3 *A_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.A_scale_inv);
fp8e4m3 *B_scale_inverse = reinterpret_cast<fp8e4m3 *>(param.B_scale_inv);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse)));
scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
// TODO(VS): do we need this? duplicate like mxfp8
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cublasLtGetVersion() <= 120803) {
const int64_t dummy_a_vec_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
sizeof(dummy_a_vec_stride)));
}
} else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) &&
(inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
Expand Down Expand Up @@ -414,8 +449,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode_b, sizeof(scaling_mode_b)));
#endif
if (is_fp8_dtype(outputD->data.dtype)) {
// Accumulation mode not supported for FP8 output
if (is_narrow_dtype(outputD->data.dtype)) {
// Accumulation mode not supported for FP8/FP4 output
C = nullptr;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ enum NVTEScalingMode {
and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD).
*/
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4,
NVTE_NVFP4_1D_SCALING = 5,
NVTE_INVALID_SCALING = 100
};

Expand Down Expand Up @@ -431,6 +432,15 @@ inline bool is_fp8_dtype(const DType t) {
*/
inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; }

/*! \brief Check if TE datatype is FP8 or FP4
*
* Return true if TE datatype is FP8 or FP4
* \param[in] DType TE Datatype of interest
*/
inline bool is_narrow_dtype(const DType t) {
return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2 || t == DType::kFloat4E2M1;
}

/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
*/
Expand Down
56 changes: 55 additions & 1 deletion transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ class _FormatHelper(NamedTuple):

class Format(Enum):
"""
Supported FP8 formats.
Supported FP8/FP4 formats.

Values
------
E2M1 :
All FP8 tensors are in e2m1 format
E4M3 :
All FP8 tensors are in e4m3 format
E5M2 :
Expand All @@ -35,6 +37,7 @@ class Format(Enum):
FP8 tensors in the backward pass are in e5m2 format
"""

E2M1 = _FormatHelper(max_fwd=6, max_bwd=6)
E4M3 = _FormatHelper(max_fwd=448, max_bwd=448)
E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344)
HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)
Expand Down Expand Up @@ -66,10 +69,18 @@ class Recipe:
Base recipe class.
"""

def nvfp4(self):
"""Whether the given recipe is NVFP4 block scaling."""
return isinstance(self, NVFP4BlockScaling)

def mxfp8(self):
"""Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling)

def nvfp4_fwd_mxfp8_bwd(self):
"""Whether the given recipe is NVFP4 forward and MXFP8 backward scaling."""
return isinstance(self, NVFP4FwdMXFP8BwdScaling)

def delayed(self):
"""Whether the given recipe is delayed scaling."""
return isinstance(self, DelayedScaling)
Expand Down Expand Up @@ -232,6 +243,49 @@ def __repr__(self) -> str:
)


@dataclass()
class NVFP4BlockScaling(Recipe):
"""
TODO(VS): documentation
abusing fp8 prefix now as refactoring requires broader changes
"""
margin: int = 0
fp8_format: Format = Format.E2M1
fp8_dpa: bool = False
fp8_mha: bool = False

def __post_init__(self) -> None:
assert self.fp8_format == Format.E2M1, "Only E2M1 training is supported, fwd and bwd in E2M1."
raise NotImplementedError("NVFP4BlockScaling is not fully implemented yet, limited by TN layout cublaslt gemm support.Please consider NVFP4FwdMXFP8BwdScaling()")

def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}"
)

@dataclass()
class NVFP4FwdMXFP8BwdScaling(Recipe):
"""
TODO(VS): documentation
abusing fp8 prefix now as refactoring requires broader changes
"""
margin: int = 0
fp8_format: Format = Format.E2M1
fp8_dpa: bool = False
fp8_mha: bool = False

def __post_init__(self) -> None:
assert self.fp8_format == Format.E2M1, "Only E2M1 training is supported, fwd and bwd in E2M1."

def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}"
)

@dataclass()
class MXFP8BlockScaling(Recipe):
"""
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/swizzle/swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
namespace transformer_engine {

void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) {
if (!is_narrow_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) {
NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + ".");
}

Expand All @@ -216,7 +216,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
auto& scaling_mode = input->scaling_mode;

// 1D block scaling, row-wise or colum-wise
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
if (scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING) {
const int m =
input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1];
const int k =
Expand Down
Loading