From bc46e8d3745f85cc8047fd8e7b484f8e5ed54d3c Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Tue, 5 May 2026 16:40:39 +0900 Subject: [PATCH] feat: add fallible AD rule hooks --- README.md | 7 +- src/ad_rule_error.rs | 122 ++++++++++++++++++++++++++++ src/lib.rs | 7 +- src/primitive_op.rs | 46 ++++++++++- tests/fallible_rule_tests.rs | 149 +++++++++++++++++++++++++++++++++++ 5 files changed, 324 insertions(+), 7 deletions(-) create mode 100644 src/ad_rule_error.rs create mode 100644 tests/fallible_rule_tests.rs diff --git a/README.md b/README.md index e47cd42..62a4bd7 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,11 @@ AD trait definitions for the tensor4all v2 stack. This crate defines: - **`PrimitiveOp`** — extends `computegraph::GraphOp` with `add()` - (cotangent accumulation constructor), `linearize` (JVP rule), and - `transpose_rule` (reverse-mode rule) + (cotangent accumulation constructor), `linearize` / `try_linearize` + (JVP rule), and `transpose_rule` / `try_transpose_rule` + (reverse-mode rule) +- **`ADRuleError`** — reports missing or unsupported AD rules without forcing + downstream AD transforms to panic - **`ADKey`** — trait on `GraphOp::InputKey` for generating tangent input keys during `differentiate` diff --git a/src/ad_rule_error.rs b/src/ad_rule_error.rs new file mode 100644 index 0000000..7288176 --- /dev/null +++ b/src/ad_rule_error.rs @@ -0,0 +1,122 @@ +use std::error::Error; +use std::fmt; + +/// Identifies which AD rule failed or is unavailable. +/// +/// # Examples +/// +/// ``` +/// use chainrules::ADRuleKind; +/// +/// assert_eq!(ADRuleKind::Linearize.as_str(), "linearize"); +/// assert_eq!(ADRuleKind::Transpose.as_str(), "transpose"); +/// ``` +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ADRuleKind { + /// Forward linearization / JVP rule. + Linearize, + /// Transpose / VJP rule for a linear primitive. + Transpose, +} + +impl ADRuleKind { + /// Returns a stable human-readable rule name. + /// + /// # Examples + /// + /// ``` + /// use chainrules::ADRuleKind; + /// + /// assert_eq!(ADRuleKind::Linearize.as_str(), "linearize"); + /// ``` + pub const fn as_str(self) -> &'static str { + match self { + Self::Linearize => "linearize", + Self::Transpose => "transpose", + } + } +} + +/// Error returned when an AD rule cannot be emitted. +/// +/// # Examples +/// +/// ``` +/// use chainrules::{ADRuleError, ADRuleKind}; +/// +/// let err = ADRuleError::unsupported("my_crate::fft", ADRuleKind::Linearize); +/// assert_eq!(err.rule(), ADRuleKind::Linearize); +/// assert!(err.to_string().contains("my_crate::fft")); +/// ``` +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ADRuleError { + /// The requested primitive does not provide the requested AD rule. + Unsupported { + /// Stable primitive name or extension family identifier. + op: String, + /// Missing rule kind. + rule: ADRuleKind, + }, +} + +impl ADRuleError { + /// Constructs an unsupported-rule error. + /// + /// # Examples + /// + /// ``` + /// use chainrules::{ADRuleError, ADRuleKind}; + /// + /// let err = ADRuleError::unsupported("custom::op", ADRuleKind::Transpose); + /// assert_eq!(err.rule(), ADRuleKind::Transpose); + /// ``` + pub fn unsupported(op: impl Into, rule: ADRuleKind) -> Self { + Self::Unsupported { + op: op.into(), + rule, + } + } + + /// Returns the AD rule kind associated with this error. + /// + /// # Examples + /// + /// ``` + /// use chainrules::{ADRuleError, ADRuleKind}; + /// + /// let err = ADRuleError::unsupported("custom::op", ADRuleKind::Linearize); + /// assert_eq!(err.rule(), ADRuleKind::Linearize); + /// ``` + pub const fn rule(&self) -> ADRuleKind { + match self { + Self::Unsupported { rule, .. } => *rule, + } + } +} + +impl fmt::Display for ADRuleError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Unsupported { op, rule } => { + write!(f, "unsupported {} AD rule for {op}", rule.as_str()) + } + } + } +} + +impl Error for ADRuleError {} + +/// Result type used by fallible AD rule emission. +/// +/// # Examples +/// +/// ``` +/// use chainrules::{ADRuleError, ADRuleKind, ADRuleResult}; +/// +/// fn missing_rule() -> ADRuleResult<()> { +/// Err(ADRuleError::unsupported("custom::op", ADRuleKind::Transpose)) +/// } +/// +/// assert!(missing_rule().is_err()); +/// ``` +pub type ADRuleResult = Result; diff --git a/src/lib.rs b/src/lib.rs index fedcc44..203e1b9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,14 @@ //! AD trait definitions for the tensor4all v2 stack. //! //! This crate defines [`PrimitiveOp`] (extends [`computegraph::GraphOp`] with -//! linearization and transpose rules) and [`ADKey`] (tangent input key -//! generation). It contains no concrete primitives and no graph infrastructure. +//! linearization and transpose rules), [`ADKey`] (tangent input key generation), +//! and [`ADRuleError`] for fallible rule emission. It contains no concrete +//! primitives and no graph infrastructure. pub mod ad_key; +pub mod ad_rule_error; pub mod primitive_op; pub use ad_key::{ADKey, DiffPassId}; +pub use ad_rule_error::{ADRuleError, ADRuleKind, ADRuleResult}; pub use primitive_op::PrimitiveOp; diff --git a/src/primitive_op.rs b/src/primitive_op.rs index deedf47..551cc6c 100644 --- a/src/primitive_op.rs +++ b/src/primitive_op.rs @@ -2,12 +2,12 @@ use computegraph::fragment::FragmentBuilder; use computegraph::types::{GlobalValKey, LocalValId, OpMode, ValRef}; use computegraph::{GraphOp, OpEmitter}; -use crate::ADKey; +use crate::{ADKey, ADRuleResult}; /// Extends `GraphOp` with linearization and transpose rules for AD. /// -/// - `linearize` is called by `tidu::differentiate` -/// - `transpose_rule` is called by `tidu::transpose` +/// - `try_linearize` is called by `tidu::differentiate` +/// - `try_transpose_rule` is called by `tidu::transpose` /// /// Both methods emit new ops into a `FragmentBuilder`. The downstream /// implementor (e.g. tenferro-rs) is responsible for ensuring closure: @@ -92,6 +92,26 @@ where where Self: Sized; + /// Fallible variant of [`PrimitiveOp::linearize`]. + /// + /// Implementors that can encounter missing extension rules should override + /// this method and return [`crate::ADRuleError`] instead of panicking. The + /// default implementation preserves the infallible contract for existing + /// primitive sets. + fn try_linearize( + &self, + builder: &mut FragmentBuilder, + primal_in: &[GlobalValKey], + primal_out: &[GlobalValKey], + tangent_in: &[Option], + ctx: &mut Self::ADContext, + ) -> ADRuleResult>> + where + Self: Sized, + { + Ok(self.linearize(builder, primal_in, primal_out, tangent_in, ctx)) + } + /// Emit the transpose rule for this linear primitive. /// /// Receives cotangent outputs and produces cotangent inputs. @@ -109,4 +129,24 @@ where ) -> Vec> where Self: Sized; + + /// Fallible variant of [`PrimitiveOp::transpose_rule`]. + /// + /// Implementors that can encounter missing extension rules should override + /// this method and return [`crate::ADRuleError`] instead of panicking. The + /// default implementation preserves the infallible contract for existing + /// primitive sets. + fn try_transpose_rule( + &self, + emitter: &mut impl OpEmitter, + cotangent_out: &[Option], + inputs: &[ValRef], + mode: &OpMode, + ctx: &mut Self::ADContext, + ) -> ADRuleResult>> + where + Self: Sized, + { + Ok(self.transpose_rule(emitter, cotangent_out, inputs, mode, ctx)) + } } diff --git a/tests/fallible_rule_tests.rs b/tests/fallible_rule_tests.rs new file mode 100644 index 0000000..e11de72 --- /dev/null +++ b/tests/fallible_rule_tests.rs @@ -0,0 +1,149 @@ +use chainrules::{ADKey, ADRuleError, ADRuleKind, ADRuleResult, DiffPassId, PrimitiveOp}; +use computegraph::fragment::FragmentBuilder; +use computegraph::types::{GlobalValKey, OpMode, ValRef}; +use computegraph::{GraphOp, LocalValId, OpEmitter}; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +enum Key { + Base(&'static str), + Tangent(Box, DiffPassId), +} + +impl ADKey for Key { + fn tangent_of(&self, pass: DiffPassId) -> Self { + Key::Tangent(Box::new(self.clone()), pass) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +enum Op { + Add, + Missing, +} + +impl GraphOp for Op { + type Operand = f64; + type Context = (); + type InputKey = Key; + + fn n_inputs(&self) -> usize { + match self { + Op::Add => 2, + Op::Missing => 1, + } + } + + fn n_outputs(&self) -> usize { + 1 + } +} + +impl PrimitiveOp for Op { + type ADContext = (); + + fn add() -> Self { + Op::Add + } + + fn linearize( + &self, + _builder: &mut FragmentBuilder, + _primal_in: &[GlobalValKey], + _primal_out: &[GlobalValKey], + tangent_in: &[Option], + _ctx: &mut (), + ) -> Vec> { + match self { + Op::Add => vec![tangent_in[0].or(tangent_in[1])], + Op::Missing => panic!("missing AD rule should be reported through try_linearize"), + } + } + + fn try_linearize( + &self, + _builder: &mut FragmentBuilder, + _primal_in: &[GlobalValKey], + _primal_out: &[GlobalValKey], + tangent_in: &[Option], + _ctx: &mut (), + ) -> ADRuleResult>> { + match self { + Op::Add => Ok(vec![tangent_in[0].or(tangent_in[1])]), + Op::Missing => Err(ADRuleError::unsupported( + "Op::Missing", + ADRuleKind::Linearize, + )), + } + } + + fn transpose_rule( + &self, + _emitter: &mut impl OpEmitter, + cotangent_out: &[Option], + _inputs: &[ValRef], + _mode: &OpMode, + _ctx: &mut (), + ) -> Vec> { + match self { + Op::Add => vec![cotangent_out[0], cotangent_out[0]], + Op::Missing => panic!("missing AD rule should be reported through try_transpose_rule"), + } + } + + fn try_transpose_rule( + &self, + _emitter: &mut impl OpEmitter, + cotangent_out: &[Option], + _inputs: &[ValRef], + _mode: &OpMode, + _ctx: &mut (), + ) -> ADRuleResult>> { + match self { + Op::Add => Ok(vec![cotangent_out[0], cotangent_out[0]]), + Op::Missing => Err(ADRuleError::unsupported( + "Op::Missing", + ADRuleKind::Transpose, + )), + } + } +} + +#[test] +fn primitive_op_can_report_missing_linearize_rule() { + let mut builder = FragmentBuilder::::new(); + let mut ctx = (); + let dx = builder.add_input(Key::Base("dx")); + let err = Op::Missing + .try_linearize( + &mut builder, + &[GlobalValKey::Input(Key::Base("x"))], + &[GlobalValKey::Input(Key::Base("y"))], + &[Some(dx)], + &mut ctx, + ) + .unwrap_err(); + + assert_eq!(err.rule(), ADRuleKind::Linearize); + assert!(err.to_string().contains("Op::Missing")); +} + +#[test] +fn primitive_op_can_report_missing_transpose_rule() { + let mut builder = FragmentBuilder::::new(); + let mut ctx = (); + let ct = builder.add_input(Key::Base("ct")); + let err = Op::Missing + .try_transpose_rule( + &mut builder, + &[Some(ct)], + &[ValRef::External(GlobalValKey::Input(Key::Base("x")))], + &OpMode::Linear { + active_mask: vec![true], + }, + &mut ctx, + ) + .unwrap_err(); + + assert_eq!(err.rule(), ADRuleKind::Transpose); + assert!(err.to_string().contains("Op::Missing")); +}