From 7fb973308d41d08e1ba72cd9fe7720ad75ecf60e Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Mon, 13 Apr 2026 19:41:55 +0200 Subject: [PATCH 1/3] Remove the use of shape_inference operator_options --- src/python/graph_builder.rs | 212 +++++++++++++----------------------- 1 file changed, 77 insertions(+), 135 deletions(-) diff --git a/src/python/graph_builder.rs b/src/python/graph_builder.rs index 5a27264..b3d8d0c 100644 --- a/src/python/graph_builder.rs +++ b/src/python/graph_builder.rs @@ -322,9 +322,7 @@ impl PyMLGraphBuilder { filter_layout: Option<&str>, bias: Option<&PyMLOperand>, ) -> PyResult { - use rustnn::shape_inference::{ - infer_conv2d_shape, Conv2dFilterLayout, Conv2dInputLayout, Conv2dOptions, - }; + use rustnn::shape_inference::infer_conv2d_shape; // Default values matching WebNN spec let strides = strides.unwrap_or_else(|| vec![1, 1]); @@ -332,10 +330,9 @@ impl PyMLGraphBuilder { let pads = pads.unwrap_or_else(|| vec![0, 0, 0, 0]); let groups = groups.unwrap_or(1); - // Parse layout strings - let input_layout_enum = match input_layout.unwrap_or("nchw") { - "nchw" => Conv2dInputLayout::Nchw, - "nhwc" => Conv2dInputLayout::Nhwc, + let input_layout_s = match input_layout.unwrap_or("nchw") { + "nchw" => "nchw", + "nhwc" => "nhwc", other => { return Err(pyo3::exceptions::PyValueError::new_err(format!( "Invalid input_layout '{}', must be 'nchw' or 'nhwc'", @@ -344,11 +341,11 @@ impl PyMLGraphBuilder { } }; - let filter_layout_enum = match filter_layout.unwrap_or("oihw") { - "oihw" => Conv2dFilterLayout::Oihw, - "hwio" => Conv2dFilterLayout::Hwio, - "ohwi" => Conv2dFilterLayout::Ohwi, - "ihwo" => Conv2dFilterLayout::Ihwo, + let filter_layout_s = match filter_layout.unwrap_or("oihw") { + "oihw" => "oihw", + "hwio" => "hwio", + "ohwi" => "ohwi", + "ihwo" => "ihwo", other => { return Err(pyo3::exceptions::PyValueError::new_err(format!( "Invalid filter_layout '{}', must be 'oihw', 'hwio', 'ohwi', or 'ihwo'", @@ -357,21 +354,22 @@ impl PyMLGraphBuilder { } }; - // Create options for shape inference - let options = Conv2dOptions { - strides: strides.clone(), - dilations: dilations.clone(), - pads: pads.clone(), + let conv2d_options = MLConv2dOptions { + label: String::new(), + padding: pads, + strides, + dilations, groups, - input_layout: input_layout_enum, - filter_layout: filter_layout_enum, + input_layout: input_layout_s.to_string(), + filter_layout: filter_layout_s.to_string(), + bias: bias.map(|b| b.id), }; // Infer output shape let output_shape = infer_conv2d_shape( &input.descriptor.static_or_max_shape(), &filter.descriptor.static_or_max_shape(), - &options, + &conv2d_options, ) .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; @@ -384,17 +382,6 @@ impl PyMLGraphBuilder { let output_id = self.next_operand_id; self.next_operand_id += 1; - let conv2d_options = MLConv2dOptions { - label: String::new(), - padding: pads, - strides, - dilations, - groups, - input_layout: input_layout.unwrap_or("nchw").to_string(), - filter_layout: filter_layout.unwrap_or("oihw").to_string(), - bias: bias.map(|b| b.id), - }; - self.push_op(Operation::Conv2d { input: input.id, filter: filter.id, @@ -431,10 +418,7 @@ impl PyMLGraphBuilder { filter_layout: Option<&str>, bias: Option<&PyMLOperand>, ) -> PyResult { - use rustnn::shape_inference::{ - infer_conv_transpose2d_shape, Conv2dFilterLayout, Conv2dInputLayout, - ConvTranspose2dOptions, - }; + use rustnn::shape_inference::infer_conv_transpose2d_shape; // Default values matching WebNN spec let strides = strides.unwrap_or_else(|| vec![1, 1]); @@ -443,10 +427,9 @@ impl PyMLGraphBuilder { let output_padding = output_padding.unwrap_or_else(|| vec![0, 0]); let groups = groups.unwrap_or(1); - // Parse layout strings - let input_layout_enum = match input_layout.unwrap_or("nchw") { - "nchw" => Conv2dInputLayout::Nchw, - "nhwc" => Conv2dInputLayout::Nhwc, + let input_layout_s = match input_layout.unwrap_or("nchw") { + "nchw" => "nchw", + "nhwc" => "nhwc", other => { return Err(pyo3::exceptions::PyValueError::new_err(format!( "Invalid input_layout '{}', must be 'nchw' or 'nhwc'", @@ -455,11 +438,11 @@ impl PyMLGraphBuilder { } }; - let filter_layout_enum = match filter_layout.unwrap_or("iohw") { - "iohw" => Conv2dFilterLayout::Oihw, // Input-Output-Height-Width (reinterpreted for transpose) - "hwoi" => Conv2dFilterLayout::Ihwo, // Height-Width-Output-Input (reinterpreted for transpose) - "ohwi" => Conv2dFilterLayout::Ohwi, // Output-Height-Width-Input - "oihw" => Conv2dFilterLayout::Hwio, // Output-Input-Height-Width (reinterpreted for transpose) + let filter_layout_s = match filter_layout.unwrap_or("iohw") { + "iohw" => "iohw", + "hwoi" => "hwoi", + "ohwi" => "ohwi", + "oihw" => "oihw", other => { return Err(pyo3::exceptions::PyValueError::new_err(format!( "Invalid filter_layout '{}', must be 'iohw', 'hwoi', 'ohwi', or 'oihw'", @@ -468,23 +451,24 @@ impl PyMLGraphBuilder { } }; - // Create options for shape inference - let options = ConvTranspose2dOptions { - strides: strides.clone(), - dilations: dilations.clone(), - pads: pads.clone(), - output_padding: output_padding.clone(), + let conv_t_options = MLConvTranspose2dOptions { + label: String::new(), + padding: pads, + strides, + dilations, + output_padding, output_sizes: output_sizes.clone(), groups, - input_layout: input_layout_enum, - filter_layout: filter_layout_enum, + input_layout: input_layout_s.to_string(), + filter_layout: filter_layout_s.to_string(), + bias: bias.map(|b| b.id), }; // Infer output shape let output_shape = infer_conv_transpose2d_shape( &input.descriptor.static_or_max_shape(), &filter.descriptor.static_or_max_shape(), - &options, + &conv_t_options, ) .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; @@ -497,19 +481,6 @@ impl PyMLGraphBuilder { let output_id = self.next_operand_id; self.next_operand_id += 1; - let conv_t_options = MLConvTranspose2dOptions { - label: String::new(), - padding: pads, - strides, - dilations, - output_padding, - output_sizes: output_sizes.clone(), - groups, - input_layout: input_layout.unwrap_or("nchw").to_string(), - filter_layout: filter_layout.unwrap_or("iohw").to_string(), - bias: bias.map(|b| b.id), - }; - self.push_op(Operation::ConvTranspose2d { input: input.id, filter: filter.id, @@ -552,7 +523,7 @@ impl PyMLGraphBuilder { pads: Option>, layout: Option<&str>, ) -> PyResult { - use rustnn::shape_inference::{infer_pool2d_shape, Conv2dInputLayout, Pool2dOptions}; + use rustnn::shape_inference::infer_pool2d_shape; // Default values matching WebNN spec let window_dimensions = window_dimensions.unwrap_or_else(|| vec![1, 1]); @@ -560,10 +531,9 @@ impl PyMLGraphBuilder { let dilations = dilations.unwrap_or_else(|| vec![1, 1]); let pads = pads.unwrap_or_else(|| vec![0, 0, 0, 0]); - // Parse layout string - let layout_enum = match layout.unwrap_or("nchw") { - "nchw" => Conv2dInputLayout::Nchw, - "nhwc" => Conv2dInputLayout::Nhwc, + let layout_s = match layout.unwrap_or("nchw") { + "nchw" => "nchw", + "nhwc" => "nhwc", other => { return Err(pyo3::exceptions::PyValueError::new_err(format!( "Invalid layout '{}', must be 'nchw' or 'nhwc'", @@ -572,17 +542,19 @@ impl PyMLGraphBuilder { } }; - // Create options for shape inference - let options = Pool2dOptions { - window_dimensions: window_dimensions.clone(), - strides: strides.clone(), - dilations: dilations.clone(), - pads: pads.clone(), - layout: layout_enum, + let pool_opts = MLPool2dOptions { + label: String::new(), + window_dimensions: Some(window_dimensions), + padding: pads, + strides, + dilations, + layout: layout_s.to_string(), + output_shape_rounding: String::new(), + output_sizes: None, }; // Infer output shape - let output_shape = infer_pool2d_shape(&input.descriptor.static_or_max_shape(), &options) + let output_shape = infer_pool2d_shape(&input.descriptor.static_or_max_shape(), &pool_opts) .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; let output_descriptor = OperandDescriptor { @@ -594,17 +566,6 @@ impl PyMLGraphBuilder { let output_id = self.next_operand_id; self.next_operand_id += 1; - let pool_opts = MLPool2dOptions { - label: String::new(), - window_dimensions: Some(window_dimensions), - padding: pads, - strides, - dilations, - layout: layout.unwrap_or("nchw").to_string(), - output_shape_rounding: String::new(), - output_sizes: None, - }; - self.push_op(Operation::AveragePool2d { input: input.id, options: Some(pool_opts), @@ -646,7 +607,7 @@ impl PyMLGraphBuilder { pads: Option>, layout: Option<&str>, ) -> PyResult { - use rustnn::shape_inference::{infer_pool2d_shape, Conv2dInputLayout, Pool2dOptions}; + use rustnn::shape_inference::infer_pool2d_shape; // Default values matching WebNN spec let window_dimensions = window_dimensions.unwrap_or_else(|| vec![1, 1]); @@ -654,10 +615,9 @@ impl PyMLGraphBuilder { let dilations = dilations.unwrap_or_else(|| vec![1, 1]); let pads = pads.unwrap_or_else(|| vec![0, 0, 0, 0]); - // Parse layout string - let layout_enum = match layout.unwrap_or("nchw") { - "nchw" => Conv2dInputLayout::Nchw, - "nhwc" => Conv2dInputLayout::Nhwc, + let layout_s = match layout.unwrap_or("nchw") { + "nchw" => "nchw", + "nhwc" => "nhwc", other => { return Err(pyo3::exceptions::PyValueError::new_err(format!( "Invalid layout '{}', must be 'nchw' or 'nhwc'", @@ -666,17 +626,19 @@ impl PyMLGraphBuilder { } }; - // Create options for shape inference - let options = Pool2dOptions { - window_dimensions: window_dimensions.clone(), - strides: strides.clone(), - dilations: dilations.clone(), - pads: pads.clone(), - layout: layout_enum, + let pool_opts = MLPool2dOptions { + label: String::new(), + window_dimensions: Some(window_dimensions), + padding: pads, + strides, + dilations, + layout: layout_s.to_string(), + output_shape_rounding: String::new(), + output_sizes: None, }; // Infer output shape - let output_shape = infer_pool2d_shape(&input.descriptor.static_or_max_shape(), &options) + let output_shape = infer_pool2d_shape(&input.descriptor.static_or_max_shape(), &pool_opts) .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; let output_descriptor = OperandDescriptor { @@ -688,17 +650,6 @@ impl PyMLGraphBuilder { let output_id = self.next_operand_id; self.next_operand_id += 1; - let pool_opts = MLPool2dOptions { - label: String::new(), - window_dimensions: Some(window_dimensions), - padding: pads, - strides, - dilations, - layout: layout.unwrap_or("nchw").to_string(), - output_shape_rounding: String::new(), - output_sizes: None, - }; - self.push_op(Operation::MaxPool2d { input: input.id, options: Some(pool_opts), @@ -732,14 +683,11 @@ impl PyMLGraphBuilder { input: &PyMLOperand, layout: Option<&str>, ) -> PyResult { - use rustnn::shape_inference::{ - infer_global_pool_shape, Conv2dInputLayout, GlobalPoolOptions, - }; + use rustnn::shape_inference::{infer_global_pool_shape, GlobalPoolOptions, InputLayout}; - // Parse layout string let layout_enum = match layout.unwrap_or("nchw") { - "nchw" => Conv2dInputLayout::Nchw, - "nhwc" => Conv2dInputLayout::Nhwc, + "nchw" => InputLayout::Nchw, + "nhwc" => InputLayout::Nhwc, other => { return Err(pyo3::exceptions::PyValueError::new_err(format!( "Invalid layout '{}', must be 'nchw' or 'nhwc'", @@ -748,14 +696,13 @@ impl PyMLGraphBuilder { } }; - // Create options for shape inference - let options = GlobalPoolOptions { + let shape_opts = GlobalPoolOptions { layout: layout_enum, }; // Infer output shape let output_shape = - infer_global_pool_shape(&input.descriptor.static_or_max_shape(), &options) + infer_global_pool_shape(&input.descriptor.static_or_max_shape(), &shape_opts) .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; let output_descriptor = OperandDescriptor { @@ -811,14 +758,11 @@ impl PyMLGraphBuilder { input: &PyMLOperand, layout: Option<&str>, ) -> PyResult { - use rustnn::shape_inference::{ - infer_global_pool_shape, Conv2dInputLayout, GlobalPoolOptions, - }; + use rustnn::shape_inference::{infer_global_pool_shape, GlobalPoolOptions, InputLayout}; - // Parse layout string let layout_enum = match layout.unwrap_or("nchw") { - "nchw" => Conv2dInputLayout::Nchw, - "nhwc" => Conv2dInputLayout::Nhwc, + "nchw" => InputLayout::Nchw, + "nhwc" => InputLayout::Nhwc, other => { return Err(pyo3::exceptions::PyValueError::new_err(format!( "Invalid layout '{}', must be 'nchw' or 'nhwc'", @@ -827,14 +771,13 @@ impl PyMLGraphBuilder { } }; - // Create options for shape inference - let options = GlobalPoolOptions { + let shape_opts = GlobalPoolOptions { layout: layout_enum, }; // Infer output shape let output_shape = - infer_global_pool_shape(&input.descriptor.static_or_max_shape(), &options) + infer_global_pool_shape(&input.descriptor.static_or_max_shape(), &shape_opts) .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; let output_descriptor = OperandDescriptor { @@ -3670,14 +3613,13 @@ impl PyMLGraphBuilder { ) -> PyResult { use rustnn::shape_inference::{infer_reduce_shape, ReduceOptions}; - // Create reduction options - let options = ReduceOptions { + let infer_opts = ReduceOptions { axes: axes.clone().unwrap_or_default(), keep_dimensions, }; // Infer output shape - let output_shape = infer_reduce_shape(&input.descriptor.static_or_max_shape(), &options) + let output_shape = infer_reduce_shape(&input.descriptor.static_or_max_shape(), &infer_opts) .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; let output_descriptor = OperandDescriptor { From 8add9353b07cc829ecf2477e15baf23aa4372d15 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Mon, 13 Apr 2026 20:14:51 +0200 Subject: [PATCH 2/3] remove GlobalOptions --- src/python/graph_builder.rs | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/python/graph_builder.rs b/src/python/graph_builder.rs index b3d8d0c..4cc4678 100644 --- a/src/python/graph_builder.rs +++ b/src/python/graph_builder.rs @@ -683,7 +683,7 @@ impl PyMLGraphBuilder { input: &PyMLOperand, layout: Option<&str>, ) -> PyResult { - use rustnn::shape_inference::{infer_global_pool_shape, GlobalPoolOptions, InputLayout}; + use rustnn::shape_inference::{infer_global_pool_shape, InputLayout}; let layout_enum = match layout.unwrap_or("nchw") { "nchw" => InputLayout::Nchw, @@ -696,14 +696,11 @@ impl PyMLGraphBuilder { } }; - let shape_opts = GlobalPoolOptions { - layout: layout_enum, - }; - - // Infer output shape - let output_shape = - infer_global_pool_shape(&input.descriptor.static_or_max_shape(), &shape_opts) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; + let output_shape = infer_global_pool_shape( + &input.descriptor.static_or_max_shape(), + layout_enum, + ) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; let output_descriptor = OperandDescriptor { data_type: input.descriptor.data_type, @@ -758,7 +755,7 @@ impl PyMLGraphBuilder { input: &PyMLOperand, layout: Option<&str>, ) -> PyResult { - use rustnn::shape_inference::{infer_global_pool_shape, GlobalPoolOptions, InputLayout}; + use rustnn::shape_inference::{infer_global_pool_shape, InputLayout}; let layout_enum = match layout.unwrap_or("nchw") { "nchw" => InputLayout::Nchw, @@ -771,14 +768,11 @@ impl PyMLGraphBuilder { } }; - let shape_opts = GlobalPoolOptions { - layout: layout_enum, - }; - - // Infer output shape - let output_shape = - infer_global_pool_shape(&input.descriptor.static_or_max_shape(), &shape_opts) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; + let output_shape = infer_global_pool_shape( + &input.descriptor.static_or_max_shape(), + layout_enum, + ) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; let output_descriptor = OperandDescriptor { data_type: input.descriptor.data_type, From bc0eccacfed7c6a0975af763c60e0a803db0c9d0 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Mon, 13 Apr 2026 20:16:44 +0200 Subject: [PATCH 3/3] cargo fmt --- src/python/graph_builder.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/python/graph_builder.rs b/src/python/graph_builder.rs index 4cc4678..5cb348f 100644 --- a/src/python/graph_builder.rs +++ b/src/python/graph_builder.rs @@ -696,11 +696,9 @@ impl PyMLGraphBuilder { } }; - let output_shape = infer_global_pool_shape( - &input.descriptor.static_or_max_shape(), - layout_enum, - ) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; + let output_shape = + infer_global_pool_shape(&input.descriptor.static_or_max_shape(), layout_enum) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; let output_descriptor = OperandDescriptor { data_type: input.descriptor.data_type, @@ -768,11 +766,9 @@ impl PyMLGraphBuilder { } }; - let output_shape = infer_global_pool_shape( - &input.descriptor.static_or_max_shape(), - layout_enum, - ) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; + let output_shape = + infer_global_pool_shape(&input.descriptor.static_or_max_shape(), layout_enum) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; let output_descriptor = OperandDescriptor { data_type: input.descriptor.data_type,