AD graph transforms for the tensor4all v2 stack.
Provides:
differentiate/try_differentiate— JVP transform (resolved view → linear fragment)transpose/try_transpose— reverse linear flow (linear fragment → linear fragment)- eager reverse-mode AD helpers around
GradNode/backward_dag
Use the try_* APIs when downstream primitive sets can report unsupported
extension AD rules through chainrules::ADRuleError.
Fully generic over Op: PrimitiveOp. References no specific primitives.
tidu provides a small generic recording layer for PyTorch-style eager
reverse-mode AD. Downstream frontends execute concrete primal operations, then
call record_eager_op() with input trace metadata and concrete output values.
tidu allocates stable input aliases and output keys, saves replay data, builds
GradEdges, and shares one GradNode across multi-output operations.
Downstream crates still own backend execution, tensor metadata, gradient slots,
and BackwardCallbacks for concrete forward replay, eager transpose execution,
and cotangent addition.
Forward and reverse modes follow the JAX convention:
differentiate(JVP) computes the full R-linear derivative:df = (∂f/∂z)·dz + (∂f/∂z̄)·conj(dz)transpose(VJP) computes the adjoint w.r.t. the real inner product⟨a, b⟩ = Re(conj(a)·b).
For a general function f: C → C, the VJP cotangent relates to Wirtinger derivatives as:
ct_z = ct_y · conj(∂f/∂z) + conj(ct_y) · (∂f/∂z̄)
Special cases:
| Case | Result |
|---|---|
| Real loss (L: C→R), ct_y=1 | ct_z = 2·(∂L/∂z̄) |
| Holomorphic f, ∂f/∂z̄=0 | ct_z = ct_y · conj(f'(z)) |
| conj(z), ∂f/∂z=0 | ct_z = conj(ct_y) |
For real-valued losses, this differs from PyTorch (which returns ∂L/∂z̄
directly) by a factor of 2. The steepest-descent direction is the same.
Higher-order derivatives are computed by repeated application of
differentiate (and optionally transpose). Each differentiate call
consumes one tangent vector, so the output shape stays the same as the
original function output regardless of derivative order.
For f: R^n → R^m:
| Order | Computation | New tangent | Output shape |
|---|---|---|---|
| 0 | f(x) | — | R^m |
| 1st (F) | J · dx₁ | dx₁ ∈ R^n | R^m |
| 2nd (FoF) | (∂J/∂x · dx₂) · dx₁ | dx₂ ∈ R^n | R^m |
| 3rd (FoFoF) | ∂³f/∂x³ · dx₁ · dx₂ · dx₃ | dx₃ ∈ R^n | R^m |
Each JVP contracts one tangent vector with the derivative tensor. The k-th forward derivative is a rank-(k+1) tensor contracted with k tangent vectors, producing a result in R^m.
To extract individual tensor components, use unit basis vectors as tangents:
∂³f/∂xᵢ∂xⱼ∂xₖ = FoFoF(eᵢ, eⱼ, eₖ).
FoF: build → resolve → differentiate → resolve → differentiate → materialize → compile → eval
FoFoF: build → (resolve → differentiate) × 3 → materialize → compile → eval
FoR: build → resolve → differentiate → transpose → resolve → differentiate → materialize → compile → eval
Each differentiate call requires a unique DiffPassId and a preceding
resolve to make earlier fragments traceable.
computegraph-rs graph engine
chainrules-rs PrimitiveOp trait
tidu-rs <-- this crate (differentiate, transpose)
tenferro-rs concrete primitives