From d5db607593d68e87e493ec400fbb95438aacc50c Mon Sep 17 00:00:00 2001 From: Samnour2 Date: Fri, 8 May 2026 07:33:12 -0600 Subject: [PATCH 1/3] Store inferred output dims; skip refineDims so Y is not overwritten by stale static types on Y. --- .../ONNX/ONNXOps/Quantize/DequantizeLinear.cpp | 5 +++-- src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp | 2 +- test/mlir/onnx/onnx_shape_inference.mlir | 12 ++++++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp b/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp index 4081b359ce6..9daa4badb03 100644 --- a/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp +++ b/src/Dialect/ONNX/ONNXOps/Quantize/DequantizeLinear.cpp @@ -79,8 +79,9 @@ LogicalResult ONNXDequantizeLinearOpShapeHelper::computeShape() { } // Get values. - // Save the final result. - setOutputDims(outputDims); + // Store inferred output dims; skip refineDims so Y is not overwritten by + // stale static types on Y. + setOutputDims(outputDims, /*n=*/0, /*refineShape=*/false); return success(); } diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp index eaa04c81f35..ec5d7794a4f 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp @@ -594,7 +594,7 @@ bool ONNXBroadcastOpShapeHelper::hasManageableBroadcastForInnerDims( << " & " << d << "; abort\n"); return collapsedInnermostLoops > 0; } // End for all non-scalars, - } // End testing non-scalar compatibility. + } // End testing non-scalar compatibility. // 4) Since we have at least one non-scalar // 4.1) all the scalar inputs are now marked as having a broadcast. diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 1ac46c7a97c..527676373ce 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -2084,6 +2084,18 @@ func.func @test_dequantize_linear_2(%arg0 : tensor<5x?x3x4xi8>, %arg1 : tensor<* // ----- +// COM: inferShapes derives Y's shape from X. +func.func @test_dequantize_linear_stale_output_batch(%arg0 : tensor<2x4x8xi8>, %arg1 : tensor, %arg2 : tensor) -> tensor<1x4x8xf32> { + %1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<2x4x8xi8>, tensor, tensor) -> tensor<1x4x8xf32> + "onnx.Return"(%1) {} : (tensor<1x4x8xf32>) -> () + + // CHECK-LABEL: test_dequantize_linear_stale_output_batch + // CHECK: [[RES:%.+]] = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64, block_size = 0 : si64} : (tensor<2x4x8xi8>, tensor, tensor) -> tensor<2x4x8xf32> + // CHECK: onnx.Return [[RES]] : tensor<2x4x8xf32> +} + +// ----- + //===----------------------------------------------------------------------===// /// Test shape inference for ConvInteger operation and all its attributes. //===----------------------------------------------------------------------===// From 9ad84911ed6554c7f64d458faa40c394922cf84d Mon Sep 17 00:00:00 2001 From: Samnour2 Date: Fri, 8 May 2026 11:18:38 -0600 Subject: [PATCH 2/3] chore: formatting --- src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp index ec5d7794a4f..eaa04c81f35 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp @@ -594,7 +594,7 @@ bool ONNXBroadcastOpShapeHelper::hasManageableBroadcastForInnerDims( << " & " << d << "; abort\n"); return collapsedInnermostLoops > 0; } // End for all non-scalars, - } // End testing non-scalar compatibility. + } // End testing non-scalar compatibility. // 4) Since we have at least one non-scalar // 4.1) all the scalar inputs are now marked as having a broadcast. From 6607ae304e19d23c534ee5e4f67b1327cf36f922 Mon Sep 17 00:00:00 2001 From: Samnour2 Date: Fri, 8 May 2026 11:45:39 -0600 Subject: [PATCH 3/3] [test] update test to preserve dynamic dimensions --- .../onnx_to_tosa/NN/DequantizeLinear.mlir | 24 +++++-------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir b/test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir index a9c09d2fa71..f8477709a86 100644 --- a/test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir +++ b/test/mlir/conversion/onnx_to_tosa/NN/DequantizeLinear.mlir @@ -117,15 +117,9 @@ func.func @dynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor } // CHECK-LABEL: func.func @dynamic -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor) -> tensor<1xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = tosa.cast [[PARAM_0_]] : (tensor) -> tensor -// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor) -> tensor<1xi8> -// CHECK: [[VAR_2_:%.+]] = tosa.cast [[VAR_1_]] : (tensor<1xi8>) -> tensor<1xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = tosa.sub [[VAR_0_]], [[VAR_2_]] : (tensor, tensor<1xf32>) -> tensor -// CHECK-DAG: [[VAR_4_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor) -> tensor<1xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> -// CHECK: [[VAR_6_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_4_]], [[VAR_5_]] : (tensor, tensor<1xf32>, tensor<1xi8>) -> tensor<1xf32> -// CHECK: return [[VAR_6_]] : tensor<1xf32> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor) -> tensor { +// CHECK: [[VAR_0_:%.+]] = "onnx.DequantizeLinear"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {{.*}} : (tensor, tensor, tensor) -> tensor +// CHECK: return [[VAR_0_]] : tensor // CHECK: } // ----- @@ -153,15 +147,9 @@ func.func @dynamic3(%arg0 : tensor<2x?xi8>, %arg1 : tensor, %arg2 : tensor< } // CHECK-LABEL: func.func @dynamic3 -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x?xi8>, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor) -> tensor<2x1xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = tosa.cast [[PARAM_0_]] : (tensor<2x?xi8>) -> tensor<2x?xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor) -> tensor<1x1xi8> -// CHECK: [[VAR_2_:%.+]] = tosa.cast [[VAR_1_]] : (tensor<1x1xi8>) -> tensor<1x1xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = tosa.sub [[VAR_0_]], [[VAR_2_]] : (tensor<2x?xf32>, tensor<1x1xf32>) -> tensor<2x?xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> -// CHECK: [[VAR_6_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_4_]], [[VAR_5_]] : (tensor<2x?xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x1xf32> -// CHECK: return [[VAR_6_]] : tensor<2x1xf32> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x?xi8>, [[PARAM_1_:%.+]]: tensor, [[PARAM_2_:%.+]]: tensor) -> tensor<2x?xf32> { +// CHECK: [[VAR_0_:%.+]] = "onnx.DequantizeLinear"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {{.*}} : (tensor<2x?xi8>, tensor, tensor) -> tensor<2x?xf32> +// CHECK: return [[VAR_0_]] : tensor<2x?xf32> // CHECK: } // -----