From f2222b3b8e3fe89372212d5c23c21cb8b714bd03 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Fri, 17 Apr 2026 08:50:16 +0000 Subject: [PATCH 1/2] Fix Reshape with 0-dims and rank change (issue #2104) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit compute_shape_with_tf_rules used a volume-matching approach to replace 0s in the shape spec, which failed when consecutive input dims of value 1 caused the iterator to not advance. Replace with simple positional substitution per the ONNX spec: shape[i]=0 means copy input[i]. Fixes Moonshine TTS model loading where RoPE reshapes like [1,52,8,32] → [0,0,8,16,2] produced [1,1,8,16,2] instead of [1,52,8,16,2]. --- core/src/ops/change_axes.rs | 62 +++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/core/src/ops/change_axes.rs b/core/src/ops/change_axes.rs index 7fcd52e966..a1c597379d 100644 --- a/core/src/ops/change_axes.rs +++ b/core/src/ops/change_axes.rs @@ -964,47 +964,32 @@ pub fn perm_to_ops(input: &[usize]) -> TVec { pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult> { let mut shape: TVec = shape_spec.into(); - fn deal_with_zero<'a>( - mut input_dims: std::iter::Peekable>, - shape: &mut [TDim], - ) -> TractResult<()> { - let mut remaining_dim_input = 1.to_dim(); - for slot in shape.iter_mut() { - if *slot == (-1).into() { - break; - } - if *slot == 0.into() { - if remaining_dim_input != TDim::one() { - bail!("Invalid remaining dim"); - } - *slot = (*input_dims.peek().context("Invalid")?).clone(); - } - loop { - let quotient = remaining_dim_input.maybe_div(slot); - if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 { - remaining_dim_input *= input_dims.next().context("Invalid")?; - } else { - break; - } - } - remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0; + // Replace 0s with corresponding input dims (positional, per ONNX/TF spec) + for (i, s) in shape.iter_mut().enumerate() { + if *s == 0.into() { + *s = input + .get(i) + .with_context(|| { + format!("Reshape: 0 at position {i} but input only has {} dims", input.len()) + })? + .clone(); } - Ok(()) } - - deal_with_zero(input.iter().peekable(), &mut shape)?; - shape.reverse(); - deal_with_zero(input.iter().rev().peekable(), &mut shape)?; - shape.reverse(); - + let input_vol: TDim = input.iter().product(); if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) { - let input_vol: TDim = input.iter().product(); let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product(); let div = input_vol.maybe_div(&shape_vol)?; if div.1 != 1 { bail!("invalid") } shape[pos] = div.0; + } else { + let shape_vol: TDim = shape.iter().product(); + if input_vol != shape_vol { + bail!( + "Reshape volume mismatch: input {input:?} (vol={input_vol}) vs shape {shape:?} (vol={shape_vol})" + ); + } } Ok(shape) } @@ -1046,7 +1031,9 @@ pub fn to_axis_ops_with_tf_rules( } } } - todo!() + bail!( + "Could not find matching reshape grouping: current_input={current_input:?} final_output={final_output:?} common={common}" + ) } } else if final_output.len() > current_input.len() { stack.push(AxisOp::Add(current_input.len())); @@ -1628,6 +1615,15 @@ mod proptests { ) } + #[test] + fn compute_zero_with_rank_change() { + // Moonshine RoPE: input rank 4, output rank 5, two leading 0s + assert_eq!( + &*compute_shape_with_tf_rules(s![1, 52, 8, 32], s!(0, 0, 8, 16, 2)).unwrap(), + s![1, 52, 8, 16, 2] + ) + } + #[test] fn axis_op_rm_begin() { assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)]) From e9f40cdbfd8fbbfff6a6fea09f1c085fa4350b04 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Fri, 17 Apr 2026 17:06:37 +0000 Subject: [PATCH 2/2] onnx: promote Cast(to=i32) to Cast(to=TDim) like i64 ONNX Shape/Size outputs are declared TDim in tract so symbolic dims survive shape-plumbing chains. The Cast loader already rewrites Cast(to=i64) to Cast(to=TDim) to keep that invariant when the exporter inserts an explicit int64 round-trip. Some exporters (e.g. Moonshine) cast shape values to int32 instead. Widen the same rewrite to i32 so those chains also stay in TDim; without this, the first Cast(TDim->i32) loses the symbol and downstream Reshape lowering bails with "shape input is variable". --- onnx/src/ops/cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx/src/ops/cast.rs b/onnx/src/ops/cast.rs index 7ad2471f73..c96a4380d6 100644 --- a/onnx/src/ops/cast.rs +++ b/onnx/src/ops/cast.rs @@ -14,7 +14,7 @@ fn cast( node: &NodeProto, ) -> TractResult<(Box, Vec)> { let mut to = node.get_attr::("to")?; - if to == i64::datum_type() { + if to == i64::datum_type() || to == i32::datum_type() { to = TDim::datum_type(); } Ok((ElementWiseOp(Box::new(Cast::new(to)), None).into_hir(), vec![]))