Skip to content

[Frontend][ONNX] Fix SplitToSequence keepdims=0 and uneven last chunk#19341

Merged
tlopex merged 3 commits intoapache:mainfrom
OmarAzizi:split-to-sequence-keepdims
Apr 3, 2026
Merged

[Frontend][ONNX] Fix SplitToSequence keepdims=0 and uneven last chunk#19341
tlopex merged 3 commits intoapache:mainfrom
OmarAzizi:split-to-sequence-keepdims

Conversation

@OmarAzizi
Copy link
Copy Markdown
Contributor

Summary

Fixes two spec violations in SplitToSequence:

  1. keepdims=0 was raising NotImplementedError. The fix squeezes the split axis from each chunk when split is scalar and keepdims=0. Per spec:

    "If input 'split' is specified [as a 1-D array], this attribute is ignored" —
    verified against ORT.

  2. Uneven last chunk was raising ValueError. The spec states: "The last chunk alone may be smaller than 'split' if the input size is not divisible by 'split'." Fixed by using index-based splitting via range(chunk_size, dim_size, chunk_size) instead of a count.

Reference: https://onnx.ai/onnx/operators/onnx__SplitToSequence.html
Closes part of #18945

- Remove NotImplementedError for keepdims=0
- Track split_is_scalar to correctly apply keepdims logic
- Replace ValueError for uneven chunks with index-based splitting
- Add tests for keepdims=0, uneven last chunk, and keepdims ignored cases

Fixes apache#18945
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the ONNX SplitToSequence implementation in the Relax frontend to support keepdims=0 and uneven last chunks, along with adding corresponding test cases. Feedback highlights a logic error where keepdims is incorrectly applied when a split input is provided, an issue with symbolic dimension handling where chunk size is misinterpreted as the number of sections, and a potential invalid argument error when the dimension size is smaller than the chunk size. Additionally, a redundant test case was identified for removal.

Comment on lines +4336 to +4337
else:
split = chunk_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When dim_size is symbolic, setting split = chunk_size is incorrect because relax.op.split interprets an integer argument as the number of sections, not the chunk size. Additionally, since SplitToSequence returns a relax.Tuple, the number of outputs must be statically known at import time. If the number of chunks depends on a dynamic dimension, it cannot be represented as a fixed-size Tuple. This case should likely raise a NotImplementedError as it is not supported by the current Relax frontend architecture.

Comment on lines +4342 to +4349
if not keepdims and split_is_scalar:
output = bb.normalize(output)
n = len(output.struct_info.fields)
squeezed = [
relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), axis=[axis])
for i in range(n)
]
return relax.Tuple(squeezed)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ONNX specification for SplitToSequence states that the keepdims attribute is ignored if the split input is specified. The current implementation applies keepdims=0 whenever split is a scalar, even if it was provided as an input. This also contradicts the PR's own test case test_split_to_sequence_keepdims_ignored_when_split_provided (lines 5506-5527), which expects keepdims to be ignored for a scalar split input. The logic should only apply keepdims when the split input is NOT provided (i.e., len(inputs) == 1). Additionally, it is recommended to emit the output before performing TupleGetItem to avoid redundant expression nodes in the IR.

Suggested change
if not keepdims and split_is_scalar:
output = bb.normalize(output)
n = len(output.struct_info.fields)
squeezed = [
relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), axis=[axis])
for i in range(n)
]
return relax.Tuple(squeezed)
# keepdims=0 applies only when split is not provided (defaults to 1)
if not keepdims and len(inputs) == 1:
output = bb.emit(output)
n = len(output.struct_info.fields)
squeezed = [
relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), axis=[axis])
for i in range(n)
]
return relax.Tuple(squeezed)

Comment on lines +4334 to +4335
indices = list(range(chunk_size, dim_size_int, chunk_size))
split = indices if indices else dim_size_int // chunk_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If dim_size_int is less than chunk_size (e.g., dim=1, chunk=2), indices will be an empty list and dim_size_int // chunk_size will evaluate to 0. Passing 0 as the number of sections to relax.op.split is invalid. It should be 1 in this case to return the entire tensor as a single chunk, which is consistent with the ONNX specification for uneven last chunks.

Suggested change
indices = list(range(chunk_size, dim_size_int, chunk_size))
split = indices if indices else dim_size_int // chunk_size
indices = list(range(chunk_size, dim_size_int, chunk_size))
split = indices if indices else 1

Comment on lines +5461 to +5480
@pytest.mark.parametrize("split", [2, [16, 48]])
def test_split_to_sequence(split):
split_to_sequence_node = helper.make_node(
"SplitToSequence",
["data", "split"],
["output"],
axis=0,
)
split_shape = [len(split)] if isinstance(split, list) else ()
split_node = make_constant_node(
"split", TensorProto.INT64, split_shape, [split] if isinstance(split, int) else split
)
graph = helper.make_graph(
[split_node, split_to_sequence_node],
"test_split_to_sequence",
inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [64, 32])],
outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [32, 32])],
)
model = helper.make_model(graph, producer_name="test_split_to_sequence")
check_correctness(model)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test function is a duplicate of the one already defined at line 3784. It should be removed to avoid redundancy.

…st chunk

Addressed Gemini code review:
- Use len(inputs)==1 instead of split_is_scalar for keepdims condition
- Use bb.emit instead of bb.normalize before TupleGetItem
- Use split=1 for empty indices edge case instead of dim_size//chunk_size
- Raise NotImplementedError for dynamic dim size with scalar split
- Remove duplicate test_split_to_sequence test
@tlopex
Copy link
Copy Markdown
Member

tlopex commented Apr 3, 2026

Hi @OmarAzizi
The indices = list(range(chunk_size, dim_size_int, chunk_size)) conversion is unnecessary here.
relax.op.split already accepts an int for indices_or_sections, and it handles the case where the last section is smaller when the dimension size is not divisible by that integer.
So the only change needed from the original code is to remove the ValueError check. This can be simplified to:
split = int(dim_size) // chunk_size

@OmarAzizi
Copy link
Copy Markdown
Contributor Author

@tlopex You are right that the integer approach works, but the calculation needed to be math.ceil(dim_size / chunk_size) instead of int(dim_size) // chunk_size.

For dim=5, chunk_size=2:

  • ceil(5/2) = 3 sections → [(2,4), (2,4), (1,4)] (matches ORT)
  • 5 // 2 = 2 sections → [(3,4), (2,4)] (wrong - doesn't match ORT)

Per the ONNX spec: "If 'split' is a scalar, then 'input' will be split into chunks all of size 'split' if possible. The last chunk alone may be smaller than 'split' if the input size along the given axis is not divisible by 'split'."

Updated the implementation accordingly.

Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@tlopex tlopex merged commit 427b66d into apache:main Apr 3, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants