From 9fc8306f992ecf1de2899fc277cd494b940794d7 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Mon, 27 Apr 2026 15:30:14 +0000 Subject: [PATCH] [Fix]: Patch zero FP16 scales in INT4_AWQ ONNX export (NVBug 6110209) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit replace_zero_scale_with_smallest_nonzero() in qdq_utils.py only inspected QuantizeLinear consumers and Constant-node producers, which made it a no-op for INT4_AWQ exports — those use DequantizeLinear (default and trt:: domain) consumers and store scales as graph initializers. Zero scales produced when the FP32→FP16 cast underflows therefore reached TensorRT, causing trtexec --stronglyTyped to fail with "Scale coefficients must all be positive". Extend the sanitizer to also walk DequantizeLinear / TRT_INT4DequantizeLinear nodes and to patch initializer-backed scales, while preserving dtype. Add regression tests for both the initializer + DQ path (default and trt:: domain) and the legacy Constant + QuantizeLinear path. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/onnx/quantization/qdq_utils.py | 31 ++++++- .../unit/onnx/quantization/test_qdq_utils.py | 87 +++++++++++++++++++ 2 files changed, 114 insertions(+), 4 deletions(-) diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 0cb1a45f681..265bcf36b2a 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -1011,14 +1011,37 @@ def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onn """Replace zero scale values with smallest nonzero fp16 value in the ONNX model.""" graph = onnx_model.graph fp16_smallest_nonzero = np.float16(6e-08) - scale_nodes = [node.input[1] for node in graph.node if node.op_type == "QuantizeLinear"] + qdq_op_types = { + "QuantizeLinear", + "DequantizeLinear", + "TRT_INT4QuantizeLinear", + "TRT_INT4DequantizeLinear", + } + scale_tensor_names = { + node.input[1] + for node in graph.node + if node.op_type in qdq_op_types and len(node.input) >= 2 + } + # Scales stored as graph initializers (e.g. INT4_AWQ / TRT_INT4DequantizeLinear exports). + for init in graph.initializer: + if init.name in scale_tensor_names: + tensor = numpy_helper.to_array(init) + if tensor.dtype.kind == "f": + new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor).astype( + tensor.dtype + ) + init.CopyFrom(numpy_helper.from_array(new_tensor, init.name)) + # Scales emitted by Constant nodes (legacy QDQ export path). for node in graph.node: - if node.op_type == "Constant" and node.output[0] in scale_nodes: + if node.op_type == "Constant" and node.output[0] in scale_tensor_names: for attr in node.attribute: if attr.name == "value": tensor = numpy_helper.to_array(attr.t) - new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor) - attr.t.CopyFrom(numpy_helper.from_array(new_tensor, attr.t.name)) + if tensor.dtype.kind == "f": + new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor).astype( + tensor.dtype + ) + attr.t.CopyFrom(numpy_helper.from_array(new_tensor, attr.t.name)) return onnx_model diff --git a/tests/unit/onnx/quantization/test_qdq_utils.py b/tests/unit/onnx/quantization/test_qdq_utils.py index 42aa317119f..8af5f560dd0 100644 --- a/tests/unit/onnx/quantization/test_qdq_utils.py +++ b/tests/unit/onnx/quantization/test_qdq_utils.py @@ -1021,3 +1021,90 @@ def test_column_major_gemm_trans_b_flip(self): print(f"transB flipped: 1 -> {trans_b_value}") print(f"Transpose nodes: {len(transpose_nodes)}") + + +def _build_model_with_zero_scale_initializer(dq_op_type: str): + """Build an ONNX model whose scale initializer feeds a (Quantize|Dequantize)Linear node. + + Mirrors the INT4_AWQ failure mode from NVBug 6110209: scales live in graph initializers + (not Constant nodes) and feed DequantizeLinear (default or trt:: domain) consumers. + """ + weight_data = np.random.randint(-8, 8, size=(6, 8), dtype=np.int8) + weight_tensor = numpy_helper.from_array(weight_data, "weight") + + scale_data = np.array([1e-3, 0.0, 5e-4, 0.0, 0.0, 2e-3], dtype=np.float16).reshape(6, 1) + scale_tensor = numpy_helper.from_array(scale_data, "scale") + + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT16, [None, 6]) + dq_node = helper.make_node( + dq_op_type, inputs=["weight", "scale"], outputs=["dq_output"], name="weight_dq" + ) + matmul_node = helper.make_node( + "MatMul", inputs=["input", "dq_output"], outputs=["output"], name="matmul" + ) + graph = helper.make_graph( + nodes=[dq_node, matmul_node], + name="test_graph", + inputs=[input_tensor], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT16, [None, 8])], + initializer=[weight_tensor, scale_tensor], + ) + return helper.make_model(graph) + + +class TestReplaceZeroScaleWithSmallestNonzero: + """Regression tests for ``replace_zero_scale_with_smallest_nonzero`` (NVBug 6110209).""" + + @pytest.mark.parametrize("dq_op_type", ["DequantizeLinear", "TRT_INT4DequantizeLinear"]) + def test_zero_scale_initializer_fed_to_dq_is_patched(self, dq_op_type): + from modelopt.onnx.quantization.qdq_utils import replace_zero_scale_with_smallest_nonzero + + model = _build_model_with_zero_scale_initializer(dq_op_type) + scale_before = numpy_helper.to_array( + next(init for init in model.graph.initializer if init.name == "scale") + ) + assert (scale_before == 0).any(), "fixture must contain zeros to exercise the fix" + + patched = replace_zero_scale_with_smallest_nonzero(model) + + scale_after_init = next(init for init in patched.graph.initializer if init.name == "scale") + scale_after = numpy_helper.to_array(scale_after_init) + assert not (scale_after == 0).any() + assert (scale_after > 0).all() + assert scale_after_init.data_type == TensorProto.FLOAT16 + + def test_constant_node_scale_path_still_patched(self): + """Legacy Constant-node QDQ path must continue to be patched.""" + from modelopt.onnx.quantization.qdq_utils import replace_zero_scale_with_smallest_nonzero + + scale_data = np.array([1e-3, 0.0, 2e-3], dtype=np.float16) + scale_const = helper.make_node( + "Constant", + inputs=[], + outputs=["scale_out"], + value=numpy_helper.from_array(scale_data), + name="scale_constant", + ) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [3]) + q_node = helper.make_node( + "QuantizeLinear", + inputs=["input", "scale_out"], + outputs=["q_output"], + name="q", + ) + graph = helper.make_graph( + nodes=[scale_const, q_node], + name="test_graph", + inputs=[input_tensor], + outputs=[helper.make_tensor_value_info("q_output", TensorProto.INT8, [3])], + initializer=[], + ) + model = helper.make_model(graph) + + patched = replace_zero_scale_with_smallest_nonzero(model) + + const = next(n for n in patched.graph.node if n.op_type == "Constant") + value_attr = next(a for a in const.attribute if a.name == "value") + scale_arr = numpy_helper.to_array(value_attr.t) + assert not (scale_arr == 0).any() + assert (scale_arr > 0).all()