Goal
Extend crates/tensor4all-treetn/src/treetn/linsolve.rs so the coefficient-matrix part (local linear operator used in GMRES) can be a linear combination with an arbitrary number of terms, and each coefficient can be set as AnyScalar.
Current state (as of now)
linsolve solves only an affine 2-term form:
( (a_0 I + a_1 A) x = b )
LinsolveOptions stores coefficients as f64 a0/a1 and LinsolveUpdater::solve_local() converts them via AnyScalar::new_real(...).
LocalLinOp already stores coefficients as AnyScalar a0/a1 and applies ( y = a_0 x + a_1 Hx ).
AnyScalar supports F64, C64, and (with backend-libtorch) TorchF64/TorchC64 which preserve autograd graphs.
Proposed API / design (high-level)
Represent the coefficient matrix as a list of terms:
- ( y = \sum_{k=0}^{K-1} a_k (H_k x) )
- Include identity-shift as either:
- a dedicated identity term (
a_I * x), or
- treat
H_0 = I in the term list.
Suggested Rust types (exact naming TBD):
struct LinOpTerm<T,V> { coeff: AnyScalar, op: TreeTN<T,V> }
struct LinOpCombination<T,V> { identity_coeff: Option<AnyScalar>, terms: Vec<LinOpTerm<T,V>> }
Implementation direction:
- Extend local apply in GMRES from fixed
(a0 + a1*H) to N-term accumulation.
- Each
H_k needs its own ProjectedOperator / environment cache (unless we later add shared caching).
AD requirement (coefficients)
Use AnyScalar::TorchF64/TorchC64 with backend-libtorch for reverse-mode AD through coefficients.
Note on SVD degeneracy: No special handling for degenerate singular values is implemented at this time. If numerical issues arise with degenerate singular values during AD, an ad-hoc workaround is to add small noise to the input. A more robust solution (e.g., Lorentzian regularization) can be considered if needed in the future.
Note on forward AD: PyTorch C++ API does not support custom forward AD rules (jvp is not implemented for torch::autograd::Function in C++). True forward-mode AD for coefficients would require either:
- Upstream PyTorch changes (see related issues below)
- Implicit differentiation approach:
dx/da = -A⁻¹((dA/da)x)
For now, we rely on reverse-mode AD (backward) which is sufficient for most use cases.
Tests / acceptance criteria
Out of scope / follow-up
Notes / related code locations
crates/tensor4all-treetn/src/treetn/linsolve/options.rs (currently f64 a0/a1)
crates/tensor4all-treetn/src/treetn/linsolve/local_linop.rs (apply logic)
crates/tensor4all-treetn/src/treetn/linsolve/updater.rs (solve_local, coefficient conversion)
crates/tensor4all-tensorbackend/src/any_scalar.rs (AnyScalar incl Torch autograd variants)
Related upstream issues
- pytorch/pytorch#40208 - C++ API for jacobian (open)
- PyTorch C++ custom Function JVP: not implemented (see
extern/pytorch/torch/csrc/autograd/custom_function.h:507-513)
Goal
Extend
crates/tensor4all-treetn/src/treetn/linsolve.rsso the coefficient-matrix part (local linear operator used in GMRES) can be a linear combination with an arbitrary number of terms, and each coefficient can be set asAnyScalar.Current state (as of now)
linsolvesolves only an affine 2-term form:( (a_0 I + a_1 A) x = b )
LinsolveOptionsstores coefficients asf64 a0/a1andLinsolveUpdater::solve_local()converts them viaAnyScalar::new_real(...).LocalLinOpalready stores coefficients asAnyScalar a0/a1and applies ( y = a_0 x + a_1 Hx ).AnyScalarsupportsF64,C64, and (withbackend-libtorch)TorchF64/TorchC64which preserve autograd graphs.Proposed API / design (high-level)
Represent the coefficient matrix as a list of terms:
a_I * x), orH_0 = Iin the term list.Suggested Rust types (exact naming TBD):
struct LinOpTerm<T,V> { coeff: AnyScalar, op: TreeTN<T,V> }struct LinOpCombination<T,V> { identity_coeff: Option<AnyScalar>, terms: Vec<LinOpTerm<T,V>> }Implementation direction:
(a0 + a1*H)to N-term accumulation.H_kneeds its ownProjectedOperator/ environment cache (unless we later add shared caching).AD requirement (coefficients)
Use
AnyScalar::TorchF64/TorchC64withbackend-libtorchfor reverse-mode AD through coefficients.Note on SVD degeneracy: No special handling for degenerate singular values is implemented at this time. If numerical issues arise with degenerate singular values during AD, an ad-hoc workaround is to add small noise to the input. A more robust solution (e.g., Lorentzian regularization) can be considered if needed in the future.
Note on forward AD: PyTorch C++ API does not support custom forward AD rules (
jvpis not implemented fortorch::autograd::Functionin C++). True forward-mode AD for coefficients would require either:dx/da = -A⁻¹((dA/da)x)For now, we rely on reverse-mode AD (
backward) which is sufficient for most use cases.Tests / acceptance criteria
AnyScalar(real/complex).a0/a1-only API once the new API is in place, or keep a thin convenience builder that maps(a0,a1,A)to the new representation.linsolvesolution matches within tolerance.backend-libtorch).Out of scope / follow-up
Notes / related code locations
crates/tensor4all-treetn/src/treetn/linsolve/options.rs(currentlyf64 a0/a1)crates/tensor4all-treetn/src/treetn/linsolve/local_linop.rs(apply logic)crates/tensor4all-treetn/src/treetn/linsolve/updater.rs(solve_local, coefficient conversion)crates/tensor4all-tensorbackend/src/any_scalar.rs(AnyScalarincl Torch autograd variants)Related upstream issues
extern/pytorch/torch/csrc/autograd/custom_function.h:507-513)