Skip to content

Fix ReshapeFusion dropping allowzero on inferred 0-sized intermediate dims#28349

Open
titaiwangms wants to merge 1 commit intomainfrom
fix/reshape-fusion-allowzero
Open

Fix ReshapeFusion dropping allowzero on inferred 0-sized intermediate dims#28349
titaiwangms wants to merge 1 commit intomainfrom
fix/reshape-fusion-allowzero

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

Description

ReshapeFusion::FuseContiguousReshapes collapses a chain of Reshape / Squeeze / Unsqueeze nodes into a single Reshape whose shape data is taken verbatim from the fully-inferred output shape of the last node in the chain. The new node is created without an allowzero attribute, so it defaults to allowzero = 0.

When that inferred shape contains a literal 0 dim (legitimate when the original chain used allowzero=1, or when intermediate tensors had zero-sized dimensions), the fused Reshape misinterprets the 0 as "copy the corresponding dim from the input tensor" — but the input here is the original input of the first reshape in the chain, with unrelated dims. The result is a silently wrong output shape (and a benign-looking MergeShapeInfo warning at graph load).

Repro (before the fix)

import numpy as np, onnx, onnxruntime as ort, onnx.reference
from onnx import helper, TensorProto

X  = helper.make_tensor_value_info("X", TensorProto.FLOAT, [0, 6, 2])
Y  = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [None, None, None])
s1 = helper.make_tensor("s1", TensorProto.INT64, [3], [3, 2, -1])
s2 = helper.make_tensor("s2", TensorProto.INT64, [3], [0, 0, 3])

n1 = helper.make_node("Reshape", ["X",   "s1"], ["mid"])
n2 = helper.make_node("Reshape", ["mid", "s2"], ["Y"], allowzero=1)
m  = helper.make_model(helper.make_graph([n1, n2], "g", [X], [Y], initializer=[s1, s2]),
                       opset_imports=[helper.make_opsetid("", 18)])

inp = np.random.default_rng(7).random((0, 6, 2), dtype=np.float32)
print("REF:", onnx.reference.ReferenceEvaluator(m).run(None, {"X": inp})[0].shape)
print("ORT:", ort.InferenceSession(m.SerializeToString(),
                                   providers=["CPUExecutionProvider"]).run(None, {"X": inp})[0].shape)

Output on main (40c9f85f69):

REF: (0, 0, 3)
[W ... graph.cc:122 MergeShapeInfo] Error merging shape info for output. 'Y' source:{0,6,3} target:{0,0,3}. Falling back to lenient merge.
ORT: (0, 6, 3)   ❌

Fix

Setting allowzero=1 on the fused node would also work but requires opset >= 14, which this transformer cannot assume (it accepts Reshape opset 5+). Bail out of fusion conservatively when shape_value contains any literal 0 dim.

Test

Adds ReshapeFusionContiguousReshapesWithZeroDim that builds the bug repro programmatically and asserts:

  • the two reshapes are NOT collapsed
  • the inferred output shape stays (0, 0, 3)

The existing happy-path test ReshapeFusion_Contiguous_Reshape (added in #22494) is unaffected — its inferred output shape (2, 1, 64, 32) contains no zero dims, so the new guard does not trigger.

Provenance

FuseContiguousReshapes was introduced in #22494 (Feb 2025). The bug has been latent in main since then.

Motivation and Context

Found while reviewing microsoft/onnxscript#2907 — the rewriter rule under test there is semantically correct, but its numerical-equivalence check using ORT as the oracle fails because of this fusion bug.

Fixes #28348.

… dims

FuseContiguousReshapes collapses a chain of Reshape/Squeeze/Unsqueeze
nodes into a single Reshape whose shape data is taken verbatim from the
fully-inferred output shape of the last node in the chain. The new node
is created without an allowzero attribute, so it defaults to allowzero=0.

When that inferred shape contains a literal 0 dim (legitimate when the
original chain used allowzero=1, or when intermediate tensors had
zero-sized dimensions), the fused Reshape misinterprets the 0 as 'copy
the corresponding dim from the input tensor' -- but the input here is
the original input of the *first* reshape in the chain, with unrelated
dims. The result is a silently wrong output shape (and a benign-looking
MergeShapeInfo warning at graph load).

Setting allowzero=1 on the fused node would fix this but requires opset
>= 14, which the transformer cannot assume (it accepts Reshape opset
5+). Bail out of fusion conservatively when the inferred shape contains
any literal 0 dim.

Adds a regression test that builds a Reshape -> Reshape chain whose
inferred intermediate has zero-sized dims and asserts that the fusion
no longer collapses it (and that the inferred output shape stays
correct).

Fixes #28348.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms requested review from Lafi7e and justinchuby May 4, 2026 20:11
@titaiwangms titaiwangms enabled auto-merge (squash) May 4, 2026 21:04
@justinchuby justinchuby requested a review from Copilot May 4, 2026 21:09
// literally 0, fusing into a single Reshape is unsafe: ONNX Reshape with the default
// allowzero=0 would reinterpret the 0 as "copy from input", producing the wrong shape.
// Setting allowzero=1 would fix it but requires opset >= 14, which we cannot assume
// here (this transformer accepts Reshape opset 5+). Bail out conservatively.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to check opset version?

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a correctness bug in the Reshape fusion optimizer: when collapsing contiguous Reshape/Squeeze/Unsqueeze chains, the fused node could mis-handle literal 0 dimensions because it did not preserve allowzero semantics. In the optimizer stack, this prevents a silent shape corruption during graph rewriting.

Changes:

  • Add a guard in ReshapeFusion::FuseContiguousReshapes to skip fusion when the inferred fused shape contains any literal 0 dimension.
  • Add a regression test covering a Reshape -> Reshape chain where the second reshape uses allowzero=1 and the correct output shape contains zeros.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
onnxruntime/core/optimizer/reshape_fusion.cc Adds a conservative early-exit to avoid unsafe contiguous reshape fusion for inferred zero dims.
onnxruntime/test/optimizer/graph_transform_test.cc Adds a regression test asserting the unsafe reshape chain is not fused and inferred output shape stays correct.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@titaiwangms titaiwangms disabled auto-merge May 4, 2026 21:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ReshapeFusion drops allowzero, producing wrong shape when inferred intermediate has 0-sized dim

3 participants