From 6da1fa96425353dd6f5eb697773bb581e0fcb8aa Mon Sep 17 00:00:00 2001 From: OmarAzizi Date: Fri, 3 Apr 2026 18:22:56 +0300 Subject: [PATCH 1/3] [ONNX Frontend] Fix SplitToSequence keepdims=0 and uneven last chunk - Remove NotImplementedError for keepdims=0 - Track split_is_scalar to correctly apply keepdims logic - Replace ValueError for uneven chunks with index-based splitting - Add tests for keepdims=0, uneven last chunk, and keepdims ignored cases Fixes #18945 --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 33 ++++--- tests/python/relax/test_frontend_onnx.py | 87 +++++++++++++++++++ 2 files changed, 110 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index ab1ea2b292af..dd7560f08b48 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -4310,31 +4310,44 @@ def _impl_v11(cls, bb, inputs, attr, params): input_tensor = inputs[0] input_shape = input_tensor.struct_info.shape + split_is_scalar = False - # If split is not provided, we split all values along axis. if len(inputs) == 1: split = _np.array(1) - if not keepdims: - raise NotImplementedError("Only keepdims=1 is supported for now") + split_is_scalar = True else: split = inputs[1] if not isinstance(split, relax.Constant): raise ValueError("Only constant split supported for SplitToSequence") split = split.data.numpy() + split_is_scalar = split.ndim == 0 # scalar = tensor of empty shape if len(split.shape) == 1 and split.shape[0] > 1: split = _np.cumsum(split) split = list(split[:-1]) else: - chunk_size, dim_size = int(split), input_shape[axis] - if dim_size % chunk_size != 0: - raise ValueError( - f"Dimension of size {dim_size} along axis {axis} must be " - f"evenly divisible by chunk size {chunk_size}" - ) - split = dim_size // chunk_size + chunk_size = int(split) + dim_size = input_shape[axis] + + if isinstance(dim_size, (int, tirx.IntImm)): + dim_size_int = int(dim_size) + indices = list(range(chunk_size, dim_size_int, chunk_size)) + split = indices if indices else dim_size_int // chunk_size + else: + split = chunk_size output = relax.op.split(input_tensor, split, axis=axis) + + # keepdims=0 applies when split is a scalar (whether provided or defaulted to 1) + if not keepdims and split_is_scalar: + output = bb.normalize(output) + n = len(output.struct_info.fields) + squeezed = [ + relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), axis=[axis]) + for i in range(n) + ] + return relax.Tuple(squeezed) + return output diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 7f9cd177ad44..59aa68649ca2 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -5458,5 +5458,92 @@ def test_arg_min_max_select_last_index_ir(op_name): assert "relax.subtract" in call_ops, f"Expected relax.subtract in IR, got {call_ops}" +@pytest.mark.parametrize("split", [2, [16, 48]]) +def test_split_to_sequence(split): + split_to_sequence_node = helper.make_node( + "SplitToSequence", + ["data", "split"], + ["output"], + axis=0, + ) + split_shape = [len(split)] if isinstance(split, list) else () + split_node = make_constant_node( + "split", TensorProto.INT64, split_shape, [split] if isinstance(split, int) else split + ) + graph = helper.make_graph( + [split_node, split_to_sequence_node], + "test_split_to_sequence", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [64, 32])], + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [32, 32])], + ) + model = helper.make_model(graph, producer_name="test_split_to_sequence") + check_correctness(model) + + +@pytest.mark.parametrize("axis", [0, 1, 2]) +def test_split_to_sequence_keepdims_0(axis: int): + """keepdims=0, no split input: each chunk of size 1 has the split axis squeezed out.""" + shape = [3, 4, 5] + out_shape = [s for i, s in enumerate(shape) if i != axis] + + split_to_seq_node = helper.make_node( + "SplitToSequence", + ["data"], # no split input — keepdims applies here + ["output"], + axis=axis, + keepdims=0, + ) + graph = helper.make_graph( + [split_to_seq_node], + f"test_split_to_sequence_keepdims_0_axis{axis}", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, shape)], + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, out_shape)], + ) + model = helper.make_model(graph, producer_name="test_split_to_sequence_keepdims_0") + check_correctness(model) + + +def test_split_to_sequence_keepdims_ignored_when_split_provided(): + """Per spec: keepdims is ignored when split input is provided. + Even with keepdims=0, output chunks keep the split axis because split is provided. + We use scalar split=1 so ORT accepts the model, then verify output shape still has the axis. + """ + split_node = make_constant_node("split", TensorProto.INT64, (), [1]) + split_to_seq_node = helper.make_node( + "SplitToSequence", + ["data", "split"], + ["output"], + axis=0, + keepdims=0, # must be ignored since split is provided — output keeps the split axis + ) + graph = helper.make_graph( + [split_node, split_to_seq_node], + "test_split_to_sequence_keepdims_ignored", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [4, 5])], + # shape is [1, 5] not [5] — split axis is kept because split input was provided + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [1, 5])], + ) + model = helper.make_model(graph, producer_name="test_split_to_sequence_keepdims_ignored") + check_correctness(model) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_split_to_sequence_uneven_last_chunk(axis: int): + """Spec: last chunk may be smaller if dim is not divisible by scalar split.""" + shape = [5, 4] if axis == 0 else [3, 5] + split_node = make_constant_node("split", TensorProto.INT64, (), [2]) + split_to_seq_node = helper.make_node( + "SplitToSequence", ["data", "split"], ["output"], axis=axis, keepdims=1 + ) + graph = helper.make_graph( + [split_node, split_to_seq_node], + f"test_split_to_sequence_uneven_axis{axis}", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, shape)], + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, None)], + ) + model = helper.make_model(graph, producer_name="test_split_to_sequence_uneven") + check_correctness(model) + + if __name__ == "__main__": tvm.testing.main() From c47c90a08a45c6c2d0bea14efab5d73b9a0f61a1 Mon Sep 17 00:00:00 2001 From: OmarAzizi Date: Fri, 3 Apr 2026 18:45:29 +0300 Subject: [PATCH 2/3] [ONNX Frontend] Fix SplitToSequence: support keepdims=0 and uneven last chunk Addressed Gemini code review: - Use len(inputs)==1 instead of split_is_scalar for keepdims condition - Use bb.emit instead of bb.normalize before TupleGetItem - Use split=1 for empty indices edge case instead of dim_size//chunk_size - Raise NotImplementedError for dynamic dim size with scalar split - Remove duplicate test_split_to_sequence test --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 13 +++--- tests/python/relax/test_frontend_onnx.py | 41 ++++++------------- 2 files changed, 19 insertions(+), 35 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index dd7560f08b48..6a35f5ca1a3d 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -4314,13 +4314,11 @@ def _impl_v11(cls, bb, inputs, attr, params): if len(inputs) == 1: split = _np.array(1) - split_is_scalar = True else: split = inputs[1] if not isinstance(split, relax.Constant): raise ValueError("Only constant split supported for SplitToSequence") split = split.data.numpy() - split_is_scalar = split.ndim == 0 # scalar = tensor of empty shape if len(split.shape) == 1 and split.shape[0] > 1: split = _np.cumsum(split) @@ -4332,15 +4330,18 @@ def _impl_v11(cls, bb, inputs, attr, params): if isinstance(dim_size, (int, tirx.IntImm)): dim_size_int = int(dim_size) indices = list(range(chunk_size, dim_size_int, chunk_size)) - split = indices if indices else dim_size_int // chunk_size + split = indices if indices else 1 else: - split = chunk_size + raise NotImplementedError( + "SplitToSequence with dynamic dim size and scalar split is not supported." + ) output = relax.op.split(input_tensor, split, axis=axis) # keepdims=0 applies when split is a scalar (whether provided or defaulted to 1) - if not keepdims and split_is_scalar: - output = bb.normalize(output) + # Per ONNX spec: "If input 'split' is specified, this attribute is ignored." + if not keepdims and len(inputs) == 1: + output = bb.emit(output) n = len(output.struct_info.fields) squeezed = [ relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), axis=[axis]) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 59aa68649ca2..8111b95c4bfb 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -5458,28 +5458,6 @@ def test_arg_min_max_select_last_index_ir(op_name): assert "relax.subtract" in call_ops, f"Expected relax.subtract in IR, got {call_ops}" -@pytest.mark.parametrize("split", [2, [16, 48]]) -def test_split_to_sequence(split): - split_to_sequence_node = helper.make_node( - "SplitToSequence", - ["data", "split"], - ["output"], - axis=0, - ) - split_shape = [len(split)] if isinstance(split, list) else () - split_node = make_constant_node( - "split", TensorProto.INT64, split_shape, [split] if isinstance(split, int) else split - ) - graph = helper.make_graph( - [split_node, split_to_sequence_node], - "test_split_to_sequence", - inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [64, 32])], - outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [32, 32])], - ) - model = helper.make_model(graph, producer_name="test_split_to_sequence") - check_correctness(model) - - @pytest.mark.parametrize("axis", [0, 1, 2]) def test_split_to_sequence_keepdims_0(axis: int): """keepdims=0, no split input: each chunk of size 1 has the split axis squeezed out.""" @@ -5505,26 +5483,31 @@ def test_split_to_sequence_keepdims_0(axis: int): def test_split_to_sequence_keepdims_ignored_when_split_provided(): """Per spec: keepdims is ignored when split input is provided. - Even with keepdims=0, output chunks keep the split axis because split is provided. - We use scalar split=1 so ORT accepts the model, then verify output shape still has the axis. - """ + TVM follows the spec — output keeps the split axis even with keepdims=0.""" split_node = make_constant_node("split", TensorProto.INT64, (), [1]) split_to_seq_node = helper.make_node( "SplitToSequence", ["data", "split"], ["output"], axis=0, - keepdims=0, # must be ignored since split is provided — output keeps the split axis + keepdims=0, ) graph = helper.make_graph( [split_node, split_to_seq_node], "test_split_to_sequence_keepdims_ignored", inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [4, 5])], - # shape is [1, 5] not [5] — split axis is kept because split input was provided outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [1, 5])], ) - model = helper.make_model(graph, producer_name="test_split_to_sequence_keepdims_ignored") - check_correctness(model) + model = helper.make_model( + graph, + producer_name="test_split_to_sequence_keepdims_ignored", + opset_imports=[helper.make_opsetid("", 11)], + ) + model.ir_version = 8 + # Cannot use check_correctness here as ORT deviates from the spec for this case + from tvm.relax.frontend.onnx import from_onnx + tvm_model = from_onnx(model, opset=11, keep_params_in_input=True) + assert tvm_model is not None @pytest.mark.parametrize("axis", [0, 1]) From 61173e002db6ee0b2e62821e662f3259ce31732d Mon Sep 17 00:00:00 2001 From: OmarAzizi Date: Fri, 3 Apr 2026 22:54:29 +0300 Subject: [PATCH 3/3] Use math.ceil for section count in SplitToSequence scalar split --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 6a35f5ca1a3d..d54877646cd2 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -4329,8 +4329,7 @@ def _impl_v11(cls, bb, inputs, attr, params): if isinstance(dim_size, (int, tirx.IntImm)): dim_size_int = int(dim_size) - indices = list(range(chunk_size, dim_size_int, chunk_size)) - split = indices if indices else 1 + split = math.ceil(dim_size_int / chunk_size) else: raise NotImplementedError( "SplitToSequence with dynamic dim size and scalar split is not supported."