diff --git a/core/src/ops/change_axes.rs b/core/src/ops/change_axes.rs index 7fcd52e966..a1c597379d 100644 --- a/core/src/ops/change_axes.rs +++ b/core/src/ops/change_axes.rs @@ -964,47 +964,32 @@ pub fn perm_to_ops(input: &[usize]) -> TVec { pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult> { let mut shape: TVec = shape_spec.into(); - fn deal_with_zero<'a>( - mut input_dims: std::iter::Peekable>, - shape: &mut [TDim], - ) -> TractResult<()> { - let mut remaining_dim_input = 1.to_dim(); - for slot in shape.iter_mut() { - if *slot == (-1).into() { - break; - } - if *slot == 0.into() { - if remaining_dim_input != TDim::one() { - bail!("Invalid remaining dim"); - } - *slot = (*input_dims.peek().context("Invalid")?).clone(); - } - loop { - let quotient = remaining_dim_input.maybe_div(slot); - if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 { - remaining_dim_input *= input_dims.next().context("Invalid")?; - } else { - break; - } - } - remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0; + // Replace 0s with corresponding input dims (positional, per ONNX/TF spec) + for (i, s) in shape.iter_mut().enumerate() { + if *s == 0.into() { + *s = input + .get(i) + .with_context(|| { + format!("Reshape: 0 at position {i} but input only has {} dims", input.len()) + })? + .clone(); } - Ok(()) } - - deal_with_zero(input.iter().peekable(), &mut shape)?; - shape.reverse(); - deal_with_zero(input.iter().rev().peekable(), &mut shape)?; - shape.reverse(); - + let input_vol: TDim = input.iter().product(); if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) { - let input_vol: TDim = input.iter().product(); let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product(); let div = input_vol.maybe_div(&shape_vol)?; if div.1 != 1 { bail!("invalid") } shape[pos] = div.0; + } else { + let shape_vol: TDim = shape.iter().product(); + if input_vol != shape_vol { + bail!( + "Reshape volume mismatch: input {input:?} (vol={input_vol}) vs shape {shape:?} (vol={shape_vol})" + ); + } } Ok(shape) } @@ -1046,7 +1031,9 @@ pub fn to_axis_ops_with_tf_rules( } } } - todo!() + bail!( + "Could not find matching reshape grouping: current_input={current_input:?} final_output={final_output:?} common={common}" + ) } } else if final_output.len() > current_input.len() { stack.push(AxisOp::Add(current_input.len())); @@ -1628,6 +1615,15 @@ mod proptests { ) } + #[test] + fn compute_zero_with_rank_change() { + // Moonshine RoPE: input rank 4, output rank 5, two leading 0s + assert_eq!( + &*compute_shape_with_tf_rules(s![1, 52, 8, 32], s!(0, 0, 8, 16, 2)).unwrap(), + s![1, 52, 8, 16, 2] + ) + } + #[test] fn axis_op_rm_begin() { assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)]) diff --git a/onnx/src/ops/cast.rs b/onnx/src/ops/cast.rs index 7ad2471f73..c96a4380d6 100644 --- a/onnx/src/ops/cast.rs +++ b/onnx/src/ops/cast.rs @@ -14,7 +14,7 @@ fn cast( node: &NodeProto, ) -> TractResult<(Box, Vec)> { let mut to = node.get_attr::("to")?; - if to == i64::datum_type() { + if to == i64::datum_type() || to == i32::datum_type() { to = TDim::datum_type(); } Ok((ElementWiseOp(Box::new(Cast::new(to)), None).into_hir(), vec![]))