diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4b3ee96fb0..47349527c7 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,8 @@ NVIDIA Model Optimizer Changelog - Add support for rotating the input before quantization for RHT. - Add support for advanced weight scale search for NVFP4 quantization and its export path. - Enable PTQ workflow for Qwen3.5 MoE models. +- Add ``nvfp4_omlp_only`` quantization format for NVFP4 quantization. This is similar to ``nvfp4_mlp_only`` but also quantizes the output projection layer in attention. +- ``pass_through_bwd`` in the quantization config is now default to True. Please set it to False if you want to use STE with zeroed outlier gradients for potentially better QAT accuracy. **Misc** diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 7a9a71f885..6576242152 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -69,6 +69,8 @@ def forward_loop(model): model = mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop) ``` +> *For higher NVFP4 PTQ accuracy, we recommend using `mtq.NVFP4_MLP_ONLY_CFG` or `mtq.NVFP4_OMLP_ONLY_CFG` instead of `mtq.NVFP4_DEFAULT_CFG`. `NVFP4_MLP_ONLY_CFG` applies NVFP4 quantization to MLP (and MoE) layers, leaving attention layers unquantized. `NVFP4_OMLP_ONLY_CFG` additionally quantizes the `o_proj` layer. Both preserve accuracy in the sensitive attention QKV projections while still providing significant compression.* + ### 2. Export Quantized Model Once your model is quantized, you can now export that model to a checkpoint for easy deployment. \ @@ -126,7 +128,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http > *7.[PTQ for DeepSeek](../deepseek/README.md)* \ > *8.GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.* -> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead.* +> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead. For NVFP4 quantization specifically, we recommend `nvfp4_mlp_only` or `nvfp4_omlp_only` to achieve higher accuracy by restricting quantization to the MLP layers (and optionally the `o_proj` layer) while keeping the attention QKV projections unquantized.* > You can also create your own custom config using [this](https://nvidia.github.io/Model-Optimizer/guides/_pytorch_quantization.html#custom-calibration-algorithm) guide. @@ -144,7 +146,7 @@ For LLM models like [Llama-3](https://huggingface.co/meta-llama): # Install model specific pip dependencies if needed export HF_PATH= -scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|nvfp4|int8_sq|int4_awq|w4a8_awq] --tp [1|2|4|8] +scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|nvfp4|nvfp4_mlp_only|nvfp4_omlp_only|int8_sq|int4_awq|w4a8_awq] --tp [1|2|4|8] ``` > *By default `trust_remote_code` is set to false. Please turn it on if model calibration and eval requires it using `--trust_remote_code`.* @@ -295,7 +297,7 @@ accelerate launch --config_file fsdp2.yaml \ --fsdp_transformer_layer_cls_to_wrap= multinode_ptq.py \ --pyt_ckpt_path \ - --qformat \ + --qformat \ --kv_cache_qformat \ --batch_size \ --calib_size \ @@ -460,4 +462,4 @@ There are many quantization schemes supported in the example scripts: 1. The W4A8 AWQ is an extension of the INT4 AWQ quantization that it also uses FP8 for activation for more speed up and acceleration. -1. The [NVFP4](https://blogs.nvidia.com/blog/generative-ai-studio-ces-geforce-rtx-50-series/) is one of the new FP4 formats supported by NVIDIA Blackwell GPU and demonstrates good accuracy compared with other 4-bit alternatives. NVFP4 can be applied to both model weights as well as activations, providing the potential for both a significant increase in math throughput and reductions in memory footprint and memory bandwidth usage compared to the FP8 data format on Blackwell. +1. The [NVFP4](https://blogs.nvidia.com/blog/generative-ai-studio-ces-geforce-rtx-50-series/) is one of the new FP4 formats supported by NVIDIA Blackwell GPU and demonstrates good accuracy compared with other 4-bit alternatives. NVFP4 can be applied to both model weights as well as activations, providing the potential for both a significant increase in math throughput and reductions in memory footprint and memory bandwidth usage compared to the FP8 data format on Blackwell. For higher accuracy with NVFP4 PTQ, we recommend `nvfp4_mlp_only` or `nvfp4_omlp_only`. `nvfp4_mlp_only` restricts NVFP4 quantization to MLP (and MoE) layers only, leaving attention layers in higher precision. `nvfp4_omlp_only` extends this by also quantizing the `o_proj` layer, providing a middle ground between full NVFP4 and MLP-only quantization. diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 0234e731ea..50ac51aace 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -258,18 +258,6 @@ def build_quant_cfg( quant_cfg["quant_cfg"]["*image*"] = {"enable": False} quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} - if model_type in ["qwen3moe", "qwen3next", "minimax"] and qformat == "nvfp4": - # Disable the attention projection layers to retain accuracy - quant_cfg["quant_cfg"]["model*.*attn*in_proj*"] = {"enable": False} - quant_cfg["quant_cfg"]["model*.*attn*q_proj*"] = {"enable": False} - quant_cfg["quant_cfg"]["model*.*attn*k_proj*"] = {"enable": False} - quant_cfg["quant_cfg"]["model*.*attn*v_proj*"] = {"enable": False} - - if model_type == "deepseek": - # Disable MLA quantization for accuracy. - quant_cfg["quant_cfg"]["*self_attn.q*"] = {"enable": False} - quant_cfg["quant_cfg"]["*self_attn.kv*"] = {"enable": False} - return quant_cfg diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 6b29be4eb0..67e7016a82 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -87,6 +87,7 @@ "w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG, "w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG, "nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG, + "nvfp4_omlp_only": mtq.NVFP4_OMLP_ONLY_CFG, "nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG, "mxfp8": mtq.MXFP8_DEFAULT_CFG, } @@ -252,6 +253,7 @@ def auto_quantize( "fp8_pb_wo", "w4a8_mxfp4_fp8", "nvfp4_mlp_only", + "nvfp4_omlp_only", "mxfp8", ] for args.qformat in qformat_list @@ -900,6 +902,7 @@ def quantize_main( "fp8_pb_wo", "w4a8_mxfp4_fp8", "nvfp4_mlp_only", + "nvfp4_omlp_only", "mxfp8", ] or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index a74a1671a8..d49fb4005a 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -53,9 +53,9 @@ esac IFS="," for qformat in $QFORMAT; do case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only | nvfp4_svdquant | mxfp8) ;; + fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8) ;; *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only, nvfp4_svdquant, mxfp8]" >&2 + echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8]" >&2 exit 1 ;; esac diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 80a2a68761..a9b3574c4d 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -258,7 +258,11 @@ INT4_BLOCKWISE_WEIGHT_ONLY_CFG = { "quant_cfg": { - "*weight_quantizer": {"num_bits": 4, "block_sizes": {-1: 128}, "enable": True}, + "*weight_quantizer": { + "num_bits": 4, + "block_sizes": {-1: 128}, + "enable": True, + }, "*input_quantizer": {"enable": False}, **_default_disabled_quantizer_cfg, }, @@ -286,10 +290,20 @@ W4A8_AWQ_BETA_CFG = { "quant_cfg": { "*weight_quantizer": [ - {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}, "enable": True}, - {"num_bits": (4, 3), "axis": None, "enable": True}, + { + "num_bits": 4, + "block_sizes": {-1: 128, "type": "static"}, + "enable": True, + }, + { + "num_bits": (4, 3), + "enable": True, + }, ], - "*input_quantizer": {"num_bits": (4, 3), "axis": None, "enable": True}, + "*input_quantizer": { + "num_bits": (4, 3), + "enable": True, + }, **_default_disabled_quantizer_cfg, }, "algorithm": "awq_lite", @@ -380,7 +394,6 @@ "quant_cfg": { "*[kv]_bmm_quantizer": { "num_bits": (4, 3), - "axis": None, "enable": True, }, "default": {"enable": False}, @@ -392,7 +405,6 @@ "quant_cfg": { "*[kv]_bmm_quantizer": { "num_bits": (4, 3), - "axis": None, "bias": {-2: None, -4: None, "type": "static"}, }, "default": {"enable": False}, @@ -400,20 +412,16 @@ "algorithm": "max", } +_nvfp4_quantizer = { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "enable": True, +} + NVFP4_DEFAULT_CFG = { "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + "*weight_quantizer": _nvfp4_quantizer, + "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, }, "algorithm": "max", @@ -424,15 +432,9 @@ "*weight_quantizer": { "num_bits": (2, 1), "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, "enable": True, }, + "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, }, "algorithm": { @@ -446,15 +448,9 @@ "*weight_quantizer": { "num_bits": (2, 1), "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, "enable": True, }, + "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, }, "algorithm": { @@ -465,18 +461,8 @@ MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + "*weight_quantizer": _nvfp4_quantizer, + "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, **_mamba_moe_disabled_quantizer_cfg, }, @@ -484,18 +470,8 @@ } MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = { "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + "*weight_quantizer": _nvfp4_quantizer, + "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, **_mamba_moe_disabled_quantizer_cfg, "*mixer.in_proj*": {"enable": False}, # Skip mamba linear @@ -507,18 +483,8 @@ NVFP4_AWQ_LITE_CFG = { "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + "*weight_quantizer": _nvfp4_quantizer, + "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, }, "algorithm": "awq_lite", @@ -526,18 +492,8 @@ NVFP4_AWQ_CLIP_CFG = { "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + "*weight_quantizer": _nvfp4_quantizer, + "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, }, "algorithm": {"method": "awq_clip"}, @@ -545,18 +501,8 @@ NVFP4_AWQ_FULL_CFG = { "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + "*weight_quantizer": _nvfp4_quantizer, + "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, }, "algorithm": {"method": "awq_full", "alpha_step": 0.1}, @@ -566,10 +512,7 @@ NVFP4_AFFINE_KV_CFG = { "quant_cfg": { "*[kv]_bmm_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, + **_nvfp4_quantizer, "bias": {-2: None, -4: None, "type": "static"}, }, "default": {"enable": False}, @@ -579,12 +522,7 @@ NVFP4_KV_CFG = { "quant_cfg": { - "*[kv]_bmm_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + "*[kv]_bmm_quantizer": _nvfp4_quantizer, "default": {"enable": False}, }, "algorithm": "max", @@ -593,38 +531,23 @@ # Moved from examples/diffusers/quantization/config.py to here NVFP4_FP8_MHA_CONFIG = { "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + "*weight_quantizer": _nvfp4_quantizer, + "*input_quantizer": _nvfp4_quantizer, "*output_quantizer": {"enable": False}, "*q_bmm_quantizer": { "num_bits": (4, 3), - "axis": None, }, "*k_bmm_quantizer": { "num_bits": (4, 3), - "axis": None, }, "*v_bmm_quantizer": { "num_bits": (4, 3), - "axis": None, }, "*softmax_quantizer": { "num_bits": (4, 3), - "axis": None, }, "transformer_blocks*bmm2_output_quantizer": { "num_bits": (4, 3), - "axis": None, }, "default": {"enable": False}, }, @@ -638,36 +561,18 @@ "rotate": True, }, "*k_bmm_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, + **_nvfp4_quantizer, "rotate": True, }, - "*v_bmm_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + "*v_bmm_quantizer": _nvfp4_quantizer, }, "algorithm": "max", } NVFP4_SVDQUANT_DEFAULT_CFG = { "quant_cfg": { - "*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, - "*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, - "enable": True, - }, + "*weight_quantizer": _nvfp4_quantizer, + "*input_quantizer": _nvfp4_quantizer, **_default_disabled_quantizer_cfg, }, "algorithm": {"method": "svdquant", "lowrank": 32}, @@ -678,12 +583,10 @@ "*weight_quantizer": { "num_bits": (2, 1), "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)}, - "axis": None, "enable": True, }, "*input_quantizer": { "num_bits": (4, 3), - "axis": None, "enable": True, }, **_default_disabled_quantizer_cfg, @@ -697,7 +600,11 @@ "num_bits": (2, 1), "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, "enable": True, - "pass_through_bwd": True, + }, + "*block_sparse_moe*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + "enable": True, }, **_default_disabled_quantizer_cfg, }, @@ -714,28 +621,39 @@ "scale_bits": (4, 3), }, # Note: block_size is 32 here "enable": True, - "pass_through_bwd": True, + }, + "*block_sparse_moe*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": { + -1: 32, + "type": "dynamic", + "scale_bits": (4, 3), + }, # Note: block_size is 32 here + "enable": True, }, **_default_disabled_quantizer_cfg, }, "algorithm": "max", } +_nvfp4_mlp_only_quant_cfg = { + "*mlp*weight_quantizer": _nvfp4_quantizer, + "*mlp*input_quantizer": _nvfp4_quantizer, + "*block_sparse_moe*weight_quantizer": _nvfp4_quantizer, + "*block_sparse_moe*input_quantizer": _nvfp4_quantizer, + **_default_disabled_quantizer_cfg, +} + NVFP4_MLP_ONLY_CFG = { + "quant_cfg": _nvfp4_mlp_only_quant_cfg, + "algorithm": "max", +} + +NVFP4_OMLP_ONLY_CFG = { "quant_cfg": { - "*mlp*weight_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "enable": True, - "pass_through_bwd": True, - }, - "*mlp*input_quantizer": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, - "enable": True, - "pass_through_bwd": True, - }, - **_default_disabled_quantizer_cfg, + "*o_proj*weight_quantizer": _nvfp4_quantizer, + "*o_proj*input_quantizer": _nvfp4_quantizer, + **_nvfp4_mlp_only_quant_cfg, }, "algorithm": "max", } @@ -769,6 +687,7 @@ "NVFP4_MLP_WEIGHT_ONLY_CFG", "MXFP4_MLP_WEIGHT_ONLY_CFG", "NVFP4_MLP_ONLY_CFG", + "NVFP4_OMLP_ONLY_CFG", "MAMBA_MOE_NVFP4_CONSERVATIVE_CFG", "MAMBA_MOE_NVFP4_AGGRESSIVE_CFG", "MAMBA_MOE_FP8_CONSERVATIVE_CFG", @@ -1074,14 +993,14 @@ def validate_calibrator(cls, v, info: ValidationInfo): ) pass_through_bwd: bool = ModeloptField( - default=False, + default=True, title="If set to true, fake quantization will be a pass through for gradient computation.", description=""" Gradient computation where fake quantization is pass through is called 'Straight-Through Estimator (STE)'. STE does not require saving of the input tensor for performing backward pass and hence consumes less memory. - If set to False, we will use STE with zeroed outlier gradients. This setting could + If set to False, we will use STE with zeroed outlier gradients. This setting may yield better QAT accuracy depending on the quantization format. However, this setting requires saving of the input tensor for computing gradients which uses more memory. diff --git a/tests/gpu/torch/export/test_fsdp2_export.py b/tests/gpu/torch/export/test_fsdp2_export.py index ba59324dd7..2a493d4412 100644 --- a/tests/gpu/torch/export/test_fsdp2_export.py +++ b/tests/gpu/torch/export/test_fsdp2_export.py @@ -225,6 +225,7 @@ def test_fsdp2_weight_update_context_for_export(dist_workers): # mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, #TODO: Fix unit test for this case mtq.W4A8_MXFP4_FP8_CFG, mtq.NVFP4_MLP_ONLY_CFG, + mtq.NVFP4_OMLP_ONLY_CFG, ], ) @pytest.mark.parametrize("bias", [True, False]) @@ -244,6 +245,7 @@ def test_fsdp2_weight_update_context_for_fuse_layers(dist_workers, quant_config, # mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, #TODO: Fix unit test for this case mtq.W4A8_MXFP4_FP8_CFG, mtq.NVFP4_MLP_ONLY_CFG, + mtq.NVFP4_OMLP_ONLY_CFG, ], ) @pytest.mark.parametrize("bias", [True, False]) diff --git a/tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py b/tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py index ba7b522ae3..af884c878b 100644 --- a/tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quantizer_cuda.py @@ -81,10 +81,13 @@ def test_fp4(self): assert fp4_quantizer._get_amax(x) == x.abs().amax() - def test_fp4_backward(self): + @pytest.mark.parametrize("pass_through_bwd", [True, False]) + def test_fp4_backward(self, pass_through_bwd): fp4_quantizer = tensor_quantizer.TensorQuantizer( QuantizerAttributeConfig( - num_bits=(2, 1), block_sizes={-1: 16, "type": "dynamic", "scale_bits": (4, 3)} + num_bits=(2, 1), + block_sizes={-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + pass_through_bwd=pass_through_bwd, ) ).cuda() @@ -96,7 +99,11 @@ def test_fp4_backward(self): loss = fp4_quantizer(x).sum() loss.backward() - assert torch.allclose(x.grad, torch.ones_like(x.grad) * (x.abs() <= fp4_quantizer.amax)) + if pass_through_bwd: + expected_grad = torch.ones_like(x.grad) + else: + expected_grad = torch.ones_like(x.grad) * (x.abs() <= fp4_quantizer.amax) + assert torch.allclose(x.grad, expected_grad) def test_fp4_non_contiguous_input(self): contiguous_tensor = torch.ones(2, 16).cuda()