Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4310,12 +4310,10 @@ def _impl_v11(cls, bb, inputs, attr, params):

input_tensor = inputs[0]
input_shape = input_tensor.struct_info.shape
split_is_scalar = False

# If split is not provided, we split all values along axis.
if len(inputs) == 1:
split = _np.array(1)
if not keepdims:
raise NotImplementedError("Only keepdims=1 is supported for now")
else:
split = inputs[1]
if not isinstance(split, relax.Constant):
Expand All @@ -4326,15 +4324,30 @@ def _impl_v11(cls, bb, inputs, attr, params):
split = _np.cumsum(split)
split = list(split[:-1])
else:
chunk_size, dim_size = int(split), input_shape[axis]
if dim_size % chunk_size != 0:
raise ValueError(
f"Dimension of size {dim_size} along axis {axis} must be "
f"evenly divisible by chunk size {chunk_size}"
chunk_size = int(split)
dim_size = input_shape[axis]

if isinstance(dim_size, (int, tirx.IntImm)):
dim_size_int = int(dim_size)
split = math.ceil(dim_size_int / chunk_size)
else:
raise NotImplementedError(
"SplitToSequence with dynamic dim size and scalar split is not supported."
)
split = dim_size // chunk_size

output = relax.op.split(input_tensor, split, axis=axis)

# keepdims=0 applies when split is a scalar (whether provided or defaulted to 1)
# Per ONNX spec: "If input 'split' is specified, this attribute is ignored."
if not keepdims and len(inputs) == 1:
output = bb.emit(output)
n = len(output.struct_info.fields)
squeezed = [
relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), axis=[axis])
for i in range(n)
]
return relax.Tuple(squeezed)

return output


Expand Down
70 changes: 70 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5458,5 +5458,75 @@ def test_arg_min_max_select_last_index_ir(op_name):
assert "relax.subtract" in call_ops, f"Expected relax.subtract in IR, got {call_ops}"


@pytest.mark.parametrize("axis", [0, 1, 2])
def test_split_to_sequence_keepdims_0(axis: int):
"""keepdims=0, no split input: each chunk of size 1 has the split axis squeezed out."""
shape = [3, 4, 5]
out_shape = [s for i, s in enumerate(shape) if i != axis]

split_to_seq_node = helper.make_node(
"SplitToSequence",
["data"], # no split input — keepdims applies here
["output"],
axis=axis,
keepdims=0,
)
graph = helper.make_graph(
[split_to_seq_node],
f"test_split_to_sequence_keepdims_0_axis{axis}",
inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, shape)],
outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, out_shape)],
)
model = helper.make_model(graph, producer_name="test_split_to_sequence_keepdims_0")
check_correctness(model)


def test_split_to_sequence_keepdims_ignored_when_split_provided():
"""Per spec: keepdims is ignored when split input is provided.
TVM follows the spec — output keeps the split axis even with keepdims=0."""
split_node = make_constant_node("split", TensorProto.INT64, (), [1])
split_to_seq_node = helper.make_node(
"SplitToSequence",
["data", "split"],
["output"],
axis=0,
keepdims=0,
)
graph = helper.make_graph(
[split_node, split_to_seq_node],
"test_split_to_sequence_keepdims_ignored",
inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [4, 5])],
outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [1, 5])],
)
model = helper.make_model(
graph,
producer_name="test_split_to_sequence_keepdims_ignored",
opset_imports=[helper.make_opsetid("", 11)],
)
model.ir_version = 8
# Cannot use check_correctness here as ORT deviates from the spec for this case
from tvm.relax.frontend.onnx import from_onnx
tvm_model = from_onnx(model, opset=11, keep_params_in_input=True)
assert tvm_model is not None


@pytest.mark.parametrize("axis", [0, 1])
def test_split_to_sequence_uneven_last_chunk(axis: int):
"""Spec: last chunk may be smaller if dim is not divisible by scalar split."""
shape = [5, 4] if axis == 0 else [3, 5]
split_node = make_constant_node("split", TensorProto.INT64, (), [2])
split_to_seq_node = helper.make_node(
"SplitToSequence", ["data", "split"], ["output"], axis=axis, keepdims=1
)
graph = helper.make_graph(
[split_node, split_to_seq_node],
f"test_split_to_sequence_uneven_axis{axis}",
inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, shape)],
outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, None)],
)
model = helper.make_model(graph, producer_name="test_split_to_sequence_uneven")
check_correctness(model)


if __name__ == "__main__":
tvm.testing.main()
Loading