Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 29 additions & 33 deletions core/src/ops/change_axes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -964,47 +964,32 @@ pub fn perm_to_ops(input: &[usize]) -> TVec<AxisOp> {

pub fn compute_shape_with_tf_rules(input: &[TDim], shape_spec: &[TDim]) -> TractResult<TVec<TDim>> {
let mut shape: TVec<TDim> = shape_spec.into();
fn deal_with_zero<'a>(
mut input_dims: std::iter::Peekable<impl Iterator<Item = &'a TDim>>,
shape: &mut [TDim],
) -> TractResult<()> {
let mut remaining_dim_input = 1.to_dim();
for slot in shape.iter_mut() {
if *slot == (-1).into() {
break;
}
if *slot == 0.into() {
if remaining_dim_input != TDim::one() {
bail!("Invalid remaining dim");
}
*slot = (*input_dims.peek().context("Invalid")?).clone();
}
loop {
let quotient = remaining_dim_input.maybe_div(slot);
if quotient.is_err() || quotient.as_ref().unwrap().1 != 1 {
remaining_dim_input *= input_dims.next().context("Invalid")?;
} else {
break;
}
}
remaining_dim_input = remaining_dim_input.maybe_div(slot)?.0;
// Replace 0s with corresponding input dims (positional, per ONNX/TF spec)
for (i, s) in shape.iter_mut().enumerate() {
if *s == 0.into() {
*s = input
.get(i)
.with_context(|| {
format!("Reshape: 0 at position {i} but input only has {} dims", input.len())
})?
.clone();
}
Ok(())
}

deal_with_zero(input.iter().peekable(), &mut shape)?;
shape.reverse();
deal_with_zero(input.iter().rev().peekable(), &mut shape)?;
shape.reverse();

let input_vol: TDim = input.iter().product();
if let Some(pos) = shape.iter().position(|d| *d == (-1).into()) {
let input_vol: TDim = input.iter().product();
let shape_vol: TDim = shape.iter().filter(|d| **d != (-1).into()).product();
let div = input_vol.maybe_div(&shape_vol)?;
if div.1 != 1 {
bail!("invalid")
}
shape[pos] = div.0;
} else {
let shape_vol: TDim = shape.iter().product();
if input_vol != shape_vol {
bail!(
"Reshape volume mismatch: input {input:?} (vol={input_vol}) vs shape {shape:?} (vol={shape_vol})"
);
}
}
Ok(shape)
}
Expand Down Expand Up @@ -1046,7 +1031,9 @@ pub fn to_axis_ops_with_tf_rules(
}
}
}
todo!()
bail!(
"Could not find matching reshape grouping: current_input={current_input:?} final_output={final_output:?} common={common}"
)
}
} else if final_output.len() > current_input.len() {
stack.push(AxisOp::Add(current_input.len()));
Expand Down Expand Up @@ -1628,6 +1615,15 @@ mod proptests {
)
}

#[test]
fn compute_zero_with_rank_change() {
// Moonshine RoPE: input rank 4, output rank 5, two leading 0s
assert_eq!(
&*compute_shape_with_tf_rules(s![1, 52, 8, 32], s!(0, 0, 8, 16, 2)).unwrap(),
s![1, 52, 8, 16, 2]
)
}

#[test]
fn axis_op_rm_begin() {
assert_eq!(&*to_axis_ops_with_tf_rules(s![1, 2, 3], s!(2, 3)).unwrap(), &[Rm(0)])
Expand Down
2 changes: 1 addition & 1 deletion onnx/src/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn cast(
node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let mut to = node.get_attr::<DatumType>("to")?;
if to == i64::datum_type() {
if to == i64::datum_type() || to == i32::datum_type() {
to = TDim::datum_type();
}
Ok((ElementWiseOp(Box::new(Cast::new(to)), None).into_hir(), vec![]))
Expand Down
Loading