Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ AD trait definitions for the tensor4all v2 stack.
This crate defines:

- **`PrimitiveOp`** — extends `computegraph::GraphOp` with `add()`
(cotangent accumulation constructor), `linearize` (JVP rule), and
`transpose_rule` (reverse-mode rule)
(cotangent accumulation constructor), `linearize` / `try_linearize`
(JVP rule), and `transpose_rule` / `try_transpose_rule`
(reverse-mode rule)
- **`ADRuleError`** — reports missing or unsupported AD rules without forcing
downstream AD transforms to panic
- **`ADKey`** — trait on `GraphOp::InputKey` for generating tangent input
keys during `differentiate`

Expand Down
122 changes: 122 additions & 0 deletions src/ad_rule_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use std::error::Error;
use std::fmt;

/// Identifies which AD rule failed or is unavailable.
///
/// # Examples
///
/// ```
/// use chainrules::ADRuleKind;
///
/// assert_eq!(ADRuleKind::Linearize.as_str(), "linearize");
/// assert_eq!(ADRuleKind::Transpose.as_str(), "transpose");
/// ```
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum ADRuleKind {
/// Forward linearization / JVP rule.
Linearize,
/// Transpose / VJP rule for a linear primitive.
Transpose,
}

impl ADRuleKind {
/// Returns a stable human-readable rule name.
///
/// # Examples
///
/// ```
/// use chainrules::ADRuleKind;
///
/// assert_eq!(ADRuleKind::Linearize.as_str(), "linearize");
/// ```
pub const fn as_str(self) -> &'static str {
match self {
Self::Linearize => "linearize",
Self::Transpose => "transpose",
}
}
}

/// Error returned when an AD rule cannot be emitted.
///
/// # Examples
///
/// ```
/// use chainrules::{ADRuleError, ADRuleKind};
///
/// let err = ADRuleError::unsupported("my_crate::fft", ADRuleKind::Linearize);
/// assert_eq!(err.rule(), ADRuleKind::Linearize);
/// assert!(err.to_string().contains("my_crate::fft"));
/// ```
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ADRuleError {
/// The requested primitive does not provide the requested AD rule.
Unsupported {
/// Stable primitive name or extension family identifier.
op: String,
/// Missing rule kind.
rule: ADRuleKind,
},
}

impl ADRuleError {
/// Constructs an unsupported-rule error.
///
/// # Examples
///
/// ```
/// use chainrules::{ADRuleError, ADRuleKind};
///
/// let err = ADRuleError::unsupported("custom::op", ADRuleKind::Transpose);
/// assert_eq!(err.rule(), ADRuleKind::Transpose);
/// ```
pub fn unsupported(op: impl Into<String>, rule: ADRuleKind) -> Self {
Self::Unsupported {
op: op.into(),
rule,
}
}

/// Returns the AD rule kind associated with this error.
///
/// # Examples
///
/// ```
/// use chainrules::{ADRuleError, ADRuleKind};
///
/// let err = ADRuleError::unsupported("custom::op", ADRuleKind::Linearize);
/// assert_eq!(err.rule(), ADRuleKind::Linearize);
/// ```
pub const fn rule(&self) -> ADRuleKind {
match self {
Self::Unsupported { rule, .. } => *rule,
}
}
}

impl fmt::Display for ADRuleError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Unsupported { op, rule } => {
write!(f, "unsupported {} AD rule for {op}", rule.as_str())
}
}
}
}

impl Error for ADRuleError {}

/// Result type used by fallible AD rule emission.
///
/// # Examples
///
/// ```
/// use chainrules::{ADRuleError, ADRuleKind, ADRuleResult};
///
/// fn missing_rule() -> ADRuleResult<()> {
/// Err(ADRuleError::unsupported("custom::op", ADRuleKind::Transpose))
/// }
///
/// assert!(missing_rule().is_err());
/// ```
pub type ADRuleResult<T> = Result<T, ADRuleError>;
7 changes: 5 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
//! AD trait definitions for the tensor4all v2 stack.
//!
//! This crate defines [`PrimitiveOp`] (extends [`computegraph::GraphOp`] with
//! linearization and transpose rules) and [`ADKey`] (tangent input key
//! generation). It contains no concrete primitives and no graph infrastructure.
//! linearization and transpose rules), [`ADKey`] (tangent input key generation),
//! and [`ADRuleError`] for fallible rule emission. It contains no concrete
//! primitives and no graph infrastructure.

pub mod ad_key;
pub mod ad_rule_error;
pub mod primitive_op;

pub use ad_key::{ADKey, DiffPassId};
pub use ad_rule_error::{ADRuleError, ADRuleKind, ADRuleResult};
pub use primitive_op::PrimitiveOp;
46 changes: 43 additions & 3 deletions src/primitive_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ use computegraph::fragment::FragmentBuilder;
use computegraph::types::{GlobalValKey, LocalValId, OpMode, ValRef};
use computegraph::{GraphOp, OpEmitter};

use crate::ADKey;
use crate::{ADKey, ADRuleResult};

/// Extends `GraphOp` with linearization and transpose rules for AD.
///
/// - `linearize` is called by `tidu::differentiate`
/// - `transpose_rule` is called by `tidu::transpose`
/// - `try_linearize` is called by `tidu::differentiate`
/// - `try_transpose_rule` is called by `tidu::transpose`
///
/// Both methods emit new ops into a `FragmentBuilder`. The downstream
/// implementor (e.g. tenferro-rs) is responsible for ensuring closure:
Expand Down Expand Up @@ -92,6 +92,26 @@ where
where
Self: Sized;

/// Fallible variant of [`PrimitiveOp::linearize`].
///
/// Implementors that can encounter missing extension rules should override
/// this method and return [`crate::ADRuleError`] instead of panicking. The
/// default implementation preserves the infallible contract for existing
/// primitive sets.
fn try_linearize(
&self,
builder: &mut FragmentBuilder<Self>,
primal_in: &[GlobalValKey<Self>],
primal_out: &[GlobalValKey<Self>],
tangent_in: &[Option<LocalValId>],
ctx: &mut Self::ADContext,
) -> ADRuleResult<Vec<Option<LocalValId>>>
where
Self: Sized,
{
Ok(self.linearize(builder, primal_in, primal_out, tangent_in, ctx))
}

/// Emit the transpose rule for this linear primitive.
///
/// Receives cotangent outputs and produces cotangent inputs.
Expand All @@ -109,4 +129,24 @@ where
) -> Vec<Option<LocalValId>>
where
Self: Sized;

/// Fallible variant of [`PrimitiveOp::transpose_rule`].
///
/// Implementors that can encounter missing extension rules should override
/// this method and return [`crate::ADRuleError`] instead of panicking. The
/// default implementation preserves the infallible contract for existing
/// primitive sets.
fn try_transpose_rule(
&self,
emitter: &mut impl OpEmitter<Self>,
cotangent_out: &[Option<LocalValId>],
inputs: &[ValRef<Self>],
mode: &OpMode,
ctx: &mut Self::ADContext,
) -> ADRuleResult<Vec<Option<LocalValId>>>
where
Self: Sized,
{
Ok(self.transpose_rule(emitter, cotangent_out, inputs, mode, ctx))
}
}
149 changes: 149 additions & 0 deletions tests/fallible_rule_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
use chainrules::{ADKey, ADRuleError, ADRuleKind, ADRuleResult, DiffPassId, PrimitiveOp};
use computegraph::fragment::FragmentBuilder;
use computegraph::types::{GlobalValKey, OpMode, ValRef};
use computegraph::{GraphOp, LocalValId, OpEmitter};

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum Key {
Base(&'static str),
Tangent(Box<Key>, DiffPassId),
}

impl ADKey for Key {
fn tangent_of(&self, pass: DiffPassId) -> Self {
Key::Tangent(Box::new(self.clone()), pass)
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum Op {
Add,
Missing,
}

impl GraphOp for Op {
type Operand = f64;
type Context = ();
type InputKey = Key;

fn n_inputs(&self) -> usize {
match self {
Op::Add => 2,
Op::Missing => 1,
}
}

fn n_outputs(&self) -> usize {
1
}
}

impl PrimitiveOp for Op {
type ADContext = ();

fn add() -> Self {
Op::Add
}

fn linearize(
&self,
_builder: &mut FragmentBuilder<Self>,
_primal_in: &[GlobalValKey<Self>],
_primal_out: &[GlobalValKey<Self>],
tangent_in: &[Option<LocalValId>],
_ctx: &mut (),
) -> Vec<Option<LocalValId>> {
match self {
Op::Add => vec![tangent_in[0].or(tangent_in[1])],
Op::Missing => panic!("missing AD rule should be reported through try_linearize"),
}
}

fn try_linearize(
&self,
_builder: &mut FragmentBuilder<Self>,
_primal_in: &[GlobalValKey<Self>],
_primal_out: &[GlobalValKey<Self>],
tangent_in: &[Option<LocalValId>],
_ctx: &mut (),
) -> ADRuleResult<Vec<Option<LocalValId>>> {
match self {
Op::Add => Ok(vec![tangent_in[0].or(tangent_in[1])]),
Op::Missing => Err(ADRuleError::unsupported(
"Op::Missing",
ADRuleKind::Linearize,
)),
}
}

fn transpose_rule(
&self,
_emitter: &mut impl OpEmitter<Self>,
cotangent_out: &[Option<LocalValId>],
_inputs: &[ValRef<Self>],
_mode: &OpMode,
_ctx: &mut (),
) -> Vec<Option<LocalValId>> {
match self {
Op::Add => vec![cotangent_out[0], cotangent_out[0]],
Op::Missing => panic!("missing AD rule should be reported through try_transpose_rule"),
}
}

fn try_transpose_rule(
&self,
_emitter: &mut impl OpEmitter<Self>,
cotangent_out: &[Option<LocalValId>],
_inputs: &[ValRef<Self>],
_mode: &OpMode,
_ctx: &mut (),
) -> ADRuleResult<Vec<Option<LocalValId>>> {
match self {
Op::Add => Ok(vec![cotangent_out[0], cotangent_out[0]]),
Op::Missing => Err(ADRuleError::unsupported(
"Op::Missing",
ADRuleKind::Transpose,
)),
}
}
}

#[test]
fn primitive_op_can_report_missing_linearize_rule() {
let mut builder = FragmentBuilder::<Op>::new();
let mut ctx = ();
let dx = builder.add_input(Key::Base("dx"));
let err = Op::Missing
.try_linearize(
&mut builder,
&[GlobalValKey::Input(Key::Base("x"))],
&[GlobalValKey::Input(Key::Base("y"))],
&[Some(dx)],
&mut ctx,
)
.unwrap_err();

assert_eq!(err.rule(), ADRuleKind::Linearize);
assert!(err.to_string().contains("Op::Missing"));
}

#[test]
fn primitive_op_can_report_missing_transpose_rule() {
let mut builder = FragmentBuilder::<Op>::new();
let mut ctx = ();
let ct = builder.add_input(Key::Base("ct"));
let err = Op::Missing
.try_transpose_rule(
&mut builder,
&[Some(ct)],
&[ValRef::External(GlobalValKey::Input(Key::Base("x")))],
&OpMode::Linear {
active_mask: vec![true],
},
&mut ctx,
)
.unwrap_err();

assert_eq!(err.rule(), ADRuleKind::Transpose);
assert!(err.to_string().contains("Op::Missing"));
}
Loading