diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index ab1ea2b292af..d54877646cd2 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -4310,12 +4310,10 @@ def _impl_v11(cls, bb, inputs, attr, params): input_tensor = inputs[0] input_shape = input_tensor.struct_info.shape + split_is_scalar = False - # If split is not provided, we split all values along axis. if len(inputs) == 1: split = _np.array(1) - if not keepdims: - raise NotImplementedError("Only keepdims=1 is supported for now") else: split = inputs[1] if not isinstance(split, relax.Constant): @@ -4326,15 +4324,30 @@ def _impl_v11(cls, bb, inputs, attr, params): split = _np.cumsum(split) split = list(split[:-1]) else: - chunk_size, dim_size = int(split), input_shape[axis] - if dim_size % chunk_size != 0: - raise ValueError( - f"Dimension of size {dim_size} along axis {axis} must be " - f"evenly divisible by chunk size {chunk_size}" + chunk_size = int(split) + dim_size = input_shape[axis] + + if isinstance(dim_size, (int, tirx.IntImm)): + dim_size_int = int(dim_size) + split = math.ceil(dim_size_int / chunk_size) + else: + raise NotImplementedError( + "SplitToSequence with dynamic dim size and scalar split is not supported." ) - split = dim_size // chunk_size output = relax.op.split(input_tensor, split, axis=axis) + + # keepdims=0 applies when split is a scalar (whether provided or defaulted to 1) + # Per ONNX spec: "If input 'split' is specified, this attribute is ignored." + if not keepdims and len(inputs) == 1: + output = bb.emit(output) + n = len(output.struct_info.fields) + squeezed = [ + relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), axis=[axis]) + for i in range(n) + ] + return relax.Tuple(squeezed) + return output diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 7f9cd177ad44..8111b95c4bfb 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -5458,5 +5458,75 @@ def test_arg_min_max_select_last_index_ir(op_name): assert "relax.subtract" in call_ops, f"Expected relax.subtract in IR, got {call_ops}" +@pytest.mark.parametrize("axis", [0, 1, 2]) +def test_split_to_sequence_keepdims_0(axis: int): + """keepdims=0, no split input: each chunk of size 1 has the split axis squeezed out.""" + shape = [3, 4, 5] + out_shape = [s for i, s in enumerate(shape) if i != axis] + + split_to_seq_node = helper.make_node( + "SplitToSequence", + ["data"], # no split input — keepdims applies here + ["output"], + axis=axis, + keepdims=0, + ) + graph = helper.make_graph( + [split_to_seq_node], + f"test_split_to_sequence_keepdims_0_axis{axis}", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, shape)], + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, out_shape)], + ) + model = helper.make_model(graph, producer_name="test_split_to_sequence_keepdims_0") + check_correctness(model) + + +def test_split_to_sequence_keepdims_ignored_when_split_provided(): + """Per spec: keepdims is ignored when split input is provided. + TVM follows the spec — output keeps the split axis even with keepdims=0.""" + split_node = make_constant_node("split", TensorProto.INT64, (), [1]) + split_to_seq_node = helper.make_node( + "SplitToSequence", + ["data", "split"], + ["output"], + axis=0, + keepdims=0, + ) + graph = helper.make_graph( + [split_node, split_to_seq_node], + "test_split_to_sequence_keepdims_ignored", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [4, 5])], + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [1, 5])], + ) + model = helper.make_model( + graph, + producer_name="test_split_to_sequence_keepdims_ignored", + opset_imports=[helper.make_opsetid("", 11)], + ) + model.ir_version = 8 + # Cannot use check_correctness here as ORT deviates from the spec for this case + from tvm.relax.frontend.onnx import from_onnx + tvm_model = from_onnx(model, opset=11, keep_params_in_input=True) + assert tvm_model is not None + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_split_to_sequence_uneven_last_chunk(axis: int): + """Spec: last chunk may be smaller if dim is not divisible by scalar split.""" + shape = [5, 4] if axis == 0 else [3, 5] + split_node = make_constant_node("split", TensorProto.INT64, (), [2]) + split_to_seq_node = helper.make_node( + "SplitToSequence", ["data", "split"], ["output"], axis=axis, keepdims=1 + ) + graph = helper.make_graph( + [split_node, split_to_seq_node], + f"test_split_to_sequence_uneven_axis{axis}", + inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, shape)], + outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, None)], + ) + model = helper.make_model(graph, producer_name="test_split_to_sequence_uneven") + check_correctness(model) + + if __name__ == "__main__": tvm.testing.main()