Skip to content

treetn linsolve: support arbitrary linear-combination operator terms with AnyScalar coefficients (incl AD check) #148

@shinaoka

Description

@shinaoka

Goal

Extend crates/tensor4all-treetn/src/treetn/linsolve.rs so the coefficient-matrix part (local linear operator used in GMRES) can be a linear combination with an arbitrary number of terms, and each coefficient can be set as AnyScalar.

Current state (as of now)

  • linsolve solves only an affine 2-term form:
    ( (a_0 I + a_1 A) x = b )
  • LinsolveOptions stores coefficients as f64 a0/a1 and LinsolveUpdater::solve_local() converts them via AnyScalar::new_real(...).
  • LocalLinOp already stores coefficients as AnyScalar a0/a1 and applies ( y = a_0 x + a_1 Hx ).
  • AnyScalar supports F64, C64, and (with backend-libtorch) TorchF64/TorchC64 which preserve autograd graphs.

Proposed API / design (high-level)

Represent the coefficient matrix as a list of terms:

  • ( y = \sum_{k=0}^{K-1} a_k (H_k x) )
  • Include identity-shift as either:
    • a dedicated identity term (a_I * x), or
    • treat H_0 = I in the term list.

Suggested Rust types (exact naming TBD):

  • struct LinOpTerm<T,V> { coeff: AnyScalar, op: TreeTN<T,V> }
  • struct LinOpCombination<T,V> { identity_coeff: Option<AnyScalar>, terms: Vec<LinOpTerm<T,V>> }

Implementation direction:

  • Extend local apply in GMRES from fixed (a0 + a1*H) to N-term accumulation.
  • Each H_k needs its own ProjectedOperator / environment cache (unless we later add shared caching).

AD requirement (coefficients)

Use AnyScalar::TorchF64/TorchC64 with backend-libtorch for reverse-mode AD through coefficients.

Note on SVD degeneracy: No special handling for degenerate singular values is implemented at this time. If numerical issues arise with degenerate singular values during AD, an ad-hoc workaround is to add small noise to the input. A more robust solution (e.g., Lorentzian regularization) can be considered if needed in the future.

Note on forward AD: PyTorch C++ API does not support custom forward AD rules (jvp is not implemented for torch::autograd::Function in C++). True forward-mode AD for coefficients would require either:

  • Upstream PyTorch changes (see related issues below)
  • Implicit differentiation approach: dx/da = -A⁻¹((dA/da)x)

For now, we rely on reverse-mode AD (backward) which is sufficient for most use cases.

Tests / acceptance criteria

  • Functionality: support K>=1 terms (e.g. 3 terms) in the linsolve local operator; coefficients are AnyScalar (real/complex).
  • Backwards compatibility not required (early dev): remove the old a0/a1-only API once the new API is in place, or keep a thin convenience builder that maps (a0,a1,A) to the new representation.
  • Correctness test (multi-term): for a small system (2-site or 3-site), build operators where we can compute the exact solution (e.g. diagonal/Pauli-type), and verify linsolve solution matches within tolerance.
  • AD check test (reverse-mode):
    • Prefer a diagonal case with analytic dependence: ( (a I + \sum_k b_k D_k) x = rhs \Rightarrow x_i = rhs_i / (a + \sum_k b_k d_{k,i}) )
    • Verify derivative of some scalar function of the solution (e.g. one component, or residual norm) w.r.t. coefficients matches analytic derivative or finite difference.
    • Gate AD test behind a feature if needed (e.g. backend-libtorch).

Out of scope / follow-up

Notes / related code locations

  • crates/tensor4all-treetn/src/treetn/linsolve/options.rs (currently f64 a0/a1)
  • crates/tensor4all-treetn/src/treetn/linsolve/local_linop.rs (apply logic)
  • crates/tensor4all-treetn/src/treetn/linsolve/updater.rs (solve_local, coefficient conversion)
  • crates/tensor4all-tensorbackend/src/any_scalar.rs (AnyScalar incl Torch autograd variants)

Related upstream issues

  • pytorch/pytorch#40208 - C++ API for jacobian (open)
  • PyTorch C++ custom Function JVP: not implemented (see extern/pytorch/torch/csrc/autograd/custom_function.h:507-513)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions