Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,14 +1011,37 @@ def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onn
"""Replace zero scale values with smallest nonzero fp16 value in the ONNX model."""
graph = onnx_model.graph
fp16_smallest_nonzero = np.float16(6e-08)
scale_nodes = [node.input[1] for node in graph.node if node.op_type == "QuantizeLinear"]
qdq_op_types = {
"QuantizeLinear",
"DequantizeLinear",
"TRT_INT4QuantizeLinear",
"TRT_INT4DequantizeLinear",
}
scale_tensor_names = {
node.input[1]
for node in graph.node
if node.op_type in qdq_op_types and len(node.input) >= 2
}
# Scales stored as graph initializers (e.g. INT4_AWQ / TRT_INT4DequantizeLinear exports).
for init in graph.initializer:
if init.name in scale_tensor_names:
tensor = numpy_helper.to_array(init)
if tensor.dtype.kind == "f":
new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor).astype(
tensor.dtype
)
init.CopyFrom(numpy_helper.from_array(new_tensor, init.name))
# Scales emitted by Constant nodes (legacy QDQ export path).
for node in graph.node:
if node.op_type == "Constant" and node.output[0] in scale_nodes:
if node.op_type == "Constant" and node.output[0] in scale_tensor_names:
for attr in node.attribute:
if attr.name == "value":
tensor = numpy_helper.to_array(attr.t)
new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor)
attr.t.CopyFrom(numpy_helper.from_array(new_tensor, attr.t.name))
if tensor.dtype.kind == "f":
new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor).astype(
tensor.dtype
)
attr.t.CopyFrom(numpy_helper.from_array(new_tensor, attr.t.name))
return onnx_model


Expand Down
87 changes: 87 additions & 0 deletions tests/unit/onnx/quantization/test_qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,3 +1021,90 @@ def test_column_major_gemm_trans_b_flip(self):

print(f"transB flipped: 1 -> {trans_b_value}")
print(f"Transpose nodes: {len(transpose_nodes)}")


def _build_model_with_zero_scale_initializer(dq_op_type: str):
"""Build an ONNX model whose scale initializer feeds a (Quantize|Dequantize)Linear node.

Mirrors the INT4_AWQ failure mode from NVBug 6110209: scales live in graph initializers
(not Constant nodes) and feed DequantizeLinear (default or trt:: domain) consumers.
"""
weight_data = np.random.randint(-8, 8, size=(6, 8), dtype=np.int8)
weight_tensor = numpy_helper.from_array(weight_data, "weight")

scale_data = np.array([1e-3, 0.0, 5e-4, 0.0, 0.0, 2e-3], dtype=np.float16).reshape(6, 1)
scale_tensor = numpy_helper.from_array(scale_data, "scale")

input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT16, [None, 6])
dq_node = helper.make_node(
dq_op_type, inputs=["weight", "scale"], outputs=["dq_output"], name="weight_dq"
)
matmul_node = helper.make_node(
"MatMul", inputs=["input", "dq_output"], outputs=["output"], name="matmul"
)
graph = helper.make_graph(
nodes=[dq_node, matmul_node],
name="test_graph",
inputs=[input_tensor],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT16, [None, 8])],
initializer=[weight_tensor, scale_tensor],
)
return helper.make_model(graph)


class TestReplaceZeroScaleWithSmallestNonzero:
"""Regression tests for ``replace_zero_scale_with_smallest_nonzero`` (NVBug 6110209)."""

@pytest.mark.parametrize("dq_op_type", ["DequantizeLinear", "TRT_INT4DequantizeLinear"])
def test_zero_scale_initializer_fed_to_dq_is_patched(self, dq_op_type):
from modelopt.onnx.quantization.qdq_utils import replace_zero_scale_with_smallest_nonzero

model = _build_model_with_zero_scale_initializer(dq_op_type)
scale_before = numpy_helper.to_array(
next(init for init in model.graph.initializer if init.name == "scale")
)
assert (scale_before == 0).any(), "fixture must contain zeros to exercise the fix"

patched = replace_zero_scale_with_smallest_nonzero(model)

scale_after_init = next(init for init in patched.graph.initializer if init.name == "scale")
scale_after = numpy_helper.to_array(scale_after_init)
assert not (scale_after == 0).any()
assert (scale_after > 0).all()
assert scale_after_init.data_type == TensorProto.FLOAT16

def test_constant_node_scale_path_still_patched(self):
"""Legacy Constant-node QDQ path must continue to be patched."""
from modelopt.onnx.quantization.qdq_utils import replace_zero_scale_with_smallest_nonzero

scale_data = np.array([1e-3, 0.0, 2e-3], dtype=np.float16)
scale_const = helper.make_node(
"Constant",
inputs=[],
outputs=["scale_out"],
value=numpy_helper.from_array(scale_data),
name="scale_constant",
)
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [3])
q_node = helper.make_node(
"QuantizeLinear",
inputs=["input", "scale_out"],
outputs=["q_output"],
name="q",
)
graph = helper.make_graph(
nodes=[scale_const, q_node],
name="test_graph",
inputs=[input_tensor],
outputs=[helper.make_tensor_value_info("q_output", TensorProto.INT8, [3])],
initializer=[],
)
model = helper.make_model(graph)

patched = replace_zero_scale_with_smallest_nonzero(model)

const = next(n for n in patched.graph.node if n.op_type == "Constant")
value_attr = next(a for a in const.attribute if a.name == "value")
scale_arr = numpy_helper.to_array(value_attr.t)
assert not (scale_arr == 0).any()
assert (scale_arr > 0).all()
Loading