diff --git a/.gitignore b/.gitignore index 9d712a10..8126c168 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ build/ Manifest.toml LocalPreferences.toml .julia-tests/ +benchmarks/results/*.h5 # Avoid accidentally committing ad-hoc analysis notes at repo root /*.md diff --git a/Cargo.toml b/Cargo.toml index 9ccdb347..f7f72c9a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ bnum = "0.12" uint = "0.10" libc = "0.2" paste = "1.0" +smallvec = "1.15" thiserror = "2.0" rand = "0.9" rand_chacha = "0.9" @@ -53,7 +54,7 @@ hdf5-metno = { version = "0.12", default-features = false } ndarray = "0.17" quanticsgrids = { git = "https://github.com/tensor4all/quanticsgrids-rs", rev = "a76b8fb" } hdf5-rt = { git = "https://github.com/tensor4all/hdf5-rt", default-features = false } -tenferro = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "925e2511bc6bd019432f66596de39389dfced754", default-features = false } -tenferro-device = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "925e2511bc6bd019432f66596de39389dfced754" } -tenferro-einsum = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "925e2511bc6bd019432f66596de39389dfced754", default-features = false } -tenferro-tensor = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "925e2511bc6bd019432f66596de39389dfced754", default-features = false } +tenferro = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "6283c6f1a8a56f920a2600f6c27e9e6ade28beed", default-features = false } +tenferro-device = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "6283c6f1a8a56f920a2600f6c27e9e6ade28beed" } +tenferro-einsum = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "6283c6f1a8a56f920a2600f6c27e9e6ade28beed", default-features = false } +tenferro-tensor = { git = "https://github.com/tensor4all/tenferro-rs.git", rev = "6283c6f1a8a56f920a2600f6c27e9e6ade28beed", default-features = false } diff --git a/benchmarks/README.md b/benchmarks/README.md index 51616246..16334f1a 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -6,6 +6,7 @@ This directory contains benchmark code for comparing Rust and Julia implementati - `rust/`: Rust benchmark code using `tensor4all-rs` - `julia/`: Julia benchmark code using `ITensors.jl` and `ITensorMPS.jl` +- `results/`: saved benchmark commands and representative local outputs ## Running Benchmarks @@ -16,6 +17,36 @@ cd crates/tensor4all-itensorlike cargo run --release --example benchmark_contract ``` +Projected local-operator apply: + +```bash +RAYON_NUM_THREADS=1 cargo run -p tensor4all-treetn --example benchmark_projected_apply --release -- 38 32 32 3 0 +``` + +Prepared local linsolve: + +```bash +RAYON_NUM_THREADS=1 cargo run -p tensor4all-treetn --example benchmark_local_linsolve --release -- 38 32 32 1 10 30 0 +``` + +Non-AD local tensor operations: + +```bash +RAYON_NUM_THREADS=1 cargo run -p tensor4all-core --example benchmark_tensor_ops --release -- 20000 6 2 2 6 +``` + +TensorTrain-level operations against ITensorMPS: + +```bash +RAYON_NUM_THREADS=1 cargo run -p tensor4all-itensorlike --example benchmark_tt_ops --release -- --L 32 --zipup-L 10 --chis 4,8,16,32,64 +``` + +Inspect Julia-dumped local linsolve inputs: + +```bash +cargo run -p tensor4all-hdf5 --example inspect_mps_inputs --release -- benchmarks/results/local_linsolve_inputs_N38_b32_o32.h5 +``` + ### Julia ```bash @@ -23,6 +54,36 @@ cd external/ITensorMPS.jl julia benchmark/benchmark_contract.jl ``` +Projected local-operator apply: + +```bash +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_projected_apply.jl 38 32 32 3 0 +``` + +Prepared local linsolve: + +```bash +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_local_linsolve.jl 38 32 32 1 1 10 +``` + +Non-AD local tensor operations: + +```bash +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_tensor_ops.jl 20000 6 2 2 6 +``` + +TensorTrain-level operations against tensor4all: + +```bash +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_tt_ops.jl --L 32 --zipup-L 10 --chis 4,8,16,32,64 +``` + +Dump local linsolve inputs as ITensorMPS-compatible HDF5: + +```bash +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/dump_local_linsolve_inputs.jl benchmarks/results/local_linsolve_inputs_N38_b32_o32.h5 38 32 32 +``` + ## Benchmark Details Both benchmarks perform MPO-MPO contraction using zip-up algorithm: @@ -31,3 +92,36 @@ Both benchmarks perform MPO-MPO contraction using zip-up algorithm: - Bond dimension: 50 - Max rank: 50 - Includes orthogonalization time in measurements + +The projected local-operator apply benchmarks isolate the local matvec hot path +used by two-site TreeTN/ITensor-style local solves. The Rust source of truth is +`benchmarks/rust/benchmark_projected_apply.rs`, included by the cargo example +target under `tensor4all-treetn`; the Julia counterpart is +`benchmarks/julia/benchmark_projected_apply.jl`. + +The prepared local linsolve benchmarks construct the operator, right-hand side, +and initial state once, then time the local solve body. They also report local +GMRES/apply/RHS/factorization buckets. Use Julia `maxiter=1, krylovdim=10` as a +rough match to Rust's `krylov_maxiter=10` total-iteration cap; KrylovKit's +`maxiter=10, krylovdim=30` performs far more local operator applications. + +The non-AD local tensor operation benchmarks isolate `inner`, `norm`, affine +addition, and explicit `conj`-then-contract on a small dense tensor. The default +shape `[6, 2, 2, 6]` mirrors the small two-site local tensors observed in the +QuanticsNEGF long-time local Krylov test, where dispatch/allocation overhead can +dominate floating-point work. + +The TensorTrain-level operation benchmarks compare tensor4all's +`TensorTrain::inner`, strict direct-sum MPS addition, and prepared MPO×MPO zipup +contraction against ITensorMPS.jl. They use deterministic Complex64 fixtures and +print CSV-style rows with sample counts, min/median/mean/max milliseconds, +result max bond dimension, and a checksum. The Rust source of truth is +`benchmarks/rust/benchmark_tt_ops.rs`, included by +`tensor4all-itensorlike/examples/benchmark_tt_ops.rs`; the Julia counterpart is +`benchmarks/julia/benchmark_tt_ops.jl`. + +`dump_local_linsolve_inputs.jl` writes the prepared local operator as +`operator_as_mps`, plus `rhs` and `init`, in one HDF5 file. The operator is a +Julia `MPO` stored through the `MPS` schema by saving its site tensors as +`MPS([H[i] for i in 1:length(H)])`; Rust reads all three groups with +`tensor4all_hdf5::load_mps`. diff --git a/benchmarks/julia/Project.toml b/benchmarks/julia/Project.toml index d37edcca..a44ed07b 100644 --- a/benchmarks/julia/Project.toml +++ b/benchmarks/julia/Project.toml @@ -1,5 +1,12 @@ [deps] FastMPOContractions = "f6e391d2-8ffa-4d7a-98cd-7e70024481ca" +HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" ITensorMPS = "0d1a4710-d33b-49a5-8f18-73bdf49b47e2" +ITensorTDVP = "25707e16-a4db-4a07-99d9-4d67b7af0342" +ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" + +[compat] +ITensorTDVP = "^0.4" +ITensors = "^0.6" diff --git a/benchmarks/julia/benchmark_local_linsolve.jl b/benchmarks/julia/benchmark_local_linsolve.jl new file mode 100644 index 00000000..b91403e9 --- /dev/null +++ b/benchmarks/julia/benchmark_local_linsolve.jl @@ -0,0 +1,287 @@ +# Benchmark prepared local linsolve using ITensorTDVP.linsolve. +# +# Run: +# BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_local_linsolve.jl +# +# Optional args: +# BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_local_linsolve.jl + +import Pkg + +Pkg.activate(@__DIR__) +Pkg.instantiate() + +using ITensors +using ITensorTDVP +using KrylovKit +using LinearAlgebra +using Printf +using Random + +ITensors.disable_warn_order() + +mutable struct LocalSolveStats + local_updates::Int + krylov_iterations::Int + krylov_ops::Int + rhs_time::Float64 + gmres_time::Float64 + apply_time::Float64 +end + +LocalSolveStats() = LocalSolveStats(0, 0, 0, 0.0, 0.0, 0.0) + +function parse_positive_int_arg(args::Vector{String}, index::Int, default::Int, name::String)::Int + value = index <= length(args) ? parse(Int, args[index]) : default + value > 0 || error("$name must be greater than zero") + return value +end + +function maybe_set_blas_threads_from_env!() + haskey(ENV, "BLAS_NUM_THREADS") || return + nthreads = parse(Int, ENV["BLAS_NUM_THREADS"]) + nthreads > 0 || error("BLAS_NUM_THREADS must be greater than zero") + BLAS.set_num_threads(nthreads) +end + +function state_indices( + site::Int, + nsites::Int, + acted_sites::Vector{Index{Int64}}, + spectator_sites::Vector{Index{Int64}}, + state_links::Vector{Index{Int64}}, +) + sites = Index{Int64}[spectator_sites[site], acted_sites[site]] + if nsites == 1 + return sites + elseif site == 1 + return vcat(sites, Index{Int64}[state_links[site]]) + elseif site == nsites + return vcat(Index{Int64}[state_links[site - 1]], sites) + else + return vcat(Index{Int64}[state_links[site - 1]], sites, Index{Int64}[state_links[site]]) + end +end + +function operator_indices(site::Int, nsites::Int, operator_links::Vector{Index{Int64}}) + if nsites == 1 + return Index{Int64}[] + elseif site == 1 + return Index{Int64}[operator_links[site]] + elseif site == nsites + return Index{Int64}[operator_links[site - 1]] + else + return Index{Int64}[operator_links[site - 1], operator_links[site]] + end +end + +function make_state_mps( + rng::AbstractRNG, + nsites::Int, + state_bond_dim::Int, + acted_sites::Vector{Index{Int64}}, + spectator_sites::Vector{Index{Int64}}, +)::MPS + state_links = [Index(state_bond_dim, "Link,psi,l=$site") for site in 1:(nsites - 1)] + tensors = Vector{ITensor}(undef, nsites) + for site in 1:nsites + indices = state_indices(site, nsites, acted_sites, spectator_sites, state_links) + tensors[site] = random_itensor(rng, indices...) + end + return MPS(tensors) +end + +function make_operator_mpo( + rng::AbstractRNG, + nsites::Int, + operator_bond_dim::Int, + acted_sites::Vector{Index{Int64}}, + spectator_sites::Vector{Index{Int64}}, +)::MPO + operator_links = [Index(operator_bond_dim, "Link,H,l=$site") for site in 1:(nsites - 1)] + tensors = Vector{ITensor}(undef, nsites) + for site in 1:nsites + core_indices = vcat( + operator_indices(site, nsites, operator_links), + Index{Int64}[acted_sites[site], prime(acted_sites[site])], + ) + core = random_itensor(rng, core_indices...) + spectator_identity = delta(spectator_sites[site], prime(spectator_sites[site])) + tensors[site] = core * spectator_identity + end + return MPO(tensors) +end + +function elapsed_seconds(f)::Tuple{Float64, Any} + start = time_ns() + result = f() + return ((time_ns() - start) / 1.0e9, result) +end + +function timed_linsolve_updater( + problem, + init; + internal_kwargs, + coefficients, + stats::LocalSolveStats, + kwargs..., +) + stats.local_updates += 1 + operator = ITensorTDVP.operator(problem) + + rhs_time, rhs = elapsed_seconds() do + return ITensorTDVP.constant_term(problem) + end + stats.rhs_time += rhs_time + + function timed_operator(x) + apply_time, y = elapsed_seconds() do + return operator(x) + end + stats.apply_time += apply_time + return y + end + + gmres_time, solve_result = elapsed_seconds() do + return KrylovKit.linsolve( + timed_operator, + rhs, + init, + coefficients[1], + coefficients[2]; + kwargs..., + ) + end + stats.gmres_time += gmres_time + x, info = solve_result + stats.krylov_iterations += info.numiter + stats.krylov_ops += info.numops + return x, (; info) +end + +function run_prepared_solve(H, rhs, init; nsweeps, cutoff, maxdim, a0, a1, tol, maxiter, krylovdim) + stats = LocalSolveStats() + solve_time, solution = elapsed_seconds() do + return ITensorTDVP.linsolve( + H, + rhs, + init, + a0, + a1; + maxdim, + cutoff, + nsweeps, + nsite=2, + reverse_step=false, + outputlevel=0, + updater=timed_linsolve_updater, + updater_kwargs=(; + stats, + ishermitian=false, + tol, + maxiter, + krylovdim, + ), + ) + end + return solve_time, solution, stats +end + +function main(args::Vector{String}) + nsites = parse_positive_int_arg(args, 1, 38, "N") + state_bond_dim = parse_positive_int_arg(args, 2, 8, "state_bond_dim") + operator_bond_dim = parse_positive_int_arg(args, 3, 8, "operator_bond_dim") + nsweeps = parse_positive_int_arg(args, 4, 1, "nsweeps") + maxiter = parse_positive_int_arg(args, 5, 10, "krylov_maxiter") + krylovdim = parse_positive_int_arg(args, 6, 30, "krylov_dim") + + nsites >= 2 || error("N must be at least 2 for a two-site local solve") + maybe_set_blas_threads_from_env!() + + phys_dim = 2 + seed = 20260518 + cutoff = 0.0 + maxdim = state_bond_dim + a0 = 1.0 + a1 = 0.01 + tol = 1.0e-30 + + setup_time, prepared = elapsed_seconds() do + rng = MersenneTwister(seed) + acted_sites = [Index(phys_dim, "s=$site") for site in 1:nsites] + spectator_sites = [Index(phys_dim, "q=$site") for site in 1:nsites] + state = make_state_mps(rng, nsites, state_bond_dim, acted_sites, spectator_sites) + operator = make_operator_mpo( + rng, + nsites, + operator_bond_dim, + acted_sites, + spectator_sites, + ) + return (; operator, rhs=deepcopy(state), init=deepcopy(state)) + end + H = prepared.operator + rhs = prepared.rhs + init = prepared.init + + # Compile the relevant local solve path outside the reported solve timing. + run_prepared_solve( + H, + rhs, + init; + nsweeps=1, + cutoff, + maxdim, + a0, + a1, + tol, + maxiter, + krylovdim, + ) + GC.gc() + + solve_time, solution, stats = run_prepared_solve( + H, + rhs, + init; + nsweeps, + cutoff, + maxdim, + a0, + a1, + tol, + maxiter, + krylovdim, + ) + + println("=== Prepared local linsolve benchmark (Julia/ITensorTDVP) ===") + println("N = $nsites") + println("phys_dim = $phys_dim") + println("state_bond_dim = $state_bond_dim") + println("operator_bond_dim = $operator_bond_dim") + println("nsweeps = $nsweeps") + println("krylov_maxiter = $maxiter") + println("krylov_dim = $krylovdim") + @printf("krylov_tol = %.1e\n", tol) + println("coefficients = ($a0, $a1)") + println("threads = $(Threads.nthreads())") + println("blas_threads = $(BLAS.get_num_threads())") + println() + + println("--- Prepared solve ---") + @printf("setup excluded from solve: %.3f ms\n", setup_time * 1000.0) + @printf("solve total: %.3f ms\n", solve_time * 1000.0) + println("local_updates = $(stats.local_updates)") + println("krylov_iterations = $(stats.krylov_iterations)") + println("krylov_ops = $(stats.krylov_ops)") + @printf("rhs projection inside updates: %.3f ms\n", stats.rhs_time * 1000.0) + @printf("local GMRES total: %.3f ms\n", stats.gmres_time * 1000.0) + @printf("projected apply inside GMRES: %.3f ms\n", stats.apply_time * 1000.0) + @printf( + "replacebond/factorization/orthogonalization overhead: %.3f ms\n", + max(0.0, solve_time - stats.gmres_time - stats.rhs_time) * 1000.0, + ) + println("solution max bond dim = $(maxlinkdim(solution))") +end + +main(ARGS) diff --git a/benchmarks/julia/benchmark_projected_apply.jl b/benchmarks/julia/benchmark_projected_apply.jl new file mode 100644 index 00000000..4b73a9ca --- /dev/null +++ b/benchmarks/julia/benchmark_projected_apply.jl @@ -0,0 +1,228 @@ +# Benchmark isolated ITensors.ProjMPO local apply calls. +# +# Run: +# julia --project=benchmarks/julia benchmarks/julia/benchmark_projected_apply.jl +# +# Optional args: +# julia --project=benchmarks/julia benchmarks/julia/benchmark_projected_apply.jl +# +# For a one-thread comparison with the Rust example: +# BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_projected_apply.jl 38 32 32 3 0 + +import Pkg + +Pkg.activate(@__DIR__) +Pkg.instantiate() + +using ITensors +using LinearAlgebra +using Printf +using Random + +ITensors.disable_warn_order() + +function parse_positive_int_arg(args::Vector{String}, index::Int, default::Int, name::String)::Int + value = index <= length(args) ? parse(Int, args[index]) : default + value > 0 || error("$name must be greater than zero") + return value +end + +function parse_nonnegative_int_arg(args::Vector{String}, index::Int, default::Int, name::String)::Int + value = index <= length(args) ? parse(Int, args[index]) : default + value >= 0 || error("$name must be nonnegative") + return value +end + +function maybe_set_blas_threads_from_env!() + haskey(ENV, "BLAS_NUM_THREADS") || return + nthreads = parse(Int, ENV["BLAS_NUM_THREADS"]) + nthreads > 0 || error("BLAS_NUM_THREADS must be greater than zero") + BLAS.set_num_threads(nthreads) +end + +function state_indices( + site::Int, + nsites::Int, + acted_sites::Vector{Index{Int64}}, + spectator_sites::Vector{Index{Int64}}, + state_links::Vector{Index{Int64}}, +) + sites = Index{Int64}[spectator_sites[site], acted_sites[site]] + if nsites == 1 + return sites + elseif site == 1 + return vcat(sites, Index{Int64}[state_links[site]]) + elseif site == nsites + return vcat(Index{Int64}[state_links[site - 1]], sites) + else + return vcat(Index{Int64}[state_links[site - 1]], sites, Index{Int64}[state_links[site]]) + end +end + +function operator_indices(site::Int, nsites::Int, operator_links::Vector{Index{Int64}}) + if nsites == 1 + return Index{Int64}[] + elseif site == 1 + return Index{Int64}[operator_links[site]] + elseif site == nsites + return Index{Int64}[operator_links[site - 1]] + else + return Index{Int64}[operator_links[site - 1], operator_links[site]] + end +end + +function make_state_mps( + rng::AbstractRNG, + nsites::Int, + state_bond_dim::Int, + acted_sites::Vector{Index{Int64}}, + spectator_sites::Vector{Index{Int64}}, +)::MPS + state_links = [Index(state_bond_dim, "Link,psi,l=$site") for site in 1:(nsites - 1)] + tensors = Vector{ITensor}(undef, nsites) + for site in 1:nsites + indices = state_indices(site, nsites, acted_sites, spectator_sites, state_links) + tensors[site] = random_itensor(rng, indices...) + end + return MPS(tensors) +end + +function make_operator_mpo( + rng::AbstractRNG, + nsites::Int, + operator_bond_dim::Int, + acted_sites::Vector{Index{Int64}}, + spectator_sites::Vector{Index{Int64}}, +)::MPO + operator_links = [Index(operator_bond_dim, "Link,H,l=$site") for site in 1:(nsites - 1)] + tensors = Vector{ITensor}(undef, nsites) + for site in 1:nsites + core_indices = vcat( + operator_indices(site, nsites, operator_links), + Index{Int64}[acted_sites[site], prime(acted_sites[site])], + ) + core = random_itensor(rng, core_indices...) + spectator_identity = delta(spectator_sites[site], prime(spectator_sites[site])) + tensors[site] = core * spectator_identity + end + return MPO(tensors) +end + +function two_site_sweep_positions(nsites::Int, center_pos::Int)::Vector{Int} + positions = collect(center_pos:(nsites - 1)) + append!(positions, collect((center_pos - 1):-1:1)) + return positions +end + +function elapsed_seconds(f)::Tuple{Float64, Any} + start = time_ns() + result = f() + return ((time_ns() - start) / 1.0e9, result) +end + +function summarize(label::String, times::Vector{Float64}) + mean = sum(times) / length(times) + min_time = minimum(times) + max_time = maximum(times) + @printf( + "%s: mean=%.3f ms min=%.3f ms max=%.3f ms n=%d\n", + label, + mean * 1000.0, + min_time * 1000.0, + max_time * 1000.0, + length(times), + ) +end + +function main(args::Vector{String}) + nsites = parse_positive_int_arg(args, 1, 38, "N") + state_bond_dim = parse_positive_int_arg(args, 2, 8, "state_bond_dim") + operator_bond_dim = parse_positive_int_arg(args, 3, 8, "operator_bond_dim") + repeats = parse_positive_int_arg(args, 4, 20, "repeats") + step_index = parse_nonnegative_int_arg(args, 5, 0, "step_index") + + nsites >= 2 || error("N must be at least 2 for a two-site local step") + maybe_set_blas_threads_from_env!() + + phys_dim = 2 + seed = 20260518 + rng = MersenneTwister(seed) + + acted_sites = [Index(phys_dim, "s=$site") for site in 1:nsites] + spectator_sites = [Index(phys_dim, "q=$site") for site in 1:nsites] + + psi = make_state_mps(rng, nsites, state_bond_dim, acted_sites, spectator_sites) + H = make_operator_mpo(rng, nsites, operator_bond_dim, acted_sites, spectator_sites) + + center_pos = nsites ÷ 2 + 1 + positions = two_site_sweep_positions(nsites, center_pos) + local_pos = positions[mod(step_index, length(positions)) + 1] + phi = psi[local_pos] * psi[local_pos + 1] + + # Compile the relevant ITensor/ProjMPO code before taking timings. + warmup_projected = ITensors.ProjMPO(H) + position!(warmup_projected, psi, local_pos) + warmup_out = warmup_projected(phi) + GC.@preserve warmup_out nothing + GC.gc() + + println("=== ProjMPO local apply benchmark ===") + println("N = $nsites") + println("phys_dim = $phys_dim") + println("state_bond_dim = $state_bond_dim") + println("operator_bond_dim = $operator_bond_dim") + println("repeats = $repeats") + println("threads = $(Threads.nthreads())") + println("blas_threads = $(BLAS.get_num_threads())") + println("center_pos = $center_pos") + println("step_index = $(mod(step_index, length(positions)))") + println("local_sites = ($local_pos, $(local_pos + 1))") + println("local_dims = $(dim.(inds(phi)))") + println() + + projected_ref = Ref{Any}() + cold_time, cold_out = elapsed_seconds() do + projected = ITensors.ProjMPO(H) + position!(projected, psi, local_pos) + out = projected(phi) + projected_ref[] = projected + return out + end + @printf( + "cold apply (environment build + one apply): %.3f ms, output_rank=%d\n", + cold_time * 1000.0, + order(cold_out), + ) + + projected = projected_ref[] + warm_times = Float64[] + sizehint!(warm_times, repeats) + warm_order_sum = 0 + for _ in 1:repeats + t, out = elapsed_seconds() do + return projected(phi) + end + warm_order_sum += order(out) + push!(warm_times, t) + end + summarize("warm apply (environment cache hot)", warm_times) + + cold_times = Float64[] + sizehint!(cold_times, repeats) + cold_order_sum = 0 + for _ in 1:repeats + t, out = elapsed_seconds() do + projected_cold = ITensors.ProjMPO(H) + position!(projected_cold, psi, local_pos) + return projected_cold(phi) + end + cold_order_sum += order(out) + push!(cold_times, t) + end + summarize("cold apply repeated (fresh environment cache)", cold_times) + + GC.@preserve warm_order_sum cold_order_sum nothing + return nothing +end + +main(ARGS) diff --git a/benchmarks/julia/benchmark_tensor_ops.jl b/benchmarks/julia/benchmark_tensor_ops.jl new file mode 100644 index 00000000..aa48302a --- /dev/null +++ b/benchmarks/julia/benchmark_tensor_ops.jl @@ -0,0 +1,127 @@ +# Benchmark non-AD dense ITensor vector-space operations. +# +# Run: +# BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_tensor_ops.jl +# +# Optional args: +# julia --project=benchmarks/julia benchmarks/julia/benchmark_tensor_ops.jl ... +# +# Example matching the Rust command: +# BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_tensor_ops.jl 20000 6 2 2 6 + +import Pkg + +Pkg.activate(@__DIR__) +Pkg.instantiate() + +using ITensors +using LinearAlgebra +using Printf +using Random + +ITensors.disable_warn_order() + +function parse_args(args::Vector{String}) + repeats = length(args) >= 1 ? parse(Int, args[1]) : 20_000 + repeats > 0 || error("repeats must be greater than zero") + dims = length(args) >= 2 ? parse.(Int, args[2:end]) : [6, 2, 2, 6] + !isempty(dims) || error("at least one dimension is required") + all(>(0), dims) || error("all dimensions must be positive") + return repeats, dims +end + +function maybe_set_blas_threads_from_env!() + haskey(ENV, "BLAS_NUM_THREADS") || return + nthreads = parse(Int, ENV["BLAS_NUM_THREADS"]) + nthreads > 0 || error("BLAS_NUM_THREADS must be greater than zero") + BLAS.set_num_threads(nthreads) +end + +function elapsed_seconds(f) + start = time_ns() + result = f() + return (time_ns() - start) / 1.0e9, result +end + +function main() + maybe_set_blas_threads_from_env!() + repeats, dims = parse_args(ARGS) + rng = MersenneTwister(0x5eed1234) + inds = [Index(dim, "bench,n=$n") for (n, dim) in enumerate(dims)] + a = random_itensor(rng, ComplexF64, inds...) + b = random_itensor(rng, ComplexF64, inds...) + alpha = 0.7 - 0.2im + beta = -0.3 + 0.4im + + for _ in 1:32 + inner(a, b) + norm(a) + alpha * a + beta * b + (conj(a) * b)[] + end + + element_count = prod(dims) + println("=== ITensors.jl non-AD tensor ops benchmark ===") + println("dims=$(dims) elements=$(element_count) repeats=$(repeats) dtype=ComplexF64") + + inner_seconds, inner_checksum = elapsed_seconds() do + checksum = zero(ComplexF64) + for _ in 1:repeats + checksum += inner(a, b) + end + checksum + end + @printf( + "inner_seconds = %.6f per_call_us = %.3f checksum = %.6e%+.6eim\n", + inner_seconds, + inner_seconds * 1.0e6 / repeats, + real(inner_checksum), + imag(inner_checksum), + ) + + norm_seconds, norm_checksum = elapsed_seconds() do + checksum = 0.0 + for _ in 1:repeats + checksum += norm(a) + end + checksum + end + @printf( + "norm_seconds = %.6f per_call_us = %.3f checksum = %.6e\n", + norm_seconds, + norm_seconds * 1.0e6 / repeats, + norm_checksum, + ) + + axpby_seconds, axpby_checksum = elapsed_seconds() do + checksum = 0.0 + for _ in 1:repeats + out = alpha * a + beta * b + checksum += norm(out) + end + checksum + end + @printf( + "axpby_seconds = %.6f per_call_us = %.3f checksum = %.6e\n", + axpby_seconds, + axpby_seconds * 1.0e6 / repeats, + axpby_checksum, + ) + + conj_contract_seconds, conj_contract_checksum = elapsed_seconds() do + checksum = zero(ComplexF64) + for _ in 1:repeats + checksum += (conj(a) * b)[] + end + checksum + end + @printf( + "conj_contract_sum_seconds = %.6f per_call_us = %.3f checksum = %.6e%+.6eim\n", + conj_contract_seconds, + conj_contract_seconds * 1.0e6 / repeats, + real(conj_contract_checksum), + imag(conj_contract_checksum), + ) +end + +main() diff --git a/benchmarks/julia/benchmark_tt_ops.jl b/benchmarks/julia/benchmark_tt_ops.jl new file mode 100644 index 00000000..925cea3c --- /dev/null +++ b/benchmarks/julia/benchmark_tt_ops.jl @@ -0,0 +1,272 @@ +#!/usr/bin/env julia + +using LinearAlgebra +using Printf +using Statistics +using ITensors +using ITensorMPS + +const DEFAULT_L = 32 +const DEFAULT_D = 2 +const DEFAULT_ZIPUP_L = 10 +const DEFAULT_CHIS = [4, 8, 16, 32, 64] +const DEFAULT_WARMUP_SECONDS = 1.0 +const DEFAULT_MEASUREMENT_SECONDS = 2.0 +const DEFAULT_MIN_SAMPLES = 10 + +function usage() + println(""" +Usage: julia --project=benchmarks/julia benchmarks/julia/benchmark_tt_ops.jl [options] + +Options: + --L N MPS length (default: $(DEFAULT_L)) + --d N Physical dimension (default: $(DEFAULT_D)) + --zipup-L N MPO zipup length (default: $(DEFAULT_ZIPUP_L)) + --chis LIST Comma-separated bond dimensions (default: $(join(DEFAULT_CHIS, ","))) + --warm-up-time SECONDS Warm-up time after first JIT call (default: $(DEFAULT_WARMUP_SECONDS)) + --measurement-time SECONDS Measurement time per case (default: $(DEFAULT_MEASUREMENT_SECONDS)) + --min-samples N Minimum samples per case (default: $(DEFAULT_MIN_SAMPLES)) + --blas-threads N Julia BLAS threads (default: 1) + --help Show this help text +""") +end + +function parse_args(args) + opts = Dict{String, Any}( + "L" => DEFAULT_L, + "d" => DEFAULT_D, + "zipup_L" => DEFAULT_ZIPUP_L, + "chis" => copy(DEFAULT_CHIS), + "warm_up_time" => DEFAULT_WARMUP_SECONDS, + "measurement_time" => DEFAULT_MEASUREMENT_SECONDS, + "min_samples" => DEFAULT_MIN_SAMPLES, + "blas_threads" => 1, + ) + + i = 1 + while i <= length(args) + arg = args[i] + if arg == "--help" + usage() + exit(0) + elseif arg == "--L" + i += 1 + opts["L"] = parse(Int, args[i]) + elseif arg == "--d" + i += 1 + opts["d"] = parse(Int, args[i]) + elseif arg == "--zipup-L" + i += 1 + opts["zipup_L"] = parse(Int, args[i]) + elseif arg == "--chis" + i += 1 + opts["chis"] = parse.(Int, split(args[i], ",")) + elseif arg == "--warm-up-time" + i += 1 + opts["warm_up_time"] = parse(Float64, args[i]) + elseif arg == "--measurement-time" + i += 1 + opts["measurement_time"] = parse(Float64, args[i]) + elseif arg == "--min-samples" + i += 1 + opts["min_samples"] = parse(Int, args[i]) + elseif arg == "--blas-threads" + i += 1 + opts["blas_threads"] = parse(Int, args[i]) + else + error("unknown argument: $arg") + end + i += 1 + end + + return opts +end + +function deterministic_value(idx::Int, seed::Int)::ComplexF64 + real = ((idx * 17 + seed * 13 + 3) % 97) / 97 - 0.5 + imag = ((idx * 29 + seed * 7 + 5) % 89) / 89 - 0.5 + return ComplexF64(real, imag) +end + +function deterministic_itensor(inds::Tuple, seed::Int)::ITensor + dims = map(dim, inds) + data = Vector{ComplexF64}(undef, prod(dims)) + @inbounds for pos in eachindex(data) + data[pos] = deterministic_value(pos - 1, seed) + end + return ITensor(reshape(data, dims...), inds...) +end + +function deterministic_mps(sites::Vector{<:Index}, chi::Int, seed_offset::Int)::MPS + nsites = length(sites) + links = [Index(chi, "Link,mps,l=$n,seed=$seed_offset") for n in 1:(nsites - 1)] + tensors = Vector{ITensor}(undef, nsites) + + for site in 1:nsites + inds = + if nsites == 1 + (sites[site],) + elseif site == 1 + (sites[site], links[site]) + elseif site == nsites + (links[site - 1], sites[site]) + else + (links[site - 1], sites[site], links[site]) + end + tensors[site] = deterministic_itensor(inds, seed_offset + site) + end + + return MPS(tensors) +end + +function deterministic_mpo( + input_sites::Vector{<:Index}, + output_sites::Vector{<:Index}, + chi::Int, + seed_offset::Int, +)::MPO + nsites = length(input_sites) + @assert length(output_sites) == nsites + links = [Index(chi, "Link,mpo,l=$n,seed=$seed_offset") for n in 1:(nsites - 1)] + tensors = Vector{ITensor}(undef, nsites) + + for site in 1:nsites + inds = + if nsites == 1 + (input_sites[site], output_sites[site]) + elseif site == 1 + (input_sites[site], output_sites[site], links[site]) + elseif site == nsites + (links[site - 1], input_sites[site], output_sites[site]) + else + (links[site - 1], input_sites[site], output_sites[site], links[site]) + end + tensors[site] = deterministic_itensor(inds, seed_offset + site) + end + + return MPO(tensors) +end + +function run_for_seconds(f; warmup_seconds::Float64, measurement_seconds::Float64, min_samples::Int) + sink = f() + GC.gc() + + warmup_start = time_ns() + while (time_ns() - warmup_start) / 1.0e9 < warmup_seconds + sink = f() + end + GC.gc() + + times_ms = Float64[] + measurement_start = time_ns() + while (time_ns() - measurement_start) / 1.0e9 < measurement_seconds || length(times_ms) < min_samples + start = time_ns() + sink = f() + push!(times_ms, (time_ns() - start) / 1.0e6) + end + + return sink, times_ms +end + +function print_result(case; params, times_ms, value, max_bond) + @printf( + "%s,%s,%d,%.6f,%.6f,%.6f,%.6f,%d,%s\n", + case, + params, + length(times_ms), + minimum(times_ms), + median(times_ms), + mean(times_ms), + maximum(times_ms), + max_bond, + repr(value), + ) +end + +function main() + opts = parse_args(ARGS) + BLAS.set_num_threads(opts["blas_threads"]) + + L = opts["L"] + d = opts["d"] + zipup_L = opts["zipup_L"] + chis = opts["chis"] + warmup_seconds = opts["warm_up_time"] + measurement_seconds = opts["measurement_time"] + min_samples = opts["min_samples"] + + println("ITensorMPS TensorTrain-level ops benchmark") + println(" Julia: $(VERSION)") + println(" ITensors: $(Base.pkgversion(ITensors))") + println(" ITensorMPS: $(Base.pkgversion(ITensorMPS))") + println(" threads: Julia=$(Threads.nthreads()) BLAS=$(BLAS.get_num_threads())") + println(" L: $L") + println(" d: $d") + println(" zipup L: $zipup_L") + println(" chis: $(join(chis, ","))") + println(" warm-up time: $warmup_seconds") + println(" measurement time: $measurement_seconds") + println(" min samples: $min_samples") + println() + println("case,params,samples,min_ms,median_ms,mean_ms,max_ms,max_bond,value") + + for chi in chis + sites = [Index(d, "Site,n=$site") for site in 1:L] + bra = deterministic_mps(sites, chi, 0) + ket = deterministic_mps(sites, chi, L) + mps_params = "L_$(L)_chi_$(chi)_d_$(d)" + + inner_value, inner_times = run_for_seconds( + () -> inner(bra, ket); + warmup_seconds, + measurement_seconds, + min_samples, + ) + print_result( + "itensormps_inner_mps"; + params = mps_params, + times_ms = inner_times, + value = inner_value, + max_bond = 0, + ) + + sum_mps, directsum_times = run_for_seconds( + () -> +(bra, ket; alg = "directsum"); + warmup_seconds, + measurement_seconds, + min_samples, + ) + print_result( + "itensormps_directsum_mps"; + params = mps_params, + times_ms = directsum_times, + value = inner(sum_mps, sum_mps), + max_bond = maxlinkdim(sum_mps), + ) + + zipup_sites_in = [Index(d, "ZipIn,n=$site") for site in 1:zipup_L] + zipup_sites_mid = [Index(d, "ZipMid,n=$site") for site in 1:zipup_L] + zipup_sites_out = [Index(d, "ZipOut,n=$site") for site in 1:zipup_L] + mpo_a = deterministic_mpo(zipup_sites_in, zipup_sites_mid, chi, 2L) + mpo_b = deterministic_mpo(zipup_sites_mid, zipup_sites_out, chi, 2L + zipup_L) + orthogonalize!(mpo_a, 1) + orthogonalize!(mpo_b, 1) + zipup_params = "L_$(zipup_L)_chi_$(chi)_d_$(d)_maxdim_$(chi)" + + zipup_result, zipup_times = run_for_seconds( + () -> contract(mpo_a, mpo_b; alg = "zipup", maxdim = chi, cutoff = 0.0); + warmup_seconds, + measurement_seconds, + min_samples, + ) + print_result( + "itensormps_zipup_mpo_prepared"; + params = zipup_params, + times_ms = zipup_times, + value = inner(zipup_result, zipup_result), + max_bond = maxlinkdim(zipup_result), + ) + end +end + +main() diff --git a/benchmarks/julia/dump_local_linsolve_inputs.jl b/benchmarks/julia/dump_local_linsolve_inputs.jl new file mode 100644 index 00000000..5667de21 --- /dev/null +++ b/benchmarks/julia/dump_local_linsolve_inputs.jl @@ -0,0 +1,167 @@ +# Dump prepared local linsolve inputs in ITensorMPS-compatible HDF5. +# +# Run: +# BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/dump_local_linsolve_inputs.jl +# +# Optional args: +# BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/dump_local_linsolve_inputs.jl + +import Pkg + +Pkg.activate(@__DIR__) +Pkg.instantiate() + +using HDF5 +using ITensors +using LinearAlgebra +using Printf +using Random + +ITensors.disable_warn_order() + +function parse_positive_int_arg(args::Vector{String}, index::Int, default::Int, name::String)::Int + value = index <= length(args) ? parse(Int, args[index]) : default + value > 0 || error("$name must be greater than zero") + return value +end + +function maybe_set_blas_threads_from_env!() + haskey(ENV, "BLAS_NUM_THREADS") || return + nthreads = parse(Int, ENV["BLAS_NUM_THREADS"]) + nthreads > 0 || error("BLAS_NUM_THREADS must be greater than zero") + BLAS.set_num_threads(nthreads) +end + +function state_indices( + site::Int, + nsites::Int, + acted_sites::Vector{Index{Int64}}, + spectator_sites::Vector{Index{Int64}}, + state_links::Vector{Index{Int64}}, +) + sites = Index{Int64}[spectator_sites[site], acted_sites[site]] + if nsites == 1 + return sites + elseif site == 1 + return vcat(sites, Index{Int64}[state_links[site]]) + elseif site == nsites + return vcat(Index{Int64}[state_links[site - 1]], sites) + else + return vcat(Index{Int64}[state_links[site - 1]], sites, Index{Int64}[state_links[site]]) + end +end + +function operator_indices(site::Int, nsites::Int, operator_links::Vector{Index{Int64}}) + if nsites == 1 + return Index{Int64}[] + elseif site == 1 + return Index{Int64}[operator_links[site]] + elseif site == nsites + return Index{Int64}[operator_links[site - 1]] + else + return Index{Int64}[operator_links[site - 1], operator_links[site]] + end +end + +function make_state_mps( + rng::AbstractRNG, + nsites::Int, + state_bond_dim::Int, + acted_sites::Vector{Index{Int64}}, + spectator_sites::Vector{Index{Int64}}, +)::MPS + state_links = [Index(state_bond_dim, "Link,psi,l=$site") for site in 1:(nsites - 1)] + tensors = Vector{ITensor}(undef, nsites) + for site in 1:nsites + indices = state_indices(site, nsites, acted_sites, spectator_sites, state_links) + tensors[site] = random_itensor(rng, indices...) + end + return MPS(tensors) +end + +function make_operator_mpo( + rng::AbstractRNG, + nsites::Int, + operator_bond_dim::Int, + acted_sites::Vector{Index{Int64}}, + spectator_sites::Vector{Index{Int64}}, +)::MPO + operator_links = [Index(operator_bond_dim, "Link,H,l=$site") for site in 1:(nsites - 1)] + tensors = Vector{ITensor}(undef, nsites) + for site in 1:nsites + core_indices = vcat( + operator_indices(site, nsites, operator_links), + Index{Int64}[acted_sites[site], prime(acted_sites[site])], + ) + core = random_itensor(rng, core_indices...) + spectator_identity = delta(spectator_sites[site], prime(spectator_sites[site])) + tensors[site] = core * spectator_identity + end + return MPO(tensors) +end + +function default_output_path(nsites::Int, state_bond_dim::Int, operator_bond_dim::Int)::String + root = normpath(joinpath(@__DIR__, "..", "results")) + return joinpath( + root, + "local_linsolve_inputs_N$(nsites)_b$(state_bond_dim)_o$(operator_bond_dim).h5", + ) +end + +function summarize_mps(name::String, psi::MPS) + site_counts = [length(inds(psi[i])) for i in 1:length(psi)] + println("$name.length = $(length(psi))") + println("$name.maxlinkdim = $(maxlinkdim(psi))") + println("$name.tensor_index_counts = $(site_counts)") + if length(psi) > 0 + println("$name.first_inds = $(inds(psi[1]))") + println("$name.last_inds = $(inds(psi[end]))") + end +end + +function main(args::Vector{String}) + nsites = parse_positive_int_arg(args, 2, 38, "N") + state_bond_dim = parse_positive_int_arg(args, 3, 32, "state_bond_dim") + operator_bond_dim = parse_positive_int_arg(args, 4, 32, "operator_bond_dim") + nsites >= 2 || error("N must be at least 2") + maybe_set_blas_threads_from_env!() + + output_path = length(args) >= 1 ? args[1] : default_output_path(nsites, state_bond_dim, operator_bond_dim) + mkpath(dirname(output_path)) + + phys_dim = 2 + seed = 20260518 + rng = MersenneTwister(seed) + acted_sites = [Index(phys_dim, "s=$site") for site in 1:nsites] + spectator_sites = [Index(phys_dim, "q=$site") for site in 1:nsites] + rhs = make_state_mps(rng, nsites, state_bond_dim, acted_sites, spectator_sites) + init = deepcopy(rhs) + operator = make_operator_mpo(rng, nsites, operator_bond_dim, acted_sites, spectator_sites) + operator_as_mps = MPS([operator[i] for i in 1:length(operator)]) + + h5open(output_path, "w") do file + write(file, "operator_as_mps", operator_as_mps) + write(file, "rhs", rhs) + write(file, "init", init) + + params = create_group(file, "params") + write(params, "N", Int64(nsites)) + write(params, "phys_dim", Int64(phys_dim)) + write(params, "state_bond_dim", Int64(state_bond_dim)) + write(params, "operator_bond_dim", Int64(operator_bond_dim)) + write(params, "seed", Int64(seed)) + write(params, "format_note", "operator_as_mps stores Julia MPO site tensors in ITensorMPS MPS schema") + end + + println("=== Dumped local linsolve inputs (Julia/HDF5) ===") + println("output_path = $output_path") + println("N = $nsites") + println("phys_dim = $phys_dim") + println("state_bond_dim = $state_bond_dim") + println("operator_bond_dim = $operator_bond_dim") + summarize_mps("operator_as_mps", operator_as_mps) + summarize_mps("rhs", rhs) + summarize_mps("init", init) +end + +main(ARGS) diff --git a/benchmarks/results/2026-05-18-hdf5-input-dump.md b/benchmarks/results/2026-05-18-hdf5-input-dump.md new file mode 100644 index 00000000..9e4a6c26 --- /dev/null +++ b/benchmarks/results/2026-05-18-hdf5-input-dump.md @@ -0,0 +1,91 @@ +# HDF5 input dump for local linsolve parity + +## Commands + +```bash +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/dump_local_linsolve_inputs.jl benchmarks/results/local_linsolve_inputs_N8_b4_o4.h5 8 4 4 +cargo run -p tensor4all-hdf5 --example inspect_mps_inputs --release -- benchmarks/results/local_linsolve_inputs_N8_b4_o4.h5 + +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/dump_local_linsolve_inputs.jl benchmarks/results/local_linsolve_inputs_N38_b32_o32.h5 38 32 32 +cargo run -p tensor4all-hdf5 --example inspect_mps_inputs --release -- benchmarks/results/local_linsolve_inputs_N38_b32_o32.h5 +``` + +## Findings + +- Julia can write the prepared local operator and states in one + ITensorMPS-compatible HDF5 file. +- The operator is a Julia `MPO`, stored as `operator_as_mps` via + `MPS([H[i] for i in 1:length(H)])`. +- Rust reads `operator_as_mps`, `rhs`, and `init` with + `tensor4all_hdf5::load_mps`. +- For `N=38, state_bond_dim=32, operator_bond_dim=32`, Rust sees: + - `operator_as_mps.length = 38` + - `operator_as_mps.bond_dims = [32, ..., 32]` + - `rhs.length = init.length = 38` + - `rhs.bond_dims = init.bond_dims = [32, ..., 32]` +- Raw site tensors loaded with `load_itensor("operator_as_mps/MPS[i]")` + preserve the Julia HDF5 index order. After `de51179` identified the + unwanted behavior and the follow-up TensorTrain fix, loading the whole object + with `load_mps` also preserves the site tensor index order. The index + identities, prime levels, dimensions, tags, and axis order are preserved. + +## TensorTrain index-order normalization relaxation + +The old `TensorTrain` constructor path permuted site tensor indices into a +chain-friendly convention: + +```text +[left_link, site_indices..., right_link] +``` + +This can be useful for APIs that require a simple chain layout, likely including +conversion to or from `tensor4all-simplett::TensorTrain`. It is not needed at +the ITensor-like `TensorTrain` boundary itself. In particular, HDF5 +interoperability and Julia/Rust parity debugging require `load_mps` to preserve +the raw ITensors.jl site tensor index order. + +Implemented policy: + +- Keep raw `ITensor` HDF5 load/store order-preserving. +- Keep `TensorTrain::new`, `TensorTrain::with_ortho`, `TensorTrain::from_treetn`, + and `TensorTrain::set_tensor_checked` order-preserving. +- Keep chain-order normalization explicit and local to operations that actually + require it. The current `norm_squared_fast_path` still applies it only to a + clone before packing sites. +- Do not reorder fit-contraction inputs to satisfy an algorithm precondition. + Fit contraction is covered by a regression with non-chain-ordered site tensor + axes and must rely on index identity/topology rather than axis position. +- If a future SimpleTT conversion or dense evaluation path needs canonical chain + axis order, move the `permuteinds` step into that conversion path or expose an + explicitly named helper such as `into_chain_ordered`. +- Regression tests now cover constructor, `with_ortho`, `from_treetn`, setter, + and HDF5 `load_mps` order preservation. + +## Conjugation optimization policy + +Profiling the local GMRES path showed that repeated complex inner products spend +most of their time in `conj`-then-contract style work on small dense tensors. +The preferred fix is not an `inner_product`-only hand-written storage loop. That +kind of local fast path would make this one workload faster, but it would bypass +the general tensor/contraction path and make AD, layout variants, and backend +behavior harder to keep consistent. + +Policy: + +- Optimize the general traced/contraction path first. +- In `tenferro-rs`, fold `Conj -> DotGeneral` into a conjugated GEMM/dot operand + when the backend and layout can support it, with materializing fallback. +- In `tensor4all-rs`, keep Krylov and TreeTN code expressed through the general + vector-space/tensor operations so they benefit from backend improvements. +- Avoid adding ad hoc `TensorDynLen::inner_product` or `norm` special cases that + only recognize the current QuanticsNEGF workload. +- If a lower-level helper is introduced, make it a general backend operation + with clear semantics, tests for real and complex dtypes, and no downstream + reach-through into storage internals. + +## Generated files + +The generated `.h5` files are local benchmark artifacts and are ignored by git: + +- `benchmarks/results/local_linsolve_inputs_N8_b4_o4.h5` +- `benchmarks/results/local_linsolve_inputs_N38_b32_o32.h5` diff --git a/benchmarks/results/2026-05-18-local-linsolve.md b/benchmarks/results/2026-05-18-local-linsolve.md new file mode 100644 index 00000000..15b7d074 --- /dev/null +++ b/benchmarks/results/2026-05-18-local-linsolve.md @@ -0,0 +1,107 @@ +# Prepared Local Linsolve Benchmark + +Date: 2026-05-18 + +Purpose: compare the solve body after operator/RHS/initial-state construction. +Setup time is reported separately and excluded from the prepared solve timing. + +## Commands + +Rust: + +```bash +RAYON_NUM_THREADS=1 cargo run -p tensor4all-treetn --example benchmark_local_linsolve --release -- 8 4 4 1 4 4 0 +RAYON_NUM_THREADS=1 cargo run -p tensor4all-treetn --example benchmark_local_linsolve --release -- 38 32 32 1 10 30 0 +RAYON_NUM_THREADS=1 T4A_PROFILE_CONTRACT=1 cargo run -p tensor4all-treetn --example benchmark_local_linsolve --release -- 38 32 32 1 1 10 0 +``` + +Julia: + +```bash +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_local_linsolve.jl 8 4 4 1 4 4 +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_local_linsolve.jl 8 4 4 1 1 10 +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_local_linsolve.jl 38 32 32 1 1 10 +``` + +## Representative Results + +### Small Case + +| Implementation | N | Bonds | Sweep steps | Solve total | Local operator apps | Apply time | Other solve overhead | +|---|---:|---:|---:|---:|---:|---:|---:| +| Rust | 8 | 4/4 | 14 | 47.0 ms | single step: 6 | single step: 3.1 ms | single step GMRES overhead: 1.0 ms | +| Julia `maxiter=4,krylovdim=4` | 8 | 4/4 | 14 | 16.0 ms | 196 | 4.7 ms | replacebond/factorization/orthogonalization: 2.7 ms | +| Julia `maxiter=1,krylovdim=10` | 8 | 4/4 | 14 | 14.1 ms | 154 | 3.3 ms | replacebond/factorization/orthogonalization: 2.4 ms | + +### Larger Case + +| Implementation | N | Bonds | Sweep steps | Solve total | Local operator apps | Apply time | RHS time | Other solve overhead | +|---|---:|---:|---:|---:|---:|---:|---:|---:| +| Rust `gmres_max_restarts=10,gmres_restart_dim=30` before convention fix | 38 | 32/32 | 74 | 6.69 s | single step: 12 | single step: 139.9 ms | single step: 89.7 ms | single step GMRES overhead: 5.4 ms | +| Rust `gmres_max_restarts=1,gmres_restart_dim=10` after KrylovKit convention fix | 38 | 32/32 | 74 | 6.89 s | single step: 12 | single step: 144.0 ms | single step: 91.3 ms | single step GMRES overhead: 5.9 ms | +| Julia `maxiter=1,krylovdim=10` | 38 | 32/32 | 74 | 10.47 s | 814 | 9.85 s | 7.0 ms | 0.30 s | + +`Julia maxiter=10,krylovdim=30` was intentionally interrupted after more than +two minutes before the Rust convention was fixed. KrylovKit defines `maxiter` +as the number of restart cycles, so the maximum number of expansion steps is +roughly `maxiter * krylovdim`. Rust now follows that convention: +`gmres_max_restarts` maps to `GmresOptions::max_restarts`, while +`gmres_restart_dim` maps to the restart cycle length. + +The comparable one-restart case is therefore Julia `maxiter=1,krylovdim=10` +against Rust `gmres_max_restarts=1,gmres_restart_dim=10`. Both perform 74 local +updates and about 814 local operator applications in this benchmark. + +## Finding + +The bottleneck is local projected operator application, not setup. In the +larger Julia prepared solve, `projected apply inside GMRES` accounts for +approximately 94% of solve time. In the Rust single local GMRES measurement, +`ProjectedOperator::apply` accounts for approximately 96% of local GMRES time. + +For comparable local operator application counts, the synthetic prepared +benchmark does not show Rust as slower than Julia. The larger remaining gap to +track in QuanticsNEGF should therefore be problem-specific rank/topology +distribution, vector-space operation overhead, or a difference in the local +operator applications actually performed. + +The isolated Rust single-step numbers include cold environment construction, so +they intentionally overestimate one step inside a full sweep. The full `N=38` +Rust solve is consistent with warm projected apply cost: about 74 local update +steps times about 12 local matvecs per step times about 6 ms per warm apply is +already about 5.3 s, close to the measured 6.69 s before other sweep overhead. + +## Follow-up: 2026-05-20 After Spectator Handling Fix + +Reran the one-thread Rust prepared local linsolve benchmark after fixing +`ProjectedState` reference-link handling and no-site operator spectators. + +```bash +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo run -q -p tensor4all-treetn --example benchmark_local_linsolve \ + --release -- 38 32 32 1 1 10 0 +``` + +| Implementation | N | Bonds | Sweep steps | Solve total | Local operator apps | Apply time | RHS time | Other solve overhead | +|---|---:|---:|---:|---:|---:|---:|---:|---:| +| Rust `gmres_max_restarts=1,gmres_restart_dim=10` after spectator fix | 38 | 32/32 | 74 | 4.45 s | single step: 12 | single step: 108.2 ms | single step: 63.9 ms | single step GMRES overhead: 4.3 ms | + +This is faster than the previous comparable Rust row (`6.89 s`). The single +local solve still spends almost all GMRES time in projected local operator +application: about `108.2 ms / 12 = 9.0 ms` per local matvec, consistent with +the isolated cached projected-apply benchmark. + +The heavier diagnostic command + +```bash +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo run -q -p tensor4all-treetn --example benchmark_local_linsolve \ + --release -- 38 32 32 1 10 30 0 +``` + +ran the single local GMRES to `300` iterations without convergence +(`apply_count=311`, projected apply inside GMRES `1456.5 ms`) and the full +sweep took `120.0 s`. This setting is therefore a stress test of repeated +local matvecs, not the Julia-comparable one-restart timing. diff --git a/benchmarks/results/2026-05-18-projected-apply.md b/benchmarks/results/2026-05-18-projected-apply.md new file mode 100644 index 00000000..4fa27f0a --- /dev/null +++ b/benchmarks/results/2026-05-18-projected-apply.md @@ -0,0 +1,59 @@ +# Projected Local Apply Benchmark + +Date: 2026-05-18 + +Purpose: isolate the local projected operator apply used by two-site local +linsolve/TDVP-style updates. These timings are for synthetic chains with two +physical index groups per site: one acted index and one spectator index. The +Julia benchmark includes the spectator identity in the MPO, matching +QuanticsNEGF.jl's `add_dummy_indices` layout. + +## Commands + +Rust: + +```bash +RAYON_NUM_THREADS=1 cargo run -p tensor4all-treetn --example benchmark_projected_apply --release -- 38 32 32 3 0 +RAYON_NUM_THREADS=1 cargo run -p tensor4all-treetn --example benchmark_projected_apply --release -- 38 64 64 2 0 +``` + +Julia: + +```bash +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_projected_apply.jl 38 32 32 3 0 +BLAS_NUM_THREADS=1 julia --project=benchmarks/julia benchmarks/julia/benchmark_projected_apply.jl 38 64 64 2 0 +``` + +## Representative Results + +| Implementation | N | State bond | Operator bond | Cold apply | Warm apply mean | Cold repeated mean | +|---|---:|---:|---:|---:|---:|---:| +| Rust `ProjectedOperator::apply` | 38 | 32 | 32 | 70.3 ms | 6.0 ms | 45.0 ms | +| Julia `ProjMPO` | 38 | 32 | 32 | 52.9 ms | 7.7 ms | 73.5 ms | +| Rust `ProjectedOperator::apply` | 38 | 64 | 64 | 564.0 ms | 68.2 ms | 532.8 ms | +| Julia `ProjMPO` | 38 | 64 | 64 | 807.2 ms | 159.4 ms | 759.7 ms | + +Interpretation: both implementations show that the warm local projected apply +is already O(10 ms) at bond 32 and O(100 ms) at bond 64. A long local Dyson +test with thousands of local GMRES matvecs can therefore be slow from this hot +path alone, even when environment caches are effective. + +## Follow-up: 2026-05-20 After Spectator Handling Fix + +Reran the one-thread Rust projected apply benchmark after fixing linsolve +spectator-node handling: + +```bash +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo run -q -p tensor4all-treetn --example benchmark_projected_apply \ + --release -- 38 32 32 3 0 +``` + +| Implementation | N | State bond | Operator bond | Cold apply | Warm apply mean | Cold repeated mean | +|---|---:|---:|---:|---:|---:|---:| +| Rust `ProjectedOperator::apply` | 38 | 32 | 32 | 77.4 ms | 7.1 ms | 67.6 ms | + +The warm cached local matvec remains about `7 ms`, consistent with the earlier +O(10 ms) estimate for bond 32. The fix did not introduce a new hot-path cost +in the cached projected apply path. diff --git a/benchmarks/results/2026-05-19-tt-inner-product-bottleneck.md b/benchmarks/results/2026-05-19-tt-inner-product-bottleneck.md new file mode 100644 index 00000000..01e67a18 --- /dev/null +++ b/benchmarks/results/2026-05-19-tt-inner-product-bottleneck.md @@ -0,0 +1,585 @@ +# TensorTrain Inner Product Bottleneck Check + +Date: 2026-05-19 + +This note records a focused check of `TensorTrain::inner` performance. The +benchmark source is `benchmarks/rust/benchmark_tt_ops.rs`, included by +`crates/tensor4all-itensorlike/examples/benchmark_tt_ops.rs`. + +## Main Finding + +`TensorTrain::inner` already conjugates one site tensor at a time and contracts +it immediately. It does not first build a fully conjugated MPS. A full-MPS +`conj()` costs only about 0.18-0.20 ms for `L=32, chi<=16`, while the inner +product costs about 1.3 ms, so whole-MPS conjugation is not the primary +bottleneck in the current Rust implementation. + +The dominant cost for small bond dimensions is fixed per-site contraction +overhead. For `L=32`, the inner product performs 32 tensor conjugations and 63 +small pairwise contractions. ITensorMPS.jl is much faster in this regime, which +suggests that tensor4all-rs is paying more overhead per small contraction +(index/spec preparation, tensor wrapper dispatch, and tenferro eager einsum +entry) rather than doing asymptotically more arithmetic. + +## Rust Command + +```bash +RAYON_NUM_THREADS=1 \ +cargo run -p tensor4all-itensorlike --example benchmark_tt_ops --release -- \ + --L 32 --chis 4,8,16 --warm-up-time 0.05 --measurement-time 0.15 \ + --min-samples 5 --no-zipup +``` + +Additional runs used `--chis 32,64` and `--L 8/64 --chis 4`. + +## Julia Command + +```bash +BLAS_NUM_THREADS=1 \ +julia --project=benchmarks/julia benchmarks/julia/benchmark_tt_ops.jl \ + --L 32 --chis 4,8,16 --warm-up-time 0.05 --measurement-time 0.15 \ + --min-samples 5 --blas-threads 1 +``` + +Additional runs used `--chis 32,64`. + +## Median Timings + +| Case | Rust median ms | Julia median ms | Ratio | +| --- | ---: | ---: | ---: | +| `L=32, chi=4` inner | 1.29 | 0.205 | 6.3x | +| `L=32, chi=8` inner | 1.30 | 0.225 | 5.8x | +| `L=32, chi=16` inner | 1.35 | 0.353 | 3.8x | +| `L=32, chi=32` inner | 2.07 | 1.10 | 1.9x | +| `L=32, chi=64` inner | 6.89 | 5.44 | 1.3x | + +The gap shrinks as bond dimension grows, supporting the fixed-overhead +hypothesis. + +## Rust Variants + +For `L=32`: + +| Variant | chi=4 ms | chi=8 ms | chi=16 ms | Meaning | +| --- | ---: | ---: | ---: | --- | +| current `TensorTrain::inner` | 1.29 | 1.30 | 1.35 | production path | +| sitewise pair, no `sim_internal_inds` | 1.29 | 1.29 | 1.36 | same chain algorithm without internal-index simulation | +| sitewise 3-array contract, no simulation | 1.17 | 1.12 | 1.22 | combines `env * conj(A_i) * B_i` per site | +| sitewise binary through generic `contract(&[a,b])` | 1.53 | 1.62 | 1.57 | slower, generic binary entry has too much overhead | +| full MPS `conj()` only | 0.18 | 0.18 | 0.19 | not dominant | +| per-site `conj()` loop only | 0.19 | 0.20 | 0.21 | not dominant | + +For `chi=32/64`, the 3-array variant is neutral or slower, so changing +`TensorTrain::inner` to that unconditionally would not be a clean improvement. + +## Next Optimization Targets + +- Reduce fixed overhead in `TensorDynLen` pairwise contractions for small dense + tensors. +- Avoid materializing conjugated payloads by carrying a conjugation flag down + into tenferro where possible. +- Keep the MPS algorithm sitewise; building a conjugated MPS is unnecessary. +- Do not introduce a TensorTrain-specific shortcut unless the same improvement + can be expressed as a general tensor contraction optimization. + +## Follow-up: tenferro Raw Comparison + +After adding an experimental eager-einsum contraction-plan cache in local +`tenferro-rs`, `TensorTrain::inner` improved substantially for small bond +dimensions. The remaining gap was checked against the in-tree tenferro +benchmarks: + +```bash +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo bench -p tenferro --bench mps_inner_product_eager -- eval_local_path \ + --warm-up-time 0.05 --measurement-time 0.25 --sample-size 10 + +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo bench -p tenferro --bench mps_inner_product -- eval_only \ + --warm-up-time 0.05 --measurement-time 0.25 --sample-size 10 + +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo run -p tensor4all-itensorlike --example benchmark_tt_ops --release -- \ + --L 32 --chis 4,8,16,32,64 --warm-up-time 0.05 \ + --measurement-time 0.25 --min-samples 10 --inner-only +``` + +| Case (`L=32,d=2`) | tensor4all `TensorTrain::inner` | tenferro eager local path | tenferro traced eval | Remaining gap vs eager | +| --- | ---: | ---: | ---: | ---: | +| `chi=4` | 0.731 ms | 0.327 ms | 0.313 ms | 2.24x | +| `chi=8` | 0.744 ms | 0.327 ms | 0.322 ms | 2.28x | +| `chi=16` | 0.876 ms | 0.361 ms | 0.347 ms | 2.43x | +| `chi=32` | 1.721 ms | 1.161 ms | 1.066 ms | 1.48x | +| `chi=64` | 6.529 ms | 6.198 ms | 5.912 ms | 1.05x | + +The same fixture was also measured with ITensorMPS.jl: + +```bash +JULIA_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +julia --startup-file=no --project=benchmarks/julia \ + ../tenferro-rs/scripts/bench-itensormps-inner-product.jl \ + --L 32 --d 2 --chis 4,8,16,32,64 --warm-up-time 0.05 \ + --measurement-time 0.25 --min-samples 10 --blas-threads 1 +``` + +Julia versions in this run: + +- Julia 1.12.5 +- ITensors 0.6.23 +- ITensorMPS 0.2.6 + +| Case (`L=32,d=2`) | tensor4all `TensorTrain::inner` | tenferro eager local path | tenferro traced eval | ITensorMPS.jl `inner` | tensor4all / ITensorMPS | +| --- | ---: | ---: | ---: | ---: | ---: | +| `chi=4` | 0.731 ms | 0.327 ms | 0.313 ms | 0.199 ms | 3.67x | +| `chi=8` | 0.744 ms | 0.327 ms | 0.322 ms | 0.232 ms | 3.22x | +| `chi=16` | 0.876 ms | 0.361 ms | 0.347 ms | 0.303 ms | 2.89x | +| `chi=32` | 1.721 ms | 1.161 ms | 1.066 ms | 1.067 ms | 1.61x | +| `chi=64` | 6.529 ms | 6.198 ms | 5.912 ms | 6.912 ms | 0.94x | + +The high-bond case is now essentially arithmetic-bound and close to tenferro. +The small-bond gap is a fixed-cost problem. + +Additional profiling of `TensorTrain::inner` showed `sim_internal_inds()` is not +the bottleneck: even with profiling overhead, it was only about 0.01-0.04 ms for +`L=32`. The visible costs are: + +- explicit per-site `conj()` payload materialization, about 0.2-0.5 ms depending + on measurement noise and bond dimension; +- generic `TensorDynLen::contract_pair` dispatch into eager einsum for 63 tiny + contractions. + +The tenferro eager benchmark avoids both costs: it calls +`dot_general_with_conj` directly, so conjugation is passed as a flag into the +backend and the generic einsum label-analysis path is bypassed. + +Clean design implication: the next upstream-quality improvement should be a +general tensor contraction API that can carry per-operand conjugation flags +through the normal `TensorDynLen`/tenferro path. `TensorTrain::inner` can then +use that general API instead of materializing `bra.tensor(i).conj()`. This is +not a TensorTrain-specific native shortcut; it is the same semantic operation +expressed without an unnecessary payload copy. + +## Follow-up: Read-only Tensor Input + +After introducing the general read-only tensor input path (`TensorRead` / +`TensorView`) in local `tenferro-rs` and wiring tensor4all compact payload +contractions through it, the same one-thread TT inner benchmark was rerun: + +```bash +RAYON_NUM_THREADS=1 cargo run -q -p tensor4all-itensorlike \ + --example benchmark_tt_ops --release -- \ + --L 32 --zipup-L 10 --chis 4,8,16,32,64 \ + --measurement-time 0.25 --min-samples 10 --inner-only +``` + +| Case (`L=32,d=2`) | tensor4all `TensorTrain::inner` median | Previous median | +| --- | ---: | ---: | +| `chi=4` | 0.708 ms | 0.731 ms | +| `chi=8` | 0.710 ms | 0.744 ms | +| `chi=16` | 0.803 ms | 0.876 ms | +| `chi=32` | 1.725 ms | 1.721 ms | +| `chi=64` | 6.465 ms | 6.529 ms | + +The read-only path removes the avoidable compact-payload copy for contiguous +storage inputs, but it does not eliminate the remaining small-bond fixed cost: +`TensorTrain::inner` still performs many tiny generic contractions and still +materializes the conjugated site tensors. + +The non-AD `TensorDynLen` operation benchmark was also updated for the current +`contract_pair` API and rerun: + +```bash +RAYON_NUM_THREADS=1 cargo run -q -p tensor4all-core \ + --example benchmark_tensor_ops --release -- 20000 6 2 2 6 +``` + +| Operation | Time | Per call | +| --- | ---: | ---: | +| `inner_product` | 0.313624 s | 15.681 us | +| `norm` | 0.273433 s | 13.672 us | +| `axpby` | 0.019593 s | 0.980 us | +| `conj_contract_sum` | 0.270288 s | 13.514 us | + +## Follow-up: Current tenferro and Julia Comparison + +The same `L=32,d=2` benchmark was rerun against current local `tenferro-rs` +and ITensorMPS.jl with one thread: + +```bash +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo bench -p tenferro --bench mps_inner_product_eager -- \ + eval_local_path --warm-up-time 0.05 --measurement-time 0.25 --sample-size 10 + +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo bench -p tenferro --bench mps_inner_product -- \ + eval_only --warm-up-time 0.05 --measurement-time 0.25 --sample-size 10 + +JULIA_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +julia --startup-file=no --project=benchmarks/julia \ + ../tenferro-rs/scripts/bench-itensormps-inner-product.jl \ + --L 32 --d 2 --chis 4,8,16,32,64 \ + --warm-up-time 0.05 --measurement-time 0.25 \ + --min-samples 10 --blas-threads 1 +``` + +| Case (`L=32,d=2`) | tensor4all `TensorTrain::inner` median | tenferro eager local path | tenferro traced eval | ITensorMPS.jl `inner` | tensor4all / Julia | +| --- | ---: | ---: | ---: | ---: | ---: | +| `chi=4` | 0.708 ms | 0.333 ms | 0.301 ms | 0.194 ms | 3.64x | +| `chi=8` | 0.710 ms | 0.328 ms | 0.309 ms | 0.219 ms | 3.24x | +| `chi=16` | 0.803 ms | 0.340 ms | 0.342 ms | 0.303 ms | 2.65x | +| `chi=32` | 1.725 ms | 1.050 ms | 0.972 ms | 1.087 ms | 1.59x | +| `chi=64` | 6.465 ms | 6.090 ms | 5.859 ms | 6.866 ms | 0.94x | + +Interpretation: the high-bond case is fine; tensor4all is now slightly faster +than Julia at `chi=64`. The remaining gap is concentrated at small bond +dimension, where fixed costs dominate. The gap to tenferro's direct eager path +is still about 2.1-2.4x for `chi <= 16`, so the next target is not arithmetic +throughput but avoiding the tiny-contraction overhead and materialized +per-site conjugation. + +## Follow-up: tenferro Binary Einsum Fast Path + +Added a conservative non-AD binary einsum fast path in local `tenferro-rs`. +For two-input contractions with unique labels and at least one shared +contracting label, eager einsum now bypasses contraction-tree planning and the +generic `HashMap`/`HashSet` label-analysis path, directly lowering to +`dot_general_read`. Repeated-label, one-sided reduction, and outer-product +cases still use the generic path. + +The same one-thread tensor4all benchmark was rerun: + +```bash +RAYON_NUM_THREADS=1 cargo run -q -p tensor4all-itensorlike \ + --example benchmark_tt_ops --release -- \ + --L 32 --zipup-L 10 --chis 4,8,16,32,64 \ + --measurement-time 0.25 --min-samples 10 --inner-only +``` + +| Case (`L=32,d=2`) | After binary fast path | Previous read-only path | +| --- | ---: | ---: | +| `chi=4` | 0.685 ms | 0.708 ms | +| `chi=8` | 0.712 ms | 0.710 ms | +| `chi=16` | 0.768 ms | 0.803 ms | +| `chi=32` | 1.646 ms | 1.725 ms | +| `chi=64` | 6.343 ms | 6.465 ms | + +This helps, but does not remove the small-bond fixed-cost gap. The remaining +visible costs are still per-site conjugation materialization and many tiny +contraction calls. The fast path is therefore useful and generally clean, but +not the whole fix. + +## Follow-up: Overhead Breakdown + +Added a benchmark-only variant that precomputes the conjugated bra site tensors +outside the timed loop and then runs the same sitewise `TensorDynLen` +`contract_pair` sequence. This isolates explicit conjugation materialization +from the contraction wrapper overhead: + +```bash +RAYON_NUM_THREADS=1 cargo run -q -p tensor4all-itensorlike \ + --example benchmark_tt_ops --release -- \ + --L 32 --zipup-L 10 --chis 4,8,16,32,64 \ + --warm-up-time 0.5 --measurement-time 0.25 \ + --min-samples 10 --inner-only +``` + +| Case (`L=32,d=2`) | Current `TensorTrain::inner` | Preconjugated sitewise pair | Difference | +| --- | ---: | ---: | ---: | +| `chi=4` | 0.700 ms | 0.473 ms | 0.227 ms | +| `chi=8` | 0.710 ms | 0.483 ms | 0.226 ms | +| `chi=16` | 0.780 ms | 0.562 ms | 0.218 ms | +| `chi=32` | 1.668 ms | 1.429 ms | 0.240 ms | +| `chi=64` | 6.407 ms | 5.916 ms | 0.491 ms | + +The same run confirmed that direct tenferro eager local path for `chi=4` is +about `0.337 ms`. Therefore the `chi=4` gap decomposes roughly as: + +- explicit per-site `conj()` materialization: about `0.23 ms`; +- remaining `TensorDynLen::contract_pair` wrapper/index/rebuild overhead over + direct dot-general: about `0.14 ms`; +- arithmetic/backend GEMM itself is not the bottleneck at small bond dimension. + +Additional aggregated profiling with `T4A_PROFILE_PAIRWISE_CONTRACT=1` showed +the non-GEMM `TensorDynLen` wrapper work is small per call but repeated 63 +times: `prepare_contraction`, result axis-class computation, binary subscript +construction, and result `TensorDynLen` reconstruction together account for the +same order as the remaining gap after preconjugation. + +Conclusion: the true small-bond overhead is not contraction-path search and not +GEMM. It is the combination of (1) materialized conjugation and (2) repeatedly +calling the fully generic indexed `TensorDynLen::contract_pair` path for a known +MPS local contraction. + +## Follow-up: Operand-Level Conjugation + +Implemented pairwise contraction options with operand-level conjugation: +`contract_pair_with_operand_options(lhs, rhs, PairwiseContractionOptions)`. +`TensorTrain::inner` now uses `lhs_conj` / `rhs_conj` flags instead of +materializing `tensor.conj()` for each bra site. + +`T4A_PROFILE_TT_INNER=1` confirms `conj_ms=0.000000` in the new +`TensorTrain::inner` path. + +The same one-thread benchmark now gives: + +| Case (`L=32,d=2`) | New `TensorTrain::inner` | Previous `TensorTrain::inner` | Preconjugated baseline | +| --- | ---: | ---: | ---: | +| `chi=4` | 0.461 ms | 0.700 ms | 0.473 ms | +| `chi=8` | 0.474 ms | 0.710 ms | 0.483 ms | +| `chi=16` | 0.460 ms | 0.780 ms | 0.562 ms | +| `chi=32` | 1.517 ms | 1.668 ms | 1.429 ms | +| `chi=64` | 6.218 ms | 6.407 ms | 5.916 ms | + +This removes the materialized-conjugation cost from the real `inner` path. +The remaining small-bond gap to direct tenferro eager local contraction is now +mostly the generic `TensorDynLen` pairwise wrapper/index/result-rebuild cost. + +## Follow-up: 2026-05-20 After Linsolve Spectator Fix + +After the TreeTN linsolve spectator-node fixes, the one-thread +`TensorTrain::inner` benchmark was rerun: + +```bash +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo run -q -p tensor4all-itensorlike --example benchmark_tt_ops \ + --release -- --L 32 --zipup-L 10 --chis 4,8,16,32,64 \ + --warm-up-time 0.5 --measurement-time 0.25 \ + --min-samples 10 --inner-only +``` + +| Case (`L=32,d=2`) | Current `TensorTrain::inner` median | Preconjugated sitewise pair median | +| --- | ---: | ---: | +| `chi=4` | 0.464 ms | 0.457 ms | +| `chi=8` | 0.476 ms | 0.457 ms | +| `chi=16` | 0.449 ms | 0.425 ms | +| `chi=32` | 1.508 ms | 1.495 ms | +| `chi=64` | 6.233 ms | 6.232 ms | + +The current production path is now essentially at the preconjugated baseline, +which confirms that operand-level conjugation removed the explicit per-site +conjugation materialization cost from `TensorTrain::inner`. + +The focused profiling command was also rerun for `L=32, chi=4`: + +```bash +T4A_PROFILE_TT_INNER=1 RAYON_NUM_THREADS=1 \ +cargo run -q -p tensor4all-itensorlike --example benchmark_tt_ops \ + --release -- --L 32 --zipup-L 10 --chis 4 \ + --warm-up-time 0 --measurement-time 0 --min-samples 1 --inner-only +``` + +The profiled production path still reports `conj_ms=0.000000`; profiling +overhead is visible in single-sample timings, and `contract_ms` remains the +dominant component. The linsolve changes therefore did not reintroduce +materialized site-tensor conjugation into `TensorTrain::inner`. + +### Current Comparison With Direct tenferro And ITensorMPS.jl + +Direct tenferro eager local path was rerun with: + +```bash +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo bench -p tenferro --bench mps_inner_product_eager -- \ + eval_local_path --warm-up-time 0.05 --measurement-time 0.25 \ + --sample-size 10 +``` + +ITensorMPS.jl was rerun with: + +```bash +JULIA_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +julia --startup-file=no --project=benchmarks/julia \ + ../tenferro-rs/scripts/bench-itensormps-inner-product.jl \ + --L 32 --d 2 --chis 4,8,16,32,64 \ + --warm-up-time 0.05 --measurement-time 0.25 \ + --min-samples 10 --blas-threads 1 +``` + +| Case (`L=32,d=2`) | tensor4all `TensorTrain::inner` | direct tenferro eager | ITensorMPS.jl `inner` | tensor4all / tenferro | tensor4all / Julia | +| --- | ---: | ---: | ---: | ---: | ---: | +| `chi=4` | 0.464 ms | 0.328 ms | 0.208 ms | 1.41x | 2.23x | +| `chi=8` | 0.476 ms | 0.313 ms | 0.228 ms | 1.52x | 2.09x | +| `chi=16` | 0.449 ms | 0.349 ms | 0.318 ms | 1.29x | 1.41x | +| `chi=32` | 1.508 ms | 1.049 ms | 1.124 ms | 1.44x | 1.34x | +| `chi=64` | 6.233 ms | 6.162 ms | 6.878 ms | 1.01x | 0.91x | + +The operand-level conjugation fix moves tensor4all close to direct tenferro at +large bond dimension and removes the old explicit-conjugation gap. The +remaining small-bond gap is still fixed overhead above direct tenferro's local +path, plus ITensorMPS.jl's particularly low overhead for tiny contractions. + +### Follow-up: TensorDynLen Wrapper/Rebuild Fixed-Cost Breakdown + +The benchmark now includes two same-process raw tenferro baselines using the +same boundary-rank convention as `TensorTrain`: + +- `tenferro_raw_eager_inner_t4a_shapes`: direct eager `dot_general_with_conj` + on the same MPS tensor shapes, bypassing `TensorDynLen` indices/storage; +- `tenferro_raw_eager_inner_t4a_shapes_snapshot_outputs`: the same direct + eager path, but cloning each intermediate output tensor to approximate the + `TensorDynLen` output-storage snapshot. + +One-thread command: + +```bash +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo run -q -p tensor4all-itensorlike --example benchmark_tt_ops \ + --release -- --L 32 --zipup-L 10 --chis 4,8,16,32,64 \ + --warm-up-time 0.5 --measurement-time 0.25 \ + --min-samples 10 --inner-only +``` + +| Case (`L=32,d=2`) | `TensorDynLen` `inner` | raw tenferro same shapes | raw tenferro + output clone | clone-only delta | wrapper gap after clone | +| --- | ---: | ---: | ---: | ---: | ---: | +| `chi=4` | 0.461 ms | 0.307 ms | 0.318 ms | 0.011 ms | 0.143 ms | +| `chi=8` | 0.481 ms | 0.309 ms | 0.332 ms | 0.023 ms | 0.149 ms | +| `chi=16` | 0.445 ms | 0.333 ms | 0.334 ms | 0.001 ms | 0.111 ms | +| `chi=32` | 1.494 ms | 1.138 ms | 1.183 ms | 0.045 ms | 0.311 ms | +| `chi=64` | 6.277 ms | 6.030 ms | 6.283 ms | 0.253 ms | -0.005 ms | + +The raw tenferro values are intentionally measured in the tensor4all benchmark +binary, so this comparison avoids Criterion/JIT/process differences. The +values confirm that `TensorDynLen` adds a real small-bond fixed cost, but the +large payload-copy hypothesis is not supported: + +- input `Storage -> NativeTensor` materialization does not appear in the + profiled `inner` path; `eager_cache` is hit; +- intermediate output snapshot copying is real, but only accounts for about + `0.01-0.02 ms` at `chi=4/8`; +- at `chi=64`, output copying becomes visible (`~0.25 ms`), but arithmetic + dominates and the full path is already essentially raw-tenferro speed. + +The profiled `TensorDynLen` pairwise contraction path for a single +`L=32, chi=4` inner product reports: + +| Component | Calls | Total | Per call | Bytes | +| --- | ---: | ---: | ---: | ---: | +| `dot_general_execute` | 63 | 0.987 ms | 15.7 us | 0 | +| `from_inner_axis_classes` | 63 | 0.091 ms | 1.45 us | 0 | +| `result_axis_classes` | 63 | 0.043 ms | 0.69 us | 0 | +| `operand_indices` | 126 | 0.036 ms | 0.28 us | 0 | +| `prepare_contraction` | 63 | 0.027 ms | 0.43 us | 0 | +| `from_inner_storage_snapshot` | 63 | 0.020 ms | 0.32 us | 23,440 | +| `as_native` | 126 | 0.010 ms | 0.08 us | 0 | +| `from_inner_eager_cache` | 63 | 0.010 ms | 0.16 us | 0 | + +For `chi=8/32/64`, `from_inner_storage_snapshot` copies `93,456` / +`1,491,984` / `5,965,840` bytes, respectively. Therefore the unnecessary-copy +candidate to remove is specifically the output storage snapshot in +`TensorDynLen::from_inner_with_axis_classes`; it is not the small-`chi` dominant +cost, but it is a real bandwidth cost at larger bond dimensions and it +duplicates the already-owned `EagerTensor` kept in `eager_cache`. + +ITensorMPS.jl on the same deterministic MPS gives: + +| Case (`L=32,d=2`) | ITensorMPS.jl `inner` | +| --- | ---: | +| `chi=4` | 0.207 ms | +| `chi=8` | 0.265 ms | +| `chi=16` | 0.331 ms | +| `chi=32` | 0.954 ms | +| `chi=64` | 5.543 ms | + +Current interpretation: + +- ITensorMPS.jl has much lower tiny-contraction fixed overhead than the current + `TensorDynLen` path and even lower than the current direct tenferro eager + path for `chi=4/8`; +- the residual tensor4all overhead is mostly the repeated generic + index/wrapper/rebuild path around 63 tiny contractions, plus tenferro eager + per-call execution/session overhead; +- the next clean optimization candidate is to make `TensorDynLen` results able + to keep the eager payload as the primary representation and lazily build + `Storage` only when a storage-backed operation actually requires it. That + would remove the duplicated output payload copy generally, without adding an + MPS-specific fast path. + +### Follow-up: CPU-Only Eager-Dense TensorDynLen Payload + +Implemented a CPU-only first step toward `TensorDynLen = EagerTensor + layout + +indices`: dense eager contraction results now store the tenferro +`EagerTensor` directly instead of immediately snapshotting into +`Storage`. `Storage` remains available as a materialized compatibility +snapshot through `storage()` / `to_storage()`, but dense hot-path contraction +intermediates no longer carry both copies. + +The implementation keeps non-dense/structured results on the existing storage +path for now, and preserves the old eager cache there so AD through diagonal SVD +singular values continues to work. + +One-thread release benchmark: + +```bash +RAYON_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 OPENBLAS_NUM_THREADS=1 \ +OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ +cargo run -q -p tensor4all-itensorlike --example benchmark_tt_ops \ + --release -- --L 32 --zipup-L 10 --chis 4,8,16,32,64 \ + --warm-up-time 0.5 --measurement-time 0.25 \ + --min-samples 10 --inner-only +``` + +| Case (`L=32,d=2`) | tensor4all after eager-dense payload | tensor4all before | direct tenferro same shapes | ITensorMPS.jl `inner` | +| --- | ---: | ---: | ---: | ---: | +| `chi=4` | 0.444 ms | 0.461 ms | 0.312 ms | 0.207 ms | +| `chi=8` | 0.458 ms | 0.481 ms | 0.315 ms | 0.265 ms | +| `chi=16` | 0.396 ms | 0.445 ms | 0.332 ms | 0.331 ms | +| `chi=32` | 1.244 ms | 1.494 ms | 1.142 ms | 0.954 ms | +| `chi=64` | 6.150 ms | 6.277 ms | 5.981 ms | 5.543 ms | + +The small-bond gap is still mostly generic wrapper/per-call overhead, not payload +copy. The larger-bond cases improve more because the removed output snapshot +copy was real bandwidth work. + +Validation: + +```bash +cargo test --release -q -p tensor4all-core +``` + +passed. + +### Follow-up: SmallVec Contraction Preparation + +Implemented SmallVec-backed contraction preparation in `tensor4all-core`: + +- `ContractionSpec` now stores axes, result indices, and result dimensions in + `SmallVec<[T; 8]>`. +- Small contractions use the ITensor-like nested-loop matcher with no hash + table. +- Larger contractions fall back to a hash map keyed by `IndexLike::id()`, with + `is_contractable` still used for the final dimension/prime/direction check. +- Result construction uses boolean axis flags instead of repeated + `axes.contains(...)` scans. + +The pairwise profile confirms that preparation is not the current bottleneck: + +| Section (`L=32,d=2,chi=4`) | Calls | Total | Per call | +| --- | ---: | ---: | ---: | +| `prepare_contraction` | 63 | 0.027 ms | 0.43 us | +| `result_axis_classes` | 63 | 0.035 ms | 0.55 us | +| `dot_general_execute` | 63 | 2.619 ms | 41.58 us | + +The benchmark run after this change was globally slower than the earlier table, +including the direct tenferro baseline, so it should not be used as an absolute +performance regression measurement. It does show that SmallVec cleanup removes +heap-oriented preparation code without moving the dominant cost; the remaining +gap is still in per-call eager contraction/wrapper execution. + +Validation: + +```bash +cargo test --release -q -p tensor4all-core +``` + +passed. diff --git a/benchmarks/results/2026-05-19-tt-ops.md b/benchmarks/results/2026-05-19-tt-ops.md new file mode 100644 index 00000000..5f3a8af7 --- /dev/null +++ b/benchmarks/results/2026-05-19-tt-ops.md @@ -0,0 +1,55 @@ +# TensorTrain Operation Benchmarks + +Date: 2026-05-19 + +This records the initial tensor4all-rs vs ITensorMPS.jl benchmark harness for +TensorTrain-level operations: + +- MPS inner product +- strict direct-sum MPS addition +- prepared MPO x MPO zipup contraction + +## Commands + +Rust: + +```bash +RAYON_NUM_THREADS=1 \ +cargo run -p tensor4all-itensorlike --example benchmark_tt_ops --release -- \ + --L 32 --zipup-L 10 --chis 4,8 --warm-up-time 0.1 --measurement-time 0.2 --min-samples 3 +``` + +Julia: + +```bash +BLAS_NUM_THREADS=1 \ +julia --project=benchmarks/julia benchmarks/julia/benchmark_tt_ops.jl \ + --L 32 --zipup-L 10 --chis 4,8 --warm-up-time 0.1 --measurement-time 0.2 --min-samples 3 +``` + +## Local Output + +These were short smoke/performance runs, not final publication numbers. + +| Case | Params | Rust median ms | Julia median ms | Notes | +| --- | --- | ---: | ---: | --- | +| inner MPS | `L=32, chi=4, d=2` | 1.110 | 0.197 | Checksums match | +| directsum MPS | `L=32, chi=4, d=2` | 0.403 | 0.448 | Checksums match | +| zipup MPO prepared | `L=10, chi=4, d=2, maxdim=4` | 1.231 | 0.780 | Same benchmark shape, approximate zipup checksums close | +| inner MPS | `L=32, chi=8, d=2` | 1.151 | 0.224 | Checksums match | +| directsum MPS | `L=32, chi=8, d=2` | 1.149 | 0.584 | Checksums match up to numerical imaginary noise | +| zipup MPO prepared | `L=10, chi=8, d=2, maxdim=8` | 2.297 | 1.963 | Same benchmark shape, approximate zipup checksums close | + +## Notes + +- The Rust benchmark source is `benchmarks/rust/benchmark_tt_ops.rs`, included + by `crates/tensor4all-itensorlike/examples/benchmark_tt_ops.rs`. +- The Julia benchmark source is `benchmarks/julia/benchmark_tt_ops.jl`. +- `inner` and `directsum` use matching deterministic Complex64 MPS fixtures and + are suitable for checksum comparison. +- `zipup` uses deterministic MPO fixtures and the same `maxdim=chi` setting. + Since both libraries run approximate zipup implementations, the checksum is a + sanity signal rather than an exact parity assertion. +- The run also required updating tensor4all-rs to tenferro-rs' public API + cleanup: eager einsum now uses `tenferro::eager_tensor::einsum`, and traced + einsum uses `tenferro::traced_tensor::{einsum_with, EinsumOptimize}`. diff --git a/benchmarks/rust/benchmark_local_linsolve.rs b/benchmarks/rust/benchmark_local_linsolve.rs new file mode 100644 index 00000000..2f3ca247 --- /dev/null +++ b/benchmarks/rust/benchmark_local_linsolve.rs @@ -0,0 +1,398 @@ +// Benchmark isolated local GMRES and full two-site square_linsolve sweeps. +// +// Run: +// RAYON_NUM_THREADS=1 cargo run -p tensor4all-treetn --example benchmark_local_linsolve --release +// +// Optional args: +// RAYON_NUM_THREADS=1 cargo run -p tensor4all-treetn --example benchmark_local_linsolve --release -- + +use std::cell::{Cell, RefCell}; +use std::collections::{HashMap, HashSet}; +use std::hint::black_box; +use std::rc::Rc; +use std::time::{Duration, Instant}; + +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use tensor4all_core::{ + index::{DynId, Index}, + krylov::{gmres, GmresOptions}, + print_and_reset_contract_profile, reset_contract_profile, + AnyScalar, DynIndex, IndexLike, SvdTruncationPolicy, TensorContractionLike, TensorDynLen, + TensorIndex, +}; +use tensor4all_treetn::{ + square_linsolve, CanonicalizationOptions, IndexMapping, LinsolveOptions, LocalUpdateSweepPlan, + ProjectedOperator, ProjectedState, TreeTN, +}; + +fn make_node_name(i: usize) -> String { + format!("site{i}") +} + +fn unique_dyn_index(used: &mut HashSet, dim: usize, rng: &mut StdRng) -> DynIndex { + loop { + let id = DynId(rng.random()); + if used.insert(id) { + return Index::new(id, dim); + } + } +} + +fn chain_node_indices(n: usize, i: usize, bonds: &[DynIndex], sites: &[DynIndex]) -> Vec { + if n == 1 { + vec![sites[i].clone()] + } else if i == 0 { + vec![sites[i].clone(), bonds[i].clone()] + } else if i + 1 == n { + vec![bonds[i - 1].clone(), sites[i].clone()] + } else { + vec![bonds[i - 1].clone(), sites[i].clone(), bonds[i].clone()] + } +} + +fn create_state_chain( + n: usize, + state_bond_dim: usize, + acted_sites: &[DynIndex], + spectator_sites: &[DynIndex], + used_ids: &mut HashSet, + rng: &mut StdRng, +) -> anyhow::Result> { + let mut tree = TreeTN::::new(); + let bonds: Vec<_> = (0..n.saturating_sub(1)) + .map(|_| unique_dyn_index(used_ids, state_bond_dim, rng)) + .collect(); + + let mut nodes = Vec::with_capacity(n); + for (i, spectator_site) in spectator_sites.iter().enumerate().take(n) { + let mut indices = chain_node_indices(n, i, &bonds, acted_sites); + indices.insert(0, spectator_site.clone()); + let tensor = TensorDynLen::random::(rng, indices)?; + let node = tree.add_tensor(make_node_name(i), tensor)?; + nodes.push(node); + } + + for i in 0..n.saturating_sub(1) { + tree.connect(nodes[i], &bonds[i], nodes[i + 1], &bonds[i])?; + } + + Ok(tree) +} + +#[allow(clippy::type_complexity)] +fn create_operator_chain( + n: usize, + phys_dim: usize, + operator_bond_dim: usize, + acted_sites: &[DynIndex], + used_ids: &mut HashSet, + rng: &mut StdRng, +) -> anyhow::Result<( + TreeTN, + HashMap>, + HashMap>, +)> { + let mut tree = TreeTN::::new(); + let bonds: Vec<_> = (0..n.saturating_sub(1)) + .map(|_| unique_dyn_index(used_ids, operator_bond_dim, rng)) + .collect(); + let s_in: Vec<_> = (0..n) + .map(|_| unique_dyn_index(used_ids, phys_dim, rng)) + .collect(); + let s_out: Vec<_> = (0..n) + .map(|_| unique_dyn_index(used_ids, phys_dim, rng)) + .collect(); + + let mut input_mapping = HashMap::new(); + let mut output_mapping = HashMap::new(); + let mut nodes = Vec::with_capacity(n); + for i in 0..n { + let indices = if n == 1 { + vec![s_out[i].clone(), s_in[i].clone()] + } else if i == 0 { + vec![s_out[i].clone(), s_in[i].clone(), bonds[i].clone()] + } else if i + 1 == n { + vec![bonds[i - 1].clone(), s_out[i].clone(), s_in[i].clone()] + } else { + vec![ + bonds[i - 1].clone(), + s_out[i].clone(), + s_in[i].clone(), + bonds[i].clone(), + ] + }; + let tensor = TensorDynLen::random::(rng, indices)?; + let name = make_node_name(i); + let node = tree.add_tensor(name.clone(), tensor)?; + nodes.push(node); + + input_mapping.insert( + name.clone(), + IndexMapping { + true_index: acted_sites[i].clone(), + internal_index: s_in[i].clone(), + }, + ); + output_mapping.insert( + name, + IndexMapping { + true_index: acted_sites[i].clone(), + internal_index: s_out[i].clone(), + }, + ); + } + + for i in 0..n.saturating_sub(1) { + tree.connect(nodes[i], &bonds[i], nodes[i + 1], &bonds[i])?; + } + + Ok((tree, input_mapping, output_mapping)) +} + +fn same_index_order(a: &[DynIndex], b: &[DynIndex]) -> bool { + a.len() == b.len() + && a.iter() + .zip(b.iter()) + .all(|(ai, bi)| ai == bi && ai.dim() == bi.dim()) +} + +fn same_index_set(a: &[DynIndex], b: &[DynIndex]) -> bool { + a.len() == b.len() + && a.iter() + .all(|ai| b.iter().any(|bi| ai == bi && ai.dim() == bi.dim())) +} + +fn align_rhs_to_init(init: &TensorDynLen, rhs: TensorDynLen) -> anyhow::Result { + let init_indices = init.external_indices(); + let rhs_indices = rhs.external_indices(); + if !same_index_set(&init_indices, &rhs_indices) { + anyhow::bail!( + "RHS local index set does not match init: init={:?}, rhs={:?}", + init_indices, + rhs_indices + ); + } + if same_index_order(&init_indices, &rhs_indices) { + Ok(rhs) + } else { + rhs.permuteinds(&init_indices) + } +} + +fn max_bond_dim(tree: &TreeTN) -> usize { + tree.site_index_network() + .edges() + .filter_map(|(a, b)| tree.edge_between(&a, &b)) + .filter_map(|edge| tree.bond_index(edge)) + .map(|idx| idx.dim()) + .max() + .unwrap_or(1) +} + +fn main() -> anyhow::Result<()> { + let args: Vec = std::env::args().collect(); + let n_sites: usize = args.get(1).and_then(|s| s.parse().ok()).unwrap_or(38); + let state_bond_dim: usize = args.get(2).and_then(|s| s.parse().ok()).unwrap_or(8); + let operator_bond_dim: usize = args.get(3).and_then(|s| s.parse().ok()).unwrap_or(8); + let nfullsweeps: usize = args.get(4).and_then(|s| s.parse().ok()).unwrap_or(1); + let gmres_max_restarts: usize = args.get(5).and_then(|s| s.parse().ok()).unwrap_or(10); + let gmres_restart_dim: usize = args.get(6).and_then(|s| s.parse().ok()).unwrap_or(30); + let step_index: usize = args.get(7).and_then(|s| s.parse().ok()).unwrap_or(0); + + anyhow::ensure!( + n_sites >= 2, + "N must be at least 2 for a two-site local step" + ); + anyhow::ensure!(nfullsweeps > 0, "nsweeps must be greater than zero"); + anyhow::ensure!(gmres_max_restarts > 0, "gmres_max_restarts must be greater than zero"); + anyhow::ensure!(gmres_restart_dim > 0, "gmres_restart_dim must be greater than zero"); + + let phys_dim = 2usize; + let seed = 20260518_u64; + let a0 = AnyScalar::new_real(1.0); + let a1 = AnyScalar::new_real(0.01); + let gmres_tol = 1.0e-30; + let mut used_ids = HashSet::::new(); + let mut rng = StdRng::seed_from_u64(seed); + + let acted_sites: Vec<_> = (0..n_sites) + .map(|_| unique_dyn_index(&mut used_ids, phys_dim, &mut rng)) + .collect(); + let spectator_sites: Vec<_> = (0..n_sites) + .map(|_| unique_dyn_index(&mut used_ids, phys_dim, &mut rng)) + .collect(); + + let state_raw = create_state_chain( + n_sites, + state_bond_dim, + &acted_sites, + &spectator_sites, + &mut used_ids, + &mut rng, + )?; + let rhs = state_raw.clone(); + let (operator, input_mapping, output_mapping) = create_operator_chain( + n_sites, + phys_dim, + operator_bond_dim, + &acted_sites, + &mut used_ids, + &mut rng, + )?; + + let center = make_node_name(n_sites / 2); + let state = state_raw + .clone() + .canonicalize([center.clone()], CanonicalizationOptions::default())?; + let reference_state = state.sim_linkinds()?; + let plan = LocalUpdateSweepPlan::from_treetn(&state, ¢er, 2) + .ok_or_else(|| anyhow::anyhow!("failed to build two-site sweep plan"))?; + let step = plan + .steps + .get(step_index % plan.steps.len()) + .ok_or_else(|| anyhow::anyhow!("empty sweep plan"))?; + let local_tensor = state.extract_subtree(&step.nodes)?.contract_to_tensor()?; + + println!("=== Local GMRES / linsolve benchmark (Rust/tensor4all-rs) ==="); + println!("N = {n_sites}"); + println!("phys_dim = {phys_dim}"); + println!("state_bond_dim = {state_bond_dim}"); + println!("operator_bond_dim = {operator_bond_dim}"); + println!("nsweeps = {nfullsweeps}"); + println!("gmres_max_restarts = {gmres_max_restarts}"); + println!("gmres_restart_dim = {gmres_restart_dim}"); + println!("gmres_tol = {gmres_tol:.1e}"); + println!("coefficients = ({a0:?}, {a1:?})"); + println!("center = {center}"); + println!("sweep_plan_steps = {}", plan.steps.len()); + println!("step_index = {}", step_index % plan.steps.len()); + println!("step_nodes = {:?}", step.nodes); + println!("local_dims = {:?}", local_tensor.dims()); + println!(); + + reset_contract_profile(); + let mut projected_state = ProjectedState::new(rhs.clone()); + let rhs_start = Instant::now(); + let rhs_local_raw = + projected_state.local_constant_term(&step.nodes, &state, state.site_index_network())?; + let rhs_time = rhs_start.elapsed(); + let rhs_local = align_rhs_to_init(&local_tensor, rhs_local_raw)?; + + let projected_operator = RefCell::new(ProjectedOperator::with_index_mappings( + operator.clone(), + input_mapping.clone(), + output_mapping.clone(), + )); + let apply_count = Rc::new(Cell::new(0usize)); + let apply_time = Rc::new(RefCell::new(Duration::ZERO)); + let combine_time = Rc::new(RefCell::new(Duration::ZERO)); + let apply_count_ref = Rc::clone(&apply_count); + let apply_time_ref = Rc::clone(&apply_time); + let combine_time_ref = Rc::clone(&combine_time); + + let gmres_options = GmresOptions { + max_iter: gmres_restart_dim, + rtol: gmres_tol, + max_restarts: gmres_max_restarts, + verbose: false, + check_true_residual: false, + }; + + let gmres_start = Instant::now(); + let gmres_result = gmres( + |x: &TensorDynLen| { + apply_count_ref.set(apply_count_ref.get() + 1); + let apply_start = Instant::now(); + let hx = projected_operator.borrow_mut().apply( + black_box(x), + &step.nodes, + &state, + &reference_state, + state.site_index_network(), + )?; + *apply_time_ref.borrow_mut() += apply_start.elapsed(); + + let combine_start = Instant::now(); + let y = x.axpby(a0.clone(), &hx, a1.clone())?; + *combine_time_ref.borrow_mut() += combine_start.elapsed(); + Ok(y) + }, + &rhs_local, + &local_tensor, + &gmres_options, + )?; + let gmres_time = gmres_start.elapsed(); + + println!("--- Single local GMRES step ---"); + println!( + "rhs projection: {:.3} ms, rhs_rank={}", + rhs_time.as_secs_f64() * 1000.0, + rhs_local.external_indices().len() + ); + println!( + "gmres total: {:.3} ms, iterations={}, converged={}, residual={:.3e}", + gmres_time.as_secs_f64() * 1000.0, + gmres_result.iterations, + gmres_result.converged, + gmres_result.residual_norm + ); + println!("apply_count = {}", apply_count.get()); + println!( + "projected apply inside GMRES: {:.3} ms", + apply_time.borrow().as_secs_f64() * 1000.0 + ); + println!( + "a0*x + a1*Hx combine: {:.3} ms", + combine_time.borrow().as_secs_f64() * 1000.0 + ); + println!( + "unaccounted GMRES/vector overhead: {:.3} ms", + (gmres_time + .saturating_sub(*apply_time.borrow()) + .saturating_sub(*combine_time.borrow())) + .as_secs_f64() + * 1000.0 + ); + println!(); + + let options = LinsolveOptions::new(nfullsweeps) + .with_coefficients(a0, a1) + .with_gmres_tol(gmres_tol) + .with_gmres_max_restarts(gmres_max_restarts) + .with_gmres_restart_dim(gmres_restart_dim) + .with_max_rank(state_bond_dim) + .with_svd_policy(SvdTruncationPolicy::new(0.0)) + .with_residual_check(false); + + let full_start = Instant::now(); + let full_result = square_linsolve( + &operator, + &rhs, + state_raw, + ¢er, + options, + Some(input_mapping), + Some(output_mapping), + )?; + let full_time = full_start.elapsed(); + + println!("--- Full two-site square_linsolve ---"); + println!( + "total: {:.3} ms, sweeps={}, residual_reported={}", + full_time.as_secs_f64() * 1000.0, + full_result.sweeps, + full_result.residual.is_some() + ); + println!( + "expected local update steps = {}", + plan.steps.len() * nfullsweeps + ); + println!( + "solution max bond dim = {}", + max_bond_dim(&full_result.solution) + ); + print_and_reset_contract_profile(); + + Ok(()) +} diff --git a/benchmarks/rust/benchmark_projected_apply.rs b/benchmarks/rust/benchmark_projected_apply.rs new file mode 100644 index 00000000..8bc0736b --- /dev/null +++ b/benchmarks/rust/benchmark_projected_apply.rs @@ -0,0 +1,285 @@ +// Benchmark isolated `ProjectedOperator::apply` calls for mapped MPO local solves. +// +// Run: +// cargo run -p tensor4all-treetn --example benchmark_projected_apply --release +// +// Optional args: +// cargo run -p tensor4all-treetn --example benchmark_projected_apply --release -- + +use std::collections::{HashMap, HashSet}; +use std::hint::black_box; +use std::time::{Duration, Instant}; + +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use tensor4all_core::{ + index::{DynId, Index}, + print_and_reset_contract_profile, reset_contract_profile, + DynIndex, TensorDynLen, TensorIndex, +}; +use tensor4all_treetn::{IndexMapping, LocalUpdateSweepPlan, ProjectedOperator, TreeTN}; + +fn make_node_name(i: usize) -> String { + format!("site{i}") +} + +fn unique_dyn_index(used: &mut HashSet, dim: usize, rng: &mut StdRng) -> DynIndex { + loop { + let id = DynId(rng.random()); + if used.insert(id) { + return Index::new(id, dim); + } + } +} + +fn chain_node_indices(n: usize, i: usize, bonds: &[DynIndex], sites: &[DynIndex]) -> Vec { + if n == 1 { + vec![sites[i].clone()] + } else if i == 0 { + vec![sites[i].clone(), bonds[i].clone()] + } else if i + 1 == n { + vec![bonds[i - 1].clone(), sites[i].clone()] + } else { + vec![bonds[i - 1].clone(), sites[i].clone(), bonds[i].clone()] + } +} + +fn create_state_chain( + n: usize, + state_bond_dim: usize, + acted_sites: &[DynIndex], + spectator_sites: &[DynIndex], + used_ids: &mut HashSet, + rng: &mut StdRng, +) -> anyhow::Result> { + let mut tree = TreeTN::::new(); + let bonds: Vec<_> = (0..n.saturating_sub(1)) + .map(|_| unique_dyn_index(used_ids, state_bond_dim, rng)) + .collect(); + + let mut nodes = Vec::with_capacity(n); + for (i, spectator_site) in spectator_sites.iter().enumerate().take(n) { + let mut indices = chain_node_indices(n, i, &bonds, acted_sites); + indices.insert(0, spectator_site.clone()); + let tensor = TensorDynLen::random::(rng, indices)?; + let node = tree.add_tensor(make_node_name(i), tensor)?; + nodes.push(node); + } + + for i in 0..n.saturating_sub(1) { + tree.connect(nodes[i], &bonds[i], nodes[i + 1], &bonds[i])?; + } + + Ok(tree) +} + +#[allow(clippy::type_complexity)] +fn create_operator_chain( + n: usize, + phys_dim: usize, + operator_bond_dim: usize, + acted_sites: &[DynIndex], + used_ids: &mut HashSet, + rng: &mut StdRng, +) -> anyhow::Result<( + TreeTN, + HashMap>, + HashMap>, +)> { + let mut tree = TreeTN::::new(); + let bonds: Vec<_> = (0..n.saturating_sub(1)) + .map(|_| unique_dyn_index(used_ids, operator_bond_dim, rng)) + .collect(); + let s_in: Vec<_> = (0..n) + .map(|_| unique_dyn_index(used_ids, phys_dim, rng)) + .collect(); + let s_out: Vec<_> = (0..n) + .map(|_| unique_dyn_index(used_ids, phys_dim, rng)) + .collect(); + + let mut input_mapping = HashMap::new(); + let mut output_mapping = HashMap::new(); + let mut nodes = Vec::with_capacity(n); + for i in 0..n { + let sites = vec![s_out[i].clone(), s_in[i].clone()]; + let mut indices = if n == 1 { + sites + } else if i == 0 { + vec![s_out[i].clone(), s_in[i].clone(), bonds[i].clone()] + } else if i + 1 == n { + vec![bonds[i - 1].clone(), s_out[i].clone(), s_in[i].clone()] + } else { + vec![ + bonds[i - 1].clone(), + s_out[i].clone(), + s_in[i].clone(), + bonds[i].clone(), + ] + }; + indices.shrink_to_fit(); + let tensor = TensorDynLen::random::(rng, indices)?; + let name = make_node_name(i); + let node = tree.add_tensor(name.clone(), tensor)?; + nodes.push(node); + + input_mapping.insert( + name.clone(), + IndexMapping { + true_index: acted_sites[i].clone(), + internal_index: s_in[i].clone(), + }, + ); + output_mapping.insert( + name, + IndexMapping { + true_index: acted_sites[i].clone(), + internal_index: s_out[i].clone(), + }, + ); + } + + for i in 0..n.saturating_sub(1) { + tree.connect(nodes[i], &bonds[i], nodes[i + 1], &bonds[i])?; + } + + Ok((tree, input_mapping, output_mapping)) +} + +fn summarize(label: &str, times: &[Duration]) { + let secs: Vec = times.iter().map(Duration::as_secs_f64).collect(); + let mean = secs.iter().sum::() / secs.len() as f64; + let min = secs.iter().copied().fold(f64::INFINITY, f64::min); + let max = secs.iter().copied().fold(f64::NEG_INFINITY, f64::max); + println!( + "{label}: mean={:.3} ms min={:.3} ms max={:.3} ms n={}", + mean * 1000.0, + min * 1000.0, + max * 1000.0, + times.len() + ); +} + +fn main() -> anyhow::Result<()> { + let args: Vec = std::env::args().collect(); + let n_sites: usize = args.get(1).and_then(|s| s.parse().ok()).unwrap_or(38); + let state_bond_dim: usize = args.get(2).and_then(|s| s.parse().ok()).unwrap_or(8); + let operator_bond_dim: usize = args.get(3).and_then(|s| s.parse().ok()).unwrap_or(8); + let repeats: usize = args.get(4).and_then(|s| s.parse().ok()).unwrap_or(20); + let step_index: usize = args.get(5).and_then(|s| s.parse().ok()).unwrap_or(0); + + anyhow::ensure!( + n_sites >= 2, + "N must be at least 2 for a two-site local step" + ); + anyhow::ensure!(repeats > 0, "repeats must be greater than zero"); + + let phys_dim = 2usize; + let seed = 20260518_u64; + let mut used_ids = HashSet::::new(); + let mut rng = StdRng::seed_from_u64(seed); + + let acted_sites: Vec<_> = (0..n_sites) + .map(|_| unique_dyn_index(&mut used_ids, phys_dim, &mut rng)) + .collect(); + let spectator_sites: Vec<_> = (0..n_sites) + .map(|_| unique_dyn_index(&mut used_ids, phys_dim, &mut rng)) + .collect(); + + let state = create_state_chain( + n_sites, + state_bond_dim, + &acted_sites, + &spectator_sites, + &mut used_ids, + &mut rng, + )?; + let reference_state = state.sim_linkinds()?; + let (operator, input_mapping, output_mapping) = create_operator_chain( + n_sites, + phys_dim, + operator_bond_dim, + &acted_sites, + &mut used_ids, + &mut rng, + )?; + + let center = make_node_name(n_sites / 2); + let plan = LocalUpdateSweepPlan::from_treetn(&state, ¢er, 2) + .ok_or_else(|| anyhow::anyhow!("failed to build two-site sweep plan"))?; + let step = plan + .steps + .get(step_index % plan.steps.len()) + .ok_or_else(|| anyhow::anyhow!("empty sweep plan"))?; + let local_tensor = state.extract_subtree(&step.nodes)?.contract_to_tensor()?; + + println!("=== ProjectedOperator::apply benchmark ==="); + println!("N = {n_sites}"); + println!("phys_dim = {phys_dim}"); + println!("state_bond_dim = {state_bond_dim}"); + println!("operator_bond_dim = {operator_bond_dim}"); + println!("repeats = {repeats}"); + println!("center = {center}"); + println!("step_index = {}", step_index % plan.steps.len()); + println!("step_nodes = {:?}", step.nodes); + println!("local_dims = {:?}", local_tensor.dims()); + println!(); + + reset_contract_profile(); + let mut projected_cold = ProjectedOperator::with_index_mappings( + operator.clone(), + input_mapping.clone(), + output_mapping.clone(), + ); + let cold_start = Instant::now(); + let cold_result = projected_cold.apply( + black_box(&local_tensor), + &step.nodes, + &state, + &reference_state, + state.site_index_network(), + )?; + let cold = cold_start.elapsed(); + println!( + "cold apply (environment build + one apply): {:.3} ms, output_rank={}", + cold.as_secs_f64() * 1000.0, + cold_result.external_indices().len() + ); + + let mut warm_times = Vec::with_capacity(repeats); + for _ in 0..repeats { + let start = Instant::now(); + let out = projected_cold.apply( + black_box(&local_tensor), + &step.nodes, + &state, + &reference_state, + state.site_index_network(), + )?; + black_box(out); + warm_times.push(start.elapsed()); + } + summarize("warm apply (environment cache hot)", &warm_times); + + let mut cold_times = Vec::with_capacity(repeats); + for _ in 0..repeats { + let mut projected = ProjectedOperator::with_index_mappings( + operator.clone(), + input_mapping.clone(), + output_mapping.clone(), + ); + let start = Instant::now(); + let out = projected.apply( + black_box(&local_tensor), + &step.nodes, + &state, + &reference_state, + state.site_index_network(), + )?; + black_box(out); + cold_times.push(start.elapsed()); + } + summarize("cold apply repeated (fresh environment cache)", &cold_times); + print_and_reset_contract_profile(); + + Ok(()) +} diff --git a/benchmarks/rust/benchmark_tensor_ops.rs b/benchmarks/rust/benchmark_tensor_ops.rs new file mode 100644 index 00000000..afb2b6ee --- /dev/null +++ b/benchmarks/rust/benchmark_tensor_ops.rs @@ -0,0 +1,143 @@ +// Benchmark non-AD TensorDynLen vector-space operations. +// +// Run: +// RAYON_NUM_THREADS=1 cargo run -p tensor4all-core --example benchmark_tensor_ops --release +// +// Optional args: +// cargo run -p tensor4all-core --example benchmark_tensor_ops --release -- ... +// +// Example matching a two-site local tensor with small bonds: +// RAYON_NUM_THREADS=1 cargo run -p tensor4all-core --example benchmark_tensor_ops --release -- 20000 6 2 2 6 + +use std::hint::black_box; +use std::time::Instant; + +use anyhow::{bail, Result}; +use num_complex::Complex64; +use rand::rngs::StdRng; +use rand::SeedableRng; +use tensor4all_core::{AnyScalar, DynIndex, TensorContractionLike, TensorDynLen}; + +fn parse_args() -> Result<(usize, Vec)> { + let args = std::env::args().skip(1).collect::>(); + let repeats = args + .first() + .map(|value| value.parse::()) + .transpose()? + .unwrap_or(20_000); + if repeats == 0 { + bail!("repeats must be greater than zero"); + } + + let dims = if args.len() > 1 { + args[1..] + .iter() + .map(|value| value.parse::()) + .collect::, _>>()? + } else { + vec![6, 2, 2, 6] + }; + if dims.is_empty() || dims.contains(&0) { + bail!("all dimensions must be positive"); + } + Ok((repeats, dims)) +} + +fn make_indices(dims: &[usize]) -> Vec { + dims.iter().map(|&dim| DynIndex::new_dyn(dim)).collect() +} + +fn elapsed_seconds(mut f: impl FnMut() -> T) -> (f64, T) { + let started = Instant::now(); + let result = f(); + (started.elapsed().as_secs_f64(), result) +} + +fn main() -> Result<()> { + let (repeats, dims) = parse_args()?; + let element_count = dims.iter().product::(); + let indices = make_indices(&dims); + let mut rng = StdRng::seed_from_u64(0x5EED_1234); + let a = TensorDynLen::random::(&mut rng, indices.clone())?; + let b = TensorDynLen::random::(&mut rng, indices)?; + let alpha = AnyScalar::new_complex(0.7, -0.2); + let beta = AnyScalar::new_complex(-0.3, 0.4); + + // Warm up caches and allocator paths. + for _ in 0..32 { + black_box(a.inner_product(&b)?); + black_box(a.norm()); + black_box(a.axpby(alpha.clone(), &b, beta.clone())?); + black_box(a.conj().contract_pair(&b)?.sum()?); + } + + println!("=== TensorDynLen non-AD tensor ops benchmark ==="); + println!( + "dims={dims:?} elements={element_count} repeats={repeats} dtype=Complex64" + ); + + let (inner_seconds, inner_checksum) = elapsed_seconds(|| -> Result { + let mut checksum = Complex64::new(0.0, 0.0); + for _ in 0..repeats { + let value: Complex64 = black_box(a.inner_product(black_box(&b))?).into(); + checksum += value; + } + Ok(checksum) + }); + let inner_checksum = inner_checksum?; + println!( + "inner_seconds = {:.6} per_call_us = {:.3} checksum = {:.6e}+{:.6e}im", + inner_seconds, + inner_seconds * 1.0e6 / repeats as f64, + inner_checksum.re, + inner_checksum.im, + ); + + let (norm_seconds, norm_checksum) = elapsed_seconds(|| { + let mut checksum = 0.0; + for _ in 0..repeats { + checksum += black_box(a.norm()); + } + checksum + }); + println!( + "norm_seconds = {:.6} per_call_us = {:.3} checksum = {:.6e}", + norm_seconds, + norm_seconds * 1.0e6 / repeats as f64, + norm_checksum, + ); + + let (axpby_seconds, axpby_checksum) = elapsed_seconds(|| -> Result { + let mut checksum = 0.0; + for _ in 0..repeats { + let out = black_box(a.axpby(alpha.clone(), black_box(&b), beta.clone())?); + checksum += black_box(out.maxabs()); + } + Ok(checksum) + }); + println!( + "axpby_seconds = {:.6} per_call_us = {:.3} checksum = {:.6e}", + axpby_seconds, + axpby_seconds * 1.0e6 / repeats as f64, + axpby_checksum?, + ); + + let (conj_contract_seconds, conj_contract_checksum) = elapsed_seconds(|| -> Result { + let mut checksum = Complex64::new(0.0, 0.0); + for _ in 0..repeats { + let value: Complex64 = black_box(a.conj().contract_pair(black_box(&b))?.sum()?).into(); + checksum += value; + } + Ok(checksum) + }); + let conj_contract_checksum = conj_contract_checksum?; + println!( + "conj_contract_sum_seconds = {:.6} per_call_us = {:.3} checksum = {:.6e}+{:.6e}im", + conj_contract_seconds, + conj_contract_seconds * 1.0e6 / repeats as f64, + conj_contract_checksum.re, + conj_contract_checksum.im, + ); + + Ok(()) +} diff --git a/benchmarks/rust/benchmark_tt_ops.rs b/benchmarks/rust/benchmark_tt_ops.rs new file mode 100644 index 00000000..1f6b9037 --- /dev/null +++ b/benchmarks/rust/benchmark_tt_ops.rs @@ -0,0 +1,904 @@ +// Benchmark TensorTrain-level operations against ITensorMPS.jl. +// +// Run: +// RAYON_NUM_THREADS=1 cargo run -p tensor4all-itensorlike --example benchmark_tt_ops --release +// +// Optional args: +// --L N +// --d N +// --zipup-L N +// --chis 4,8,16,32 +// --warm-up-time SECONDS +// --measurement-time SECONDS +// --min-samples N +// --inner-only + +use std::hint::black_box; +use std::sync::Arc; +use std::time::Instant; + +use anyhow::{bail, Context, Result}; +use num_complex::Complex64; +use tensor4all_core::{ + contract, contract_pair, print_and_reset_pairwise_contract_profile, + reset_pairwise_contract_profile, AnyScalar, DynIndex, SvdTruncationPolicy, + TensorContractionLike, TensorDynLen, +}; +use tensor4all_itensorlike::{CanonicalForm, ContractOptions, TensorTrain}; +use tenferro::{CpuBackend, DotGeneralConfig, EagerContext, EagerTensor, Tensor, TypedTensor}; + +#[derive(Debug, Clone)] +struct Options { + length: usize, + phys_dim: usize, + zipup_length: usize, + chis: Vec, + warmup_seconds: f64, + measurement_seconds: f64, + min_samples: usize, + skip_zipup: bool, + inner_only: bool, +} + +impl Default for Options { + fn default() -> Self { + Self { + length: 32, + phys_dim: 2, + zipup_length: 10, + chis: vec![4, 8, 16, 32, 64], + warmup_seconds: 1.0, + measurement_seconds: 2.0, + min_samples: 10, + skip_zipup: false, + inner_only: false, + } + } +} + +fn parse_args() -> Result { + let mut opts = Options::default(); + let args = std::env::args().skip(1).collect::>(); + let mut i = 0; + while i < args.len() { + match args[i].as_str() { + "--help" | "-h" => { + println!( + "Usage: benchmark_tt_ops [--L N] [--d N] [--zipup-L N] \ + [--chis LIST] [--warm-up-time SEC] [--measurement-time SEC] \ + [--min-samples N] [--no-zipup] [--inner-only]" + ); + std::process::exit(0); + } + "--L" => { + i += 1; + opts.length = parse_arg(&args, i, "--L")?; + } + "--d" => { + i += 1; + opts.phys_dim = parse_arg(&args, i, "--d")?; + } + "--zipup-L" => { + i += 1; + opts.zipup_length = parse_arg(&args, i, "--zipup-L")?; + } + "--chis" => { + i += 1; + let value = args + .get(i) + .with_context(|| "missing value for --chis".to_string())?; + opts.chis = value + .split(',') + .map(|part| part.parse::()) + .collect::, _>>()?; + } + "--warm-up-time" => { + i += 1; + opts.warmup_seconds = parse_arg(&args, i, "--warm-up-time")?; + } + "--measurement-time" => { + i += 1; + opts.measurement_seconds = parse_arg(&args, i, "--measurement-time")?; + } + "--min-samples" => { + i += 1; + opts.min_samples = parse_arg(&args, i, "--min-samples")?; + } + "--no-zipup" => { + opts.skip_zipup = true; + } + "--inner-only" => { + opts.inner_only = true; + opts.skip_zipup = true; + } + other => bail!("unknown argument: {other}"), + } + i += 1; + } + + if opts.length == 0 || opts.zipup_length == 0 || opts.phys_dim == 0 { + bail!("lengths and physical dimension must be positive"); + } + if opts.chis.is_empty() || opts.chis.contains(&0) { + bail!("all bond dimensions must be positive"); + } + if opts.warmup_seconds < 0.0 || opts.measurement_seconds < 0.0 { + bail!("timing windows must be nonnegative"); + } + if opts.min_samples == 0 { + bail!("min-samples must be positive"); + } + + Ok(opts) +} + +fn parse_arg(args: &[String], index: usize, name: &str) -> Result +where + T: std::str::FromStr, + T::Err: std::error::Error + Send + Sync + 'static, +{ + args.get(index) + .with_context(|| format!("missing value for {name}"))? + .parse::() + .with_context(|| format!("invalid value for {name}")) +} + +fn deterministic_value(idx: usize, seed: usize) -> Complex64 { + let real = ((idx * 17 + seed * 13 + 3) % 97) as f64 / 97.0 - 0.5; + let imag = ((idx * 29 + seed * 7 + 5) % 89) as f64 / 89.0 - 0.5; + Complex64::new(real, imag) +} + +fn deterministic_tensor(indices: Vec, seed: usize) -> Result { + let len = indices.iter().map(|index| index.size()).product::(); + let data = (0..len) + .map(|idx| deterministic_value(idx, seed)) + .collect::>(); + TensorDynLen::from_dense(indices, data) +} + +fn deterministic_native_tensor(shape: Vec, seed: usize) -> Tensor { + let len = shape.iter().product::(); + let data = (0..len) + .map(|idx| deterministic_value(idx, seed)) + .collect::>(); + Tensor::C64(TypedTensor::from_vec(shape, data)) +} + +fn make_sites(length: usize, phys_dim: usize) -> Vec { + (0..length) + .map(|_| DynIndex::new_dyn(phys_dim)) + .collect() +} + +fn make_mps(sites: &[DynIndex], chi: usize, seed_offset: usize) -> Result { + let length = sites.len(); + let links = (0..length.saturating_sub(1)) + .map(|_| DynIndex::new_dyn(chi)) + .collect::>(); + let mut tensors = Vec::with_capacity(length); + + for site in 0..length { + let mut indices = Vec::new(); + if site > 0 { + indices.push(links[site - 1].clone()); + } + indices.push(sites[site].clone()); + if site + 1 < length { + indices.push(links[site].clone()); + } + tensors.push(deterministic_tensor(indices, seed_offset + site + 1)?); + } + + Ok(TensorTrain::new(tensors)?) +} + +fn make_native_mps_t4a_shapes( + length: usize, + phys_dim: usize, + chi: usize, + seed_offset: usize, +) -> Vec { + (0..length) + .map(|site| { + let mut shape = Vec::new(); + if site > 0 { + shape.push(chi); + } + shape.push(phys_dim); + if site + 1 < length { + shape.push(chi); + } + deterministic_native_tensor(shape, seed_offset + site + 1) + }) + .collect() +} + +fn eager_mps_tensors( + ctx: &Arc>, + tensors: Vec, +) -> Vec> { + tensors + .into_iter() + .map(|tensor| EagerTensor::from_tensor_in(tensor, Arc::clone(ctx))) + .collect() +} + +#[derive(Debug, Clone)] +struct RawEagerInnerConfigs { + first_site: DotGeneralConfig, + env_bra: DotGeneralConfig, + tmp_ket: DotGeneralConfig, +} + +impl RawEagerInnerConfigs { + fn new() -> Self { + Self { + first_site: DotGeneralConfig { + lhs_contracting_dims: vec![0], + rhs_contracting_dims: vec![0], + lhs_batch_dims: vec![], + rhs_batch_dims: vec![], + }, + env_bra: DotGeneralConfig { + lhs_contracting_dims: vec![0], + rhs_contracting_dims: vec![0], + lhs_batch_dims: vec![], + rhs_batch_dims: vec![], + }, + tmp_ket: DotGeneralConfig { + lhs_contracting_dims: vec![0, 1], + rhs_contracting_dims: vec![0, 1], + lhs_batch_dims: vec![], + rhs_batch_dims: vec![], + }, + } + } +} + +fn maybe_snapshot_output(tensor: &EagerTensor, snapshot_outputs: bool) { + if snapshot_outputs { + black_box(tensor.data().clone()); + } +} + +fn raw_eager_inner_t4a_shapes( + bra: &[EagerTensor], + ket: &[EagerTensor], + configs: &RawEagerInnerConfigs, + snapshot_outputs: bool, +) -> Result { + if bra.len() != ket.len() { + bail!( + "raw eager inputs must have the same length: {} vs {}", + bra.len(), + ket.len() + ); + } + if bra.is_empty() { + return Ok(Complex64::new(0.0, 0.0)); + } + + let mut env = bra[0] + .dot_general_with_conj(&ket[0], &configs.first_site, true, false) + .context("failed raw eager first-site contraction")?; + maybe_snapshot_output(&env, snapshot_outputs); + + for site in 1..bra.len() { + env = env + .dot_general_with_conj(&bra[site], &configs.env_bra, false, true) + .with_context(|| format!("failed raw eager env-bra contraction at site {site}"))?; + maybe_snapshot_output(&env, snapshot_outputs); + + env = env + .dot_general_with_conj(&ket[site], &configs.tmp_ket, false, false) + .with_context(|| format!("failed raw eager tmp-ket contraction at site {site}"))?; + maybe_snapshot_output(&env, snapshot_outputs); + } + + let scalar = env + .data() + .as_slice::() + .context("raw eager inner output is not Complex64")? + .first() + .copied() + .context("raw eager inner output is empty")?; + Ok(scalar) +} + +fn make_mpo( + input_sites: &[DynIndex], + output_sites: &[DynIndex], + chi: usize, + seed_offset: usize, +) -> Result { + if input_sites.len() != output_sites.len() { + bail!("input/output site lengths must match"); + } + + let length = input_sites.len(); + let links = (0..length.saturating_sub(1)) + .map(|_| DynIndex::new_dyn(chi)) + .collect::>(); + let mut tensors = Vec::with_capacity(length); + + for site in 0..length { + let mut indices = Vec::new(); + if site > 0 { + indices.push(links[site - 1].clone()); + } + indices.push(input_sites[site].clone()); + indices.push(output_sites[site].clone()); + if site + 1 < length { + indices.push(links[site].clone()); + } + tensors.push(deterministic_tensor(indices, seed_offset + site + 1)?); + } + + Ok(TensorTrain::new(tensors)?) +} + +fn run_for_seconds( + warmup_seconds: f64, + measurement_seconds: f64, + min_samples: usize, + mut f: F, +) -> Result<(T, Vec)> +where + F: FnMut() -> Result, +{ + let mut sink = f()?; + + let warmup_start = Instant::now(); + while warmup_start.elapsed().as_secs_f64() < warmup_seconds { + sink = f()?; + black_box(&sink); + } + + let mut times_ms = Vec::new(); + let measurement_start = Instant::now(); + while measurement_start.elapsed().as_secs_f64() < measurement_seconds + || times_ms.len() < min_samples + { + let start = Instant::now(); + sink = f()?; + black_box(&sink); + times_ms.push(start.elapsed().as_secs_f64() * 1.0e3); + } + + Ok((sink, times_ms)) +} + +fn stats_ms(times: &[f64]) -> (f64, f64, f64, f64) { + let mut sorted = times.to_vec(); + sorted.sort_by(|a, b| a.total_cmp(b)); + let min = sorted[0]; + let max = *sorted.last().expect("non-empty timings"); + let mean = sorted.iter().sum::() / sorted.len() as f64; + let median = if sorted.len().is_multiple_of(2) { + let hi = sorted.len() / 2; + 0.5 * (sorted[hi - 1] + sorted[hi]) + } else { + sorted[sorted.len() / 2] + }; + (min, median, mean, max) +} + +fn print_result(case: &str, params: &str, times: &[f64], value: &str, max_bond: usize) { + let (min, median, mean, max) = stats_ms(times); + println!( + "{case},{params},{},{min:.6},{median:.6},{mean:.6},{max:.6},{max_bond},{value}", + times.len() + ); +} + +fn inner_sitewise_pair_no_sim(bra: &TensorTrain, ket: &TensorTrain) -> Result { + if bra.len() != ket.len() { + bail!( + "Tensor trains must have the same length for inner product: {} vs {}", + bra.len(), + ket.len() + ); + } + if bra.is_empty() { + return Ok(AnyScalar::new_real(0.0)); + } + + let a0_conj = bra.tensor(0)?.conj(); + let mut env = contract_pair(&a0_conj, ket.tensor(0)?) + .context("failed to contract leftmost site tensors")?; + + for site in 1..bra.len() { + let ai_conj = bra.tensor(site)?.conj(); + env = contract_pair(&env, &ai_conj) + .with_context(|| format!("failed to contract environment with site {site}"))?; + env = contract_pair(&env, ket.tensor(site)?) + .with_context(|| format!("failed to contract ket tensor at site {site}"))?; + } + + env.sum() +} + +fn preconjugate_sites(bra: &TensorTrain) -> Result> { + (0..bra.len()) + .map(|site| Ok(bra.tensor(site)?.conj())) + .collect() +} + +fn inner_sitewise_pair_preconj_no_sim( + bra_conj: &[TensorDynLen], + ket: &TensorTrain, +) -> Result { + if bra_conj.len() != ket.len() { + bail!( + "Tensor trains must have the same length for inner product: {} vs {}", + bra_conj.len(), + ket.len() + ); + } + if bra_conj.is_empty() { + return Ok(AnyScalar::new_real(0.0)); + } + + let mut env = contract_pair(&bra_conj[0], ket.tensor(0)?) + .context("failed to contract leftmost site tensors")?; + + for (site, bra_site) in bra_conj.iter().enumerate().skip(1) { + env = contract_pair(&env, bra_site) + .with_context(|| format!("failed to contract environment with site {site}"))?; + env = contract_pair(&env, ket.tensor(site)?) + .with_context(|| format!("failed to contract ket tensor at site {site}"))?; + } + + env.sum() +} + +fn inner_sitewise_nary_no_sim(bra: &TensorTrain, ket: &TensorTrain) -> Result { + if bra.len() != ket.len() { + bail!( + "Tensor trains must have the same length for inner product: {} vs {}", + bra.len(), + ket.len() + ); + } + if bra.is_empty() { + return Ok(AnyScalar::new_real(0.0)); + } + + let a0_conj = bra.tensor(0)?.conj(); + let mut env = contract_pair(&a0_conj, ket.tensor(0)?) + .context("failed to contract leftmost site tensors")?; + + for site in 1..bra.len() { + let ai_conj = bra.tensor(site)?.conj(); + env = contract(&[&env, &ai_conj, ket.tensor(site)?]) + .with_context(|| format!("failed to contract three-tensor environment at site {site}"))?; + } + + env.sum() +} + +fn inner_sitewise_binary_contract_no_sim(bra: &TensorTrain, ket: &TensorTrain) -> Result { + if bra.len() != ket.len() { + bail!( + "Tensor trains must have the same length for inner product: {} vs {}", + bra.len(), + ket.len() + ); + } + if bra.is_empty() { + return Ok(AnyScalar::new_real(0.0)); + } + + let a0_conj = bra.tensor(0)?.conj(); + let mut env = + contract(&[&a0_conj, ket.tensor(0)?]).context("failed to contract leftmost site tensors")?; + + for site in 1..bra.len() { + let ai_conj = bra.tensor(site)?.conj(); + env = contract(&[&env, &ai_conj]) + .with_context(|| format!("failed to contract environment with site {site}"))?; + env = contract(&[&env, ket.tensor(site)?]) + .with_context(|| format!("failed to contract ket tensor at site {site}"))?; + } + + env.sum() +} + +#[derive(Debug, Default)] +struct InnerBreakdown { + conj_ms: f64, + first_contract_ms: f64, + env_contract_ms: f64, + ket_contract_ms: f64, + nary_contract_ms: f64, + sum_ms: f64, +} + +fn inner_sitewise_pair_breakdown_no_sim( + bra: &TensorTrain, + ket: &TensorTrain, +) -> Result<(AnyScalar, InnerBreakdown)> { + if bra.len() != ket.len() { + bail!( + "Tensor trains must have the same length for inner product: {} vs {}", + bra.len(), + ket.len() + ); + } + if bra.is_empty() { + return Ok((AnyScalar::new_real(0.0), InnerBreakdown::default())); + } + + let mut breakdown = InnerBreakdown::default(); + + let started = Instant::now(); + let a0_conj = bra.tensor(0)?.conj(); + breakdown.conj_ms += started.elapsed().as_secs_f64() * 1.0e3; + + let started = Instant::now(); + let mut env = contract_pair(&a0_conj, ket.tensor(0)?) + .context("failed to contract leftmost site tensors")?; + breakdown.first_contract_ms += started.elapsed().as_secs_f64() * 1.0e3; + + for site in 1..bra.len() { + let started = Instant::now(); + let ai_conj = bra.tensor(site)?.conj(); + breakdown.conj_ms += started.elapsed().as_secs_f64() * 1.0e3; + + let started = Instant::now(); + env = contract_pair(&env, &ai_conj) + .with_context(|| format!("failed to contract environment with site {site}"))?; + breakdown.env_contract_ms += started.elapsed().as_secs_f64() * 1.0e3; + + let started = Instant::now(); + env = contract_pair(&env, ket.tensor(site)?) + .with_context(|| format!("failed to contract ket tensor at site {site}"))?; + breakdown.ket_contract_ms += started.elapsed().as_secs_f64() * 1.0e3; + } + + let started = Instant::now(); + let value = env.sum()?; + breakdown.sum_ms += started.elapsed().as_secs_f64() * 1.0e3; + Ok((value, breakdown)) +} + +fn inner_sitewise_nary_breakdown_no_sim( + bra: &TensorTrain, + ket: &TensorTrain, +) -> Result<(AnyScalar, InnerBreakdown)> { + if bra.len() != ket.len() { + bail!( + "Tensor trains must have the same length for inner product: {} vs {}", + bra.len(), + ket.len() + ); + } + if bra.is_empty() { + return Ok((AnyScalar::new_real(0.0), InnerBreakdown::default())); + } + + let mut breakdown = InnerBreakdown::default(); + + let started = Instant::now(); + let a0_conj = bra.tensor(0)?.conj(); + breakdown.conj_ms += started.elapsed().as_secs_f64() * 1.0e3; + + let started = Instant::now(); + let mut env = contract_pair(&a0_conj, ket.tensor(0)?) + .context("failed to contract leftmost site tensors")?; + breakdown.first_contract_ms += started.elapsed().as_secs_f64() * 1.0e3; + + for site in 1..bra.len() { + let started = Instant::now(); + let ai_conj = bra.tensor(site)?.conj(); + breakdown.conj_ms += started.elapsed().as_secs_f64() * 1.0e3; + + let started = Instant::now(); + env = contract(&[&env, &ai_conj, ket.tensor(site)?]) + .with_context(|| format!("failed to contract three-tensor environment at site {site}"))?; + breakdown.nary_contract_ms += started.elapsed().as_secs_f64() * 1.0e3; + } + + let started = Instant::now(); + let value = env.sum()?; + breakdown.sum_ms += started.elapsed().as_secs_f64() * 1.0e3; + Ok((value, breakdown)) +} + +fn print_inner_breakdown(kind: &str, params: &str, value: AnyScalar, breakdown: &InnerBreakdown) { + eprintln!( + "inner_breakdown,{kind},{params},conj_ms={:.6},first_contract_ms={:.6},env_contract_ms={:.6},ket_contract_ms={:.6},nary_contract_ms={:.6},sum_ms={:.6},value={}", + breakdown.conj_ms, + breakdown.first_contract_ms, + breakdown.env_contract_ms, + breakdown.ket_contract_ms, + breakdown.nary_contract_ms, + breakdown.sum_ms, + format_scalar(value), + ); +} + +fn conj_sites(tt: &TensorTrain) -> Result { + let mut total_elements = 0usize; + for site in 0..tt.len() { + let tensor = tt.tensor(site)?.conj(); + total_elements += tensor.dims().iter().product::(); + black_box(&tensor); + } + Ok(total_elements) +} + +fn format_scalar(value: impl Into) -> String { + let value = value.into(); + if let Some(z) = value.as_c64() { + format!("{:.12e}+{:.12e}im", z.re, z.im) + } else { + format!("{:.12e}", value.real()) + } +} + +fn main() -> Result<()> { + let opts = parse_args()?; + + println!("tensor4all TensorTrain ops benchmark"); + println!(" L: {}", opts.length); + println!(" d: {}", opts.phys_dim); + println!(" zipup L: {}", opts.zipup_length); + println!( + " chis: {}", + opts.chis + .iter() + .map(|chi| chi.to_string()) + .collect::>() + .join(",") + ); + println!(" warm-up time: {}", opts.warmup_seconds); + println!(" measurement time: {}", opts.measurement_seconds); + println!(" min samples: {}", opts.min_samples); + println!(" skip zipup: {}", opts.skip_zipup); + println!(" inner only: {}", opts.inner_only); + println!(); + println!("case,params,samples,min_ms,median_ms,mean_ms,max_ms,max_bond,value"); + + for &chi in &opts.chis { + let sites = make_sites(opts.length, opts.phys_dim); + let bra = make_mps(&sites, chi, 0)?; + let ket = make_mps(&sites, chi, opts.length)?; + let bra_conj = preconjugate_sites(&bra)?; + let raw_ctx = EagerContext::with_backend(CpuBackend::with_threads(1)); + let raw_bra = eager_mps_tensors( + &raw_ctx, + make_native_mps_t4a_shapes(opts.length, opts.phys_dim, chi, 0), + ); + let raw_ket = eager_mps_tensors( + &raw_ctx, + make_native_mps_t4a_shapes(opts.length, opts.phys_dim, chi, opts.length), + ); + let raw_configs = RawEagerInnerConfigs::new(); + let mps_params = format!("L_{}_chi_{}_d_{}", opts.length, chi, opts.phys_dim); + + reset_pairwise_contract_profile(); + let profiled_inner_value = bra.inner(&ket)?; + black_box(profiled_inner_value); + eprintln!("pairwise_profile,current_inner,{mps_params}"); + print_and_reset_pairwise_contract_profile(); + + let (pair_profile_value, pair_breakdown) = + inner_sitewise_pair_breakdown_no_sim(&bra, &ket)?; + print_inner_breakdown( + "sitewise_pair_no_sim_once", + &mps_params, + pair_profile_value, + &pair_breakdown, + ); + let (nary_profile_value, nary_breakdown) = + inner_sitewise_nary_breakdown_no_sim(&bra, &ket)?; + print_inner_breakdown( + "sitewise_nary_no_sim_once", + &mps_params, + nary_profile_value, + &nary_breakdown, + ); + + let (inner_value, inner_times) = run_for_seconds( + opts.warmup_seconds, + opts.measurement_seconds, + opts.min_samples, + || Ok(bra.inner(black_box(&ket))?), + )?; + print_result( + "tensor4all_inner_mps", + &mps_params, + &inner_times, + &format_scalar(inner_value), + 0, + ); + + let (raw_eager_value, raw_eager_times) = run_for_seconds( + opts.warmup_seconds, + opts.measurement_seconds, + opts.min_samples, + || { + raw_eager_inner_t4a_shapes( + black_box(&raw_bra), + black_box(&raw_ket), + black_box(&raw_configs), + false, + ) + }, + )?; + print_result( + "tenferro_raw_eager_inner_t4a_shapes", + &mps_params, + &raw_eager_times, + &format_scalar(raw_eager_value), + 0, + ); + + let (raw_eager_snapshot_value, raw_eager_snapshot_times) = run_for_seconds( + opts.warmup_seconds, + opts.measurement_seconds, + opts.min_samples, + || { + raw_eager_inner_t4a_shapes( + black_box(&raw_bra), + black_box(&raw_ket), + black_box(&raw_configs), + true, + ) + }, + )?; + print_result( + "tenferro_raw_eager_inner_t4a_shapes_snapshot_outputs", + &mps_params, + &raw_eager_snapshot_times, + &format_scalar(raw_eager_snapshot_value), + 0, + ); + + let (preconj_pair_value, preconj_pair_times) = run_for_seconds( + opts.warmup_seconds, + opts.measurement_seconds, + opts.min_samples, + || inner_sitewise_pair_preconj_no_sim(black_box(&bra_conj), black_box(&ket)), + )?; + print_result( + "tensor4all_inner_mps_sitewise_pair_preconj_no_sim", + &mps_params, + &preconj_pair_times, + &format_scalar(preconj_pair_value), + 0, + ); + + if opts.inner_only { + continue; + } + + let (sitewise_pair_value, sitewise_pair_times) = run_for_seconds( + opts.warmup_seconds, + opts.measurement_seconds, + opts.min_samples, + || inner_sitewise_pair_no_sim(black_box(&bra), black_box(&ket)), + )?; + print_result( + "tensor4all_inner_mps_sitewise_pair_no_sim", + &mps_params, + &sitewise_pair_times, + &format_scalar(sitewise_pair_value), + 0, + ); + + let (sitewise_nary_value, sitewise_nary_times) = run_for_seconds( + opts.warmup_seconds, + opts.measurement_seconds, + opts.min_samples, + || inner_sitewise_nary_no_sim(black_box(&bra), black_box(&ket)), + )?; + print_result( + "tensor4all_inner_mps_sitewise_nary_no_sim", + &mps_params, + &sitewise_nary_times, + &format_scalar(sitewise_nary_value), + 0, + ); + + let (binary_contract_value, binary_contract_times) = run_for_seconds( + opts.warmup_seconds, + opts.measurement_seconds, + opts.min_samples, + || inner_sitewise_binary_contract_no_sim(black_box(&bra), black_box(&ket)), + )?; + print_result( + "tensor4all_inner_mps_sitewise_binary_contract_no_sim", + &mps_params, + &binary_contract_times, + &format_scalar(binary_contract_value), + 0, + ); + + let (conj_tt, conj_times) = run_for_seconds( + opts.warmup_seconds, + opts.measurement_seconds, + opts.min_samples, + || Ok(black_box(&bra).conj()), + )?; + print_result( + "tensor4all_conj_mps", + &mps_params, + &conj_times, + "ok", + conj_tt.maxbonddim(), + ); + + let (conj_elements, conj_sites_times) = run_for_seconds( + opts.warmup_seconds, + opts.measurement_seconds, + opts.min_samples, + || conj_sites(black_box(&bra)), + )?; + print_result( + "tensor4all_conj_sites_mps", + &mps_params, + &conj_sites_times, + &conj_elements.to_string(), + 0, + ); + + let (sum, directsum_times) = run_for_seconds( + opts.warmup_seconds, + opts.measurement_seconds, + opts.min_samples, + || Ok(bra.add(black_box(&ket))?), + )?; + let sum_norm = sum.inner(&sum)?; + print_result( + "tensor4all_directsum_mps", + &mps_params, + &directsum_times, + &format_scalar(sum_norm), + sum.maxbonddim(), + ); + + if opts.skip_zipup { + continue; + } + + let zipup_sites_in = make_sites(opts.zipup_length, opts.phys_dim); + let zipup_sites_mid = make_sites(opts.zipup_length, opts.phys_dim); + let zipup_sites_out = make_sites(opts.zipup_length, opts.phys_dim); + let mut mpo_a = make_mpo(&zipup_sites_in, &zipup_sites_mid, chi, 2 * opts.length)?; + let mut mpo_b = make_mpo( + &zipup_sites_mid, + &zipup_sites_out, + chi, + 2 * opts.length + opts.zipup_length, + )?; + mpo_a.orthogonalize_with(0, CanonicalForm::Unitary)?; + mpo_b.orthogonalize_with(0, CanonicalForm::Unitary)?; + + let zipup_options = ContractOptions::zipup() + .with_svd_policy(SvdTruncationPolicy::new(0.0)) + .with_max_rank(chi); + let zipup_params = format!( + "L_{}_chi_{}_d_{}_maxdim_{}", + opts.zipup_length, chi, opts.phys_dim, chi + ); + let (zipup_result, zipup_times) = run_for_seconds( + opts.warmup_seconds, + opts.measurement_seconds, + opts.min_samples, + || Ok(mpo_a.contract(black_box(&mpo_b), black_box(&zipup_options))?), + )?; + let zipup_norm = zipup_result.inner(&zipup_result)?; + print_result( + "tensor4all_zipup_mpo_prepared", + &zipup_params, + &zipup_times, + &format_scalar(zipup_norm), + zipup_result.maxbonddim(), + ); + } + + Ok(()) +} diff --git a/benchmarks/rust/inspect_hdf5_mps_inputs.rs b/benchmarks/rust/inspect_hdf5_mps_inputs.rs new file mode 100644 index 00000000..b77128cb --- /dev/null +++ b/benchmarks/rust/inspect_hdf5_mps_inputs.rs @@ -0,0 +1,67 @@ +// Inspect Julia-dumped local linsolve inputs stored in ITensorMPS-compatible HDF5. +// +// Run: +// cargo run -p tensor4all-hdf5 --example inspect_mps_inputs --release -- benchmarks/results/local_linsolve_inputs_N8_b4_o4.h5 + +use std::env; + +use tensor4all_core::TensorDynLen; +use tensor4all_hdf5::{load_itensor, load_mps}; +use tensor4all_itensorlike::TensorTrain; + +fn summarize(name: &str, tt: &TensorTrain) { + let tensors = tt.tensors(); + let siteinds = tt.siteinds(); + let tensor_dims: Vec<_> = tensors.iter().map(|tensor| tensor.dims()).collect(); + let tensor_index_counts: Vec<_> = tensors.iter().map(|tensor| tensor.indices().len()).collect(); + let site_index_counts: Vec<_> = siteinds.iter().map(Vec::len).collect(); + + println!("{name}.length = {}", tt.len()); + println!("{name}.llim = {}", tt.llim()); + println!("{name}.rlim = {}", tt.rlim()); + println!("{name}.bond_dims = {:?}", tt.bond_dims()); + println!("{name}.maxbonddim = {}", tt.maxbonddim()); + println!("{name}.tensor_dims = {:?}", tensor_dims); + println!("{name}.tensor_index_counts = {:?}", tensor_index_counts); + println!("{name}.site_index_counts = {:?}", site_index_counts); + + if let Some(first) = tensors.first() { + println!("{name}.first_tensor_indices = {:?}", first.indices()); + } + if let Some(last) = tensors.last() { + println!("{name}.last_tensor_indices = {:?}", last.indices()); + } +} + +fn summarize_raw_tensor(path: &str, name: &str) -> anyhow::Result<()> { + let tensor: TensorDynLen = load_itensor(path, name)?; + println!("{name}.raw_dims = {:?}", tensor.dims()); + println!("{name}.raw_indices = {:?}", tensor.indices()); + Ok(()) +} + +fn main() -> anyhow::Result<()> { + let path = env::args() + .nth(1) + .unwrap_or_else(|| "benchmarks/results/local_linsolve_inputs_N8_b4_o4.h5".to_string()); + + println!("=== Inspect HDF5 MPS inputs (Rust/tensor4all-hdf5) ==="); + println!("path = {path}"); + + let operator_as_mps = load_mps(&path, "operator_as_mps")?; + let rhs = load_mps(&path, "rhs")?; + let init = load_mps(&path, "init")?; + + summarize("operator_as_mps", &operator_as_mps); + summarize("rhs", &rhs); + summarize("init", &init); + + println!("--- Raw HDF5 site tensors ---"); + let last_operator_site = format!("operator_as_mps/MPS[{}]", operator_as_mps.len()); + summarize_raw_tensor(&path, "operator_as_mps/MPS[1]")?; + summarize_raw_tensor(&path, &last_operator_site)?; + summarize_raw_tensor(&path, "rhs/MPS[1]")?; + summarize_raw_tensor(&path, "init/MPS[1]")?; + + Ok(()) +} diff --git a/crates/tensor4all-capi/include/tensor4all_capi.h b/crates/tensor4all-capi/include/tensor4all_capi.h index e150ee4c..96d2a716 100644 --- a/crates/tensor4all-capi/include/tensor4all_capi.h +++ b/crates/tensor4all-capi/include/tensor4all_capi.h @@ -619,11 +619,11 @@ enum t4a_status_code t4a_tensor_contract(const struct t4a_tensor *a, /** * Contract multiple tensors while retaining selected shared indices as output legs. */ -enum t4a_status_code t4a_tensor_contract_multi(const struct t4a_tensor *const *tensors, - size_t n_tensors, - const struct t4a_index *const *retain_indices, - size_t n_retain, - struct t4a_tensor **out); +enum t4a_status_code t4a_tensor_contract_many_retain(const struct t4a_tensor *const *tensors, + size_t n_tensors, + const struct t4a_index *const *retain_indices, + size_t n_retain, + struct t4a_tensor **out); /** * Contract two tensors while retaining selected shared indices as output legs. @@ -1045,9 +1045,9 @@ enum t4a_status_code t4a_treetn_linsolve(const struct t4a_treetn *operator_, const struct t4a_svd_truncation_policy *policy, size_t maxdim, size_t nfullsweeps, - double krylov_tol, - size_t krylov_maxiter, - size_t krylov_dim, + double gmres_tol, + size_t gmres_max_restarts, + size_t gmres_restart_dim, double a0, double a1, double convergence_tol, diff --git a/crates/tensor4all-capi/src/tensor.rs b/crates/tensor4all-capi/src/tensor.rs index 75aa2fb3..203cb522 100644 --- a/crates/tensor4all-capi/src/tensor.rs +++ b/crates/tensor4all-capi/src/tensor.rs @@ -5,8 +5,8 @@ use std::sync::Arc; use num_complex::Complex64; use tensor4all_core::{ - contract_multi_with_options, qr_with, svd_with, AllowedPairs, ContractionOptions, QrOptions, - SvdOptions, SvdTruncationPolicy, + contract_pair, contract_pair_with_options, contract_with_options, qr_with, svd_with, + ContractionOptions, QrOptions, SvdOptions, SvdTruncationPolicy, }; use tensor4all_tensorbackend::Storage; @@ -635,9 +635,7 @@ pub extern "C" fn t4a_tensor_contract( } run_catching(out, || unsafe { - let tensor = (*a) - .inner() - .contract((*b).inner()) + let tensor = contract_pair((*a).inner(), (*b).inner()) .map_err(|err| capi_error(T4A_INVALID_ARGUMENT, err))?; Ok(t4a_tensor::new(tensor)) }) @@ -656,11 +654,8 @@ pub extern "C" fn t4a_tensor_contract_retain( let a = require_tensor(a)?; let b = require_tensor(b)?; let retain_indices = read_indices_from_ptrs(n_retain, retain_indices)?; - let options = - ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); - let result = a - .inner() - .contract_with_options(b.inner(), options) + let options = ContractionOptions::new().with_retain_indices(&retain_indices); + let result = contract_pair_with_options(a.inner(), b.inner(), options) .map_err(|err| capi_error(T4A_INVALID_ARGUMENT, err))?; Ok(t4a_tensor::new(result)) }) @@ -668,7 +663,7 @@ pub extern "C" fn t4a_tensor_contract_retain( /// Contract multiple tensors while retaining selected shared indices as output legs. #[unsafe(no_mangle)] -pub extern "C" fn t4a_tensor_contract_multi( +pub extern "C" fn t4a_tensor_contract_many_retain( tensors: *const *const t4a_tensor, n_tensors: usize, retain_indices: *const *const t4a_index, @@ -678,9 +673,8 @@ pub extern "C" fn t4a_tensor_contract_multi( run_catching(out, || { let tensors = read_tensor_refs(tensors, n_tensors)?; let retain_indices = read_indices_from_ptrs(n_retain, retain_indices)?; - let options = - ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); - let result = contract_multi_with_options(&tensors, options) + let options = ContractionOptions::new().with_retain_indices(&retain_indices); + let result = contract_with_options(&tensors, options) .map_err(|err| capi_error(T4A_INVALID_ARGUMENT, err))?; Ok(t4a_tensor::new(result)) }) diff --git a/crates/tensor4all-capi/src/tensor/tests/mod.rs b/crates/tensor4all-capi/src/tensor/tests/mod.rs index d1fdee8d..12ee7def 100644 --- a/crates/tensor4all-capi/src/tensor/tests/mod.rs +++ b/crates/tensor4all-capi/src/tensor/tests/mod.rs @@ -5,7 +5,7 @@ use crate::types::{ }; use num_complex::Complex64; use std::ffi::CStr; -use tensor4all_core::AnyScalar; +use tensor4all_core::{AnyScalar, TensorContractionLike}; fn last_error() -> String { let mut len = 0usize; @@ -195,15 +195,15 @@ fn reconstruct_svd( let mut perm = vec![v.indices.len() - 1]; perm.extend(0..(v.indices.len() - 1)); let vh = v.conj().permute(&perm).unwrap(); - let svh = s.contract(&vh).unwrap(); + let svh = s.contract_pair(&vh).unwrap(); let sim_bond = s.indices[1].clone(); let bond = v.indices[v.indices.len() - 1].clone(); let svh = svh.replaceind(&sim_bond, &bond).unwrap(); - unsafe { (*u).inner().contract(&svh).unwrap() } + unsafe { (*u).inner().contract_pair(&svh).unwrap() } } fn reconstruct_qr(q: *const t4a_tensor, r: *const t4a_tensor) -> InternalTensor { - unsafe { (*q).inner().contract((*r).inner()).unwrap() } + unsafe { (*q).inner().contract_pair((*r).inner()).unwrap() } } fn internal_tensor_f64(indices: &[*const t4a_index], data: &[f64]) -> InternalTensor { @@ -838,7 +838,7 @@ fn test_tensor_contract_retain_preserves_shared_index() { } #[test] -fn test_tensor_contract_multi_retain_preserves_batch_index() { +fn test_tensor_contract_many_retain_preserves_batch_index() { let i = new_index(2); let j = new_index(3); let k = new_index(2); @@ -864,7 +864,7 @@ fn test_tensor_contract_multi_retain_preserves_batch_index() { let retain = [j as *const t4a_index]; let mut out = std::ptr::null_mut(); assert_eq!( - t4a_tensor_contract_multi( + t4a_tensor_contract_many_retain( tensors.as_ptr(), tensors.len(), retain.as_ptr(), diff --git a/crates/tensor4all-capi/src/treetn.rs b/crates/tensor4all-capi/src/treetn.rs index fd2ec662..094839dd 100644 --- a/crates/tensor4all-capi/src/treetn.rs +++ b/crates/tensor4all-capi/src/treetn.rs @@ -1912,9 +1912,9 @@ pub extern "C" fn t4a_treetn_linsolve( policy: *const t4a_svd_truncation_policy, maxdim: libc::size_t, nfullsweeps: libc::size_t, - krylov_tol: libc::c_double, - krylov_maxiter: libc::size_t, - krylov_dim: libc::size_t, + gmres_tol: libc::c_double, + gmres_max_restarts: libc::size_t, + gmres_restart_dim: libc::size_t, a0: libc::c_double, a1: libc::c_double, convergence_tol: libc::c_double, @@ -1945,20 +1945,23 @@ pub extern "C" fn t4a_treetn_linsolve( run_catching(out, || { require_node(init.inner(), center_vertex)?; - if !krylov_tol.is_finite() || krylov_tol <= 0.0 { + if !gmres_tol.is_finite() || gmres_tol <= 0.0 { return Err(capi_error( T4A_INVALID_ARGUMENT, - format!("krylov_tol must be finite and > 0, got {krylov_tol}"), + format!("gmres_tol must be finite and > 0, got {gmres_tol}"), )); } - if krylov_maxiter == 0 { + if gmres_max_restarts == 0 { return Err(capi_error( T4A_INVALID_ARGUMENT, - "krylov_maxiter must be >= 1", + "gmres_max_restarts must be >= 1", )); } - if krylov_dim == 0 { - return Err(capi_error(T4A_INVALID_ARGUMENT, "krylov_dim must be >= 1")); + if gmres_restart_dim == 0 { + return Err(capi_error( + T4A_INVALID_ARGUMENT, + "gmres_restart_dim must be >= 1", + )); } if !a0.is_finite() || !a1.is_finite() { return Err(capi_error( @@ -1989,9 +1992,9 @@ pub extern "C" fn t4a_treetn_linsolve( let mut options = LinsolveOptions::new(nfullsweeps) .with_truncation(TruncationOptions::new()) - .with_krylov_tol(krylov_tol) - .with_krylov_maxiter(krylov_maxiter) - .with_krylov_dim(krylov_dim) + .with_gmres_tol(gmres_tol) + .with_gmres_max_restarts(gmres_max_restarts) + .with_gmres_restart_dim(gmres_restart_dim) .with_coefficients(a0, a1); if let Some(policy) = resolve_svd_policy(policy) { options = options.with_svd_policy(policy); diff --git a/crates/tensor4all-core/Cargo.toml b/crates/tensor4all-core/Cargo.toml index 778ee8ac..aa8ae1f3 100644 --- a/crates/tensor4all-core/Cargo.toml +++ b/crates/tensor4all-core/Cargo.toml @@ -35,6 +35,7 @@ rand.workspace = true rand_distr.workspace = true omeco.workspace = true petgraph.workspace = true +smallvec.workspace = true tenferro.workspace = true tenferro-tensor.workspace = true diff --git a/crates/tensor4all-core/README.md b/crates/tensor4all-core/README.md index 0fc7fdb1..57f71707 100644 --- a/crates/tensor4all-core/README.md +++ b/crates/tensor4all-core/README.md @@ -7,7 +7,7 @@ Core tensor library: Index system, dynamic-rank Tensor, contraction, SVD/QR/LU f - `Index` — flexible index with tags and prime levels - `TensorDynLen` — dynamic-rank tensor with flexible index types - `Storage` — dense or diagonal storage for `f64` and `Complex64` -- `contract()` / `contract_multi()` — pairwise and multi-tensor contraction +- `contract()` / `contract_with_options()` — connected tensor-network contraction - `svd()` / `qr()` — factorizations with truncation support ## Example diff --git a/crates/tensor4all-core/examples/benchmark_tensor_ops.rs b/crates/tensor4all-core/examples/benchmark_tensor_ops.rs new file mode 100644 index 00000000..9f1a5ac4 --- /dev/null +++ b/crates/tensor4all-core/examples/benchmark_tensor_ops.rs @@ -0,0 +1 @@ +include!("../../../benchmarks/rust/benchmark_tensor_ops.rs"); diff --git a/crates/tensor4all-core/src/any_scalar.rs b/crates/tensor4all-core/src/any_scalar.rs index 57160815..e4de9def 100644 --- a/crates/tensor4all-core/src/any_scalar.rs +++ b/crates/tensor4all-core/src/any_scalar.rs @@ -678,6 +678,9 @@ impl AnyScalar { /// assert_eq!(scalar.as_c64().map(|z| (z.re, z.im)), Some((3.0, 4.0))); /// ``` pub fn try_conj(&self) -> Result { + if !self.tracks_grad() { + return Ok(Self::from_backend_scalar(self.to_backend_scalar().conj())); + } Self::from_eager_unary(self, "conj", |tensor| tensor.conj()) } @@ -778,7 +781,7 @@ impl AnyScalar { /// assert!(scalar.is_real()); /// ``` pub fn sqrt(&self) -> Self { - if self.is_complex() || self.real() < 0.0 { + if !self.tracks_grad() || self.is_complex() || self.real() < 0.0 { Self::from_backend_scalar(self.to_backend_scalar().sqrt()) } else { Self::from_eager_unary(self, "sqrt", |tensor| tensor.sqrt()) @@ -869,14 +872,29 @@ impl AnyScalar { } pub(crate) fn try_add(&self, rhs: &Self) -> Result { + if !self.tracks_grad() && !rhs.tracks_grad() { + return Ok(Self::from_backend_scalar( + self.to_backend_scalar() + rhs.to_backend_scalar(), + )); + } Self::from_eager_binary(self, rhs, "add", |lhs, rhs| lhs.add(rhs)) } pub(crate) fn try_mul(&self, rhs: &Self) -> Result { + if !self.tracks_grad() && !rhs.tracks_grad() { + return Ok(Self::from_backend_scalar( + self.to_backend_scalar() * rhs.to_backend_scalar(), + )); + } Self::from_eager_binary(self, rhs, "mul", |lhs, rhs| lhs.mul(rhs)) } pub(crate) fn try_div(&self, rhs: &Self) -> Result { + if !self.tracks_grad() && !rhs.tracks_grad() { + return Ok(Self::from_backend_scalar( + self.to_backend_scalar() / rhs.to_backend_scalar(), + )); + } if self.as_tensor()?.as_native()?.dtype() == rhs.as_tensor()?.as_native()?.dtype() { Self::from_eager_binary(self, rhs, "div", |lhs, rhs| lhs.div(rhs)) } else { @@ -887,6 +905,9 @@ impl AnyScalar { } pub(crate) fn try_neg(&self) -> Result { + if !self.tracks_grad() { + return Ok(Self::from_backend_scalar(-self.to_backend_scalar())); + } Self::from_eager_unary(self, "neg", |tensor| tensor.neg()) } } @@ -1176,3 +1197,33 @@ impl fmt::Debug for AnyScalar { .finish() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn non_grad_scalar_arithmetic_uses_plain_values() { + let a = AnyScalar::new_real(3.0); + let b = AnyScalar::new_real(4.0); + + let value = ((a.clone() + b.clone()) * b.clone() - AnyScalar::new_real(8.0)) + / AnyScalar::new_real(2.0); + + assert_eq!(value.as_f64(), Some(10.0)); + assert!(!value.tracks_grad()); + assert!(value.as_tensor().is_ok()); + } + + #[test] + fn tracked_scalar_arithmetic_preserves_autodiff() { + let x = AnyScalar::new_real(2.0).enable_grad().unwrap(); + let y = &x * &x; + + assert!(y.tracks_grad()); + y.backward().unwrap(); + + let grad = x.grad().unwrap().unwrap(); + assert_eq!(grad.as_f64(), Some(4.0)); + } +} diff --git a/crates/tensor4all-core/src/block_tensor.rs b/crates/tensor4all-core/src/block_tensor.rs index c7dd218c..4e53310b 100644 --- a/crates/tensor4all-core/src/block_tensor.rs +++ b/crates/tensor4all-core/src/block_tensor.rs @@ -34,8 +34,9 @@ use std::collections::HashSet; use crate::any_scalar::AnyScalar; use crate::tensor_index::TensorIndex; use crate::tensor_like::{ - AllowedPairs, DirectSumResult, FactorizeError, FactorizeOptions, FactorizeResult, - LinearizationOrder, TensorLike, + DirectSumResult, FactorizeError, FactorizeOptions, FactorizeResult, LinearizationOrder, + TensorConstructionLike, TensorContractionLike, TensorFactorizationLike, TensorLike, + TensorVectorSpace, }; use anyhow::Result; @@ -412,7 +413,7 @@ impl TensorIndex for BlockTensor { // TensorLike implementation // ============================================================================ -impl TensorLike for BlockTensor { +impl TensorVectorSpace for BlockTensor { // ------------------------------------------------------------------------ // Vector space operations (required for GMRES) // ------------------------------------------------------------------------ @@ -476,41 +477,22 @@ impl TensorLike for BlockTensor { Ok(sum) } - fn conj(&self) -> Self { - let conjugated: Vec = self.blocks.iter().map(|b| b.conj()).collect(); - Self { - blocks: conjugated, - shape: self.shape, - } - } - fn validate(&self) -> Result<()> { self.validate_indices() } +} +impl TensorContractionLike for BlockTensor { // ------------------------------------------------------------------------ - // Operations not supported for BlockTensor (return error, don't panic) + // Tensor network operations // ------------------------------------------------------------------------ - fn factorize( - &self, - _left_inds: &[::Index], - _options: &FactorizeOptions, - ) -> std::result::Result, FactorizeError> { - Err(FactorizeError::ComputationError(anyhow::anyhow!( - "BlockTensor does not support factorize" - ))) - } - - fn factorize_full_rank( - &self, - _left_inds: &[::Index], - _alg: crate::FactorizeAlg, - _canonical: crate::Canonical, - ) -> std::result::Result, FactorizeError> { - Err(FactorizeError::ComputationError(anyhow::anyhow!( - "BlockTensor does not support factorize_full_rank" - ))) + fn conj(&self) -> Self { + let conjugated: Vec = self.blocks.iter().map(|b| b.conj()).collect(); + Self { + blocks: conjugated, + shape: self.shape, + } } fn direct_sum( @@ -546,14 +528,35 @@ impl TensorLike for BlockTensor { }) } - fn contract(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result { + fn contract(_tensors: &[&Self]) -> Result { anyhow::bail!("BlockTensor does not support contract") } +} - fn contract_connected(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> Result { - anyhow::bail!("BlockTensor does not support contract_connected") +impl TensorFactorizationLike for BlockTensor { + fn factorize( + &self, + _left_inds: &[::Index], + _options: &FactorizeOptions, + ) -> std::result::Result, FactorizeError> { + Err(FactorizeError::ComputationError(anyhow::anyhow!( + "BlockTensor does not support factorize" + ))) } + fn factorize_full_rank( + &self, + _left_inds: &[::Index], + _alg: crate::FactorizeAlg, + _canonical: crate::Canonical, + ) -> std::result::Result, FactorizeError> { + Err(FactorizeError::ComputationError(anyhow::anyhow!( + "BlockTensor does not support factorize_full_rank" + ))) + } +} + +impl TensorConstructionLike for BlockTensor { fn diagonal( _input_index: &::Index, _output_index: &::Index, diff --git a/crates/tensor4all-core/src/block_tensor/tests/mod.rs b/crates/tensor4all-core/src/block_tensor/tests/mod.rs index 51a8320d..f000b73e 100644 --- a/crates/tensor4all-core/src/block_tensor/tests/mod.rs +++ b/crates/tensor4all-core/src/block_tensor/tests/mod.rs @@ -3,6 +3,7 @@ use crate::defaults::tensordynlen::TensorDynLen; use crate::defaults::DynIndex; use crate::index_like::IndexLike; use crate::krylov::{gmres, GmresOptions}; +use crate::tensor_like::TensorContractionLike; /// Helper to create a 1D tensor (vector) with given data and shared index. fn make_vector_with_index(data: Vec, idx: &DynIndex) -> TensorDynLen { @@ -80,13 +81,14 @@ fn test_fuse_indices_delegates_to_blocks_and_preserves_shape() { TensorDynLen::from_dense(vec![i.clone(), j.clone()], vec![5.0, 6.0, 7.0, 8.0]).unwrap(); let block = BlockTensor::new(vec![block_a, block_b], (1, 2)).unwrap(); - let fused_block = as crate::tensor_like::TensorLike>::fuse_indices( - &block, - &[i, j], - fused, - crate::tensor_like::LinearizationOrder::ColumnMajor, - ) - .unwrap(); + let fused_block = + as crate::tensor_like::TensorContractionLike>::fuse_indices( + &block, + &[i, j], + fused, + crate::tensor_like::LinearizationOrder::ColumnMajor, + ) + .unwrap(); assert_eq!(fused_block.shape(), (1, 2)); assert_eq!(fused_block.num_blocks(), 2); @@ -647,17 +649,7 @@ fn test_contract_unsupported() { let b1 = make_vector_with_index(vec![1.0, 2.0], &idx); let block = BlockTensor::new(vec![b1], (1, 1)).unwrap(); - let result = BlockTensor::::contract(&[&block], AllowedPairs::All); - assert!(result.is_err()); -} - -#[test] -fn test_contract_connected_unsupported() { - let idx = DynIndex::new_dyn(2); - let b1 = make_vector_with_index(vec![1.0, 2.0], &idx); - let block = BlockTensor::new(vec![b1], (1, 1)).unwrap(); - - let result = BlockTensor::::contract_connected(&[&block], AllowedPairs::All); + let result = BlockTensor::::contract(&[&block]); assert!(result.is_err()); } diff --git a/crates/tensor4all-core/src/defaults/contract.rs b/crates/tensor4all-core/src/defaults/contract.rs index 44256ec3..85e3bd09 100644 --- a/crates/tensor4all-core/src/defaults/contract.rs +++ b/crates/tensor4all-core/src/defaults/contract.rs @@ -8,8 +8,8 @@ //! //! # Main Functions //! -//! - [`contract_multi`]: Contracts tensors, handling disconnected components via outer product -//! - [`contract_connected`]: Contracts tensors that must form a connected graph +//! - [`contract`]: Contracts one connected tensor network +//! - [`contract_with_options`]: Contracts one connected tensor network with retained indices //! //! # Diag Tensor Handling //! @@ -27,14 +27,13 @@ use std::time::{Duration, Instant}; use anyhow::Result; use petgraph::algo::connected_components; use petgraph::prelude::*; -use tenferro::eager_einsum::eager_einsum_ad; +use tenferro::eager_tensor::einsum_subscripts as eager_einsum_ad; +use tenferro::EinsumSubscripts; use tensor4all_tensorbackend::{einsum_native_tensors, einsum_native_tensors_owned}; use crate::defaults::{DynId, DynIndex, TensorDynLen}; use crate::index_like::IndexLike; -use crate::tensor_like::AllowedPairs; - #[derive(Debug, Clone, Hash, PartialEq, Eq)] struct ContractOperandSignature { dims: Vec, @@ -95,7 +94,7 @@ pub fn print_and_reset_contract_profile() { state.borrow_mut().clear(); entries.sort_by_key(|(_, entry)| Reverse(entry.total_time)); - eprintln!("=== contract_multi Profile ==="); + eprintln!("=== contract Profile ==="); for (idx, (signature, entry)) in entries.into_iter().take(20).enumerate() { let operands = signature .operands @@ -129,34 +128,30 @@ pub fn print_and_reset_contract_profile() { /// Options for multi-tensor contraction. /// -/// Use this to choose which tensor pairs may contract and which shared indices -/// should be retained in the output instead of summed over. +/// Use this to choose which shared indices should be retained in the output +/// instead of summed over. /// /// # Examples /// /// ``` -/// use tensor4all_core::{AllowedPairs, ContractionOptions, DynIndex}; +/// use tensor4all_core::{ContractionOptions, DynIndex}; /// /// let batch = DynIndex::new_dyn(2); /// let retain = [batch.clone()]; -/// let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain); +/// let options = ContractionOptions::new().with_retain_indices(&retain); /// -/// assert!(matches!(options.allowed, AllowedPairs::All)); /// assert_eq!(options.retain_indices, &[batch]); /// ``` #[derive(Clone, Copy, Debug)] pub struct ContractionOptions<'a> { - /// Contractability policy for tensor pairs. - pub allowed: AllowedPairs<'a>, /// Indices that should remain in the result even if they appear more than once. pub retain_indices: &'a [DynIndex], } impl<'a> ContractionOptions<'a> { /// Create contraction options with no retained indices. - pub fn new(allowed: AllowedPairs<'a>) -> Self { + pub fn new() -> Self { Self { - allowed, retain_indices: &[], } } @@ -168,200 +163,234 @@ impl<'a> ContractionOptions<'a> { } } -/// Contract multiple tensors into a single tensor, handling disconnected components. -/// -/// This function automatically handles disconnected tensor graphs by: -/// 1. Finding connected components based on contractable indices -/// 2. Contracting each connected component separately -/// 3. Combining results using outer product -/// -/// # Arguments -/// * `tensors` - Slice of tensors to contract -/// * `allowed` - Specifies which tensor pairs can have their indices contracted -/// -/// # Returns -/// The result of contracting all tensors over allowed contractable indices. -/// If tensors form disconnected components, they are combined via outer product. -/// -/// # Behavior by N -/// - N=0: Error -/// - N=1: Clone of input -/// - N>=2: Contract connected components, combine with outer product +impl Default for ContractionOptions<'_> { + fn default() -> Self { + Self::new() + } +} + +/// Options for pairwise tensor contraction. /// -/// # Errors -/// - `AllowedPairs::Specified` contains a pair with no contractable indices +/// The conjugation flags are semantically equivalent to contracting +/// `lhs.conj()` or `rhs.conj()`, but allow implementations to pass conjugation +/// to the backend without materializing a conjugated tensor. /// /// # Examples /// /// ``` -/// use tensor4all_core::{TensorDynLen, DynIndex, contract_multi, AllowedPairs}; +/// use num_complex::Complex64; +/// use tensor4all_core::{ +/// contract_pair, contract_pair_with_operand_options, DynIndex, +/// PairwiseContractionOptions, TensorDynLen, +/// }; /// -/// // A[i, j] and B[j, k] share index j — contract to get C[i, k] /// let i = DynIndex::new_dyn(2); -/// let j = DynIndex::new_dyn(3); -/// let k = DynIndex::new_dyn(4); -/// -/// let a = TensorDynLen::from_dense( -/// vec![i.clone(), j.clone()], -/// vec![1.0_f64; 6], +/// let lhs = TensorDynLen::from_dense( +/// vec![i.clone()], +/// vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, -1.0)], /// ).unwrap(); -/// let b = TensorDynLen::from_dense( -/// vec![j.clone(), k.clone()], -/// vec![1.0_f64; 12], +/// let rhs = TensorDynLen::from_dense( +/// vec![i], +/// vec![Complex64::new(2.0, 0.5), Complex64::new(-1.0, 4.0)], /// ).unwrap(); /// -/// let c = contract_multi(&[&a, &b], AllowedPairs::All).unwrap(); -/// assert_eq!(c.dims(), vec![2, 4]); +/// let options = PairwiseContractionOptions::new().with_lhs_conj(true); +/// let flagged = contract_pair_with_operand_options(&lhs, &rhs, options).unwrap(); +/// let materialized = contract_pair(&lhs.conj(), &rhs).unwrap(); +/// +/// assert!((flagged.sum().unwrap() - materialized.sum().unwrap()).abs() < 1e-12); /// ``` -pub fn contract_multi( - tensors: &[&TensorDynLen], - allowed: AllowedPairs<'_>, -) -> Result { - contract_multi_with_options(tensors, ContractionOptions::new(allowed)) +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct PairwiseContractionOptions { + /// Whether to conjugate the left operand before contraction. + pub lhs_conj: bool, + /// Whether to conjugate the right operand before contraction. + pub rhs_conj: bool, +} + +impl PairwiseContractionOptions { + /// Create pairwise contraction options with no operand conjugation. + /// + /// # Examples + /// + /// ``` + /// use tensor4all_core::PairwiseContractionOptions; + /// + /// let options = PairwiseContractionOptions::new(); + /// assert!(!options.lhs_conj); + /// assert!(!options.rhs_conj); + /// ``` + pub fn new() -> Self { + Self::default() + } + + /// Set whether the left operand is conjugated during contraction. + /// + /// # Examples + /// + /// ``` + /// use tensor4all_core::PairwiseContractionOptions; + /// + /// let options = PairwiseContractionOptions::new().with_lhs_conj(true); + /// assert!(options.lhs_conj); + /// assert!(!options.rhs_conj); + /// ``` + pub fn with_lhs_conj(mut self, lhs_conj: bool) -> Self { + self.lhs_conj = lhs_conj; + self + } + + /// Set whether the right operand is conjugated during contraction. + /// + /// # Examples + /// + /// ``` + /// use tensor4all_core::PairwiseContractionOptions; + /// + /// let options = PairwiseContractionOptions::new().with_rhs_conj(true); + /// assert!(!options.lhs_conj); + /// assert!(options.rhs_conj); + /// ``` + pub fn with_rhs_conj(mut self, rhs_conj: bool) -> Self { + self.rhs_conj = rhs_conj; + self + } + + pub(crate) fn has_conj(self) -> bool { + self.lhs_conj || self.rhs_conj + } } -/// Contract multiple tensors into a single tensor with additional options. +/// Contract a connected tensor network with the default semantics. /// -/// This behaves like [`contract_multi`] but also allows selected shared indices -/// to be retained in the output. +/// This is the normal public entry point for N-ary tensor contraction. It +/// contracts all common contractable indices and requires the input tensors to +/// form one connected tensor graph. Disconnected inputs are rejected so missing +/// links do not silently become outer products. /// -/// # Arguments -/// * `tensors` - Slice of tensors to contract -/// * `options` - Pair-selection policy and retained indices +/// Use explicit [`outer_product`] calls when an outer product of disconnected +/// components is intentional. +pub fn contract(tensors: &[&TensorDynLen]) -> Result { + contract_with_options(tensors, ContractionOptions::new()) +} + +/// Contract a connected tensor network with advanced options. +pub fn contract_with_options( + tensors: &[&TensorDynLen], + options: ContractionOptions<'_>, +) -> Result { + contract_with_options_impl(tensors, options) +} + +/// Contract owned tensors with the default connected-network semantics. +pub fn contract_owned(tensors: Vec) -> Result { + contract_owned_with_options(tensors, ContractionOptions::new()) +} + +/// Contract owned tensors with advanced connected-network options. +pub fn contract_owned_with_options( + tensors: Vec, + options: ContractionOptions<'_>, +) -> Result { + let tensor_refs = tensors.iter().collect::>(); + let components = + find_tensor_connected_components_with_retained(&tensor_refs, options.retain_indices); + if components.len() > 1 { + return Err(anyhow::anyhow!( + "Tensors form disconnected components; use explicit outer_product operations for an intentional disconnected product" + )); + } + drop(tensor_refs); + contract_owned_with_options_impl(tensors, options) +} + +/// Contract two tensors with the default pairwise semantics. /// -/// # Returns -/// The contracted tensor, possibly with retained shared indices in the result. +/// This is the concrete `TensorDynLen` entry point for binary contraction. It +/// contracts all common indices and preserves the pairwise structured fast +/// paths used by [`TensorContractionLike::contract_pair`]. +pub fn contract_pair(lhs: &TensorDynLen, rhs: &TensorDynLen) -> Result { + lhs.try_contract_pairwise_default_with_options(rhs, PairwiseContractionOptions::new()) +} + +/// Contract two tensors with operand-level conjugation options. /// -/// # Errors -/// Returns an error if: -/// - no tensors are provided -/// - `AllowedPairs::Specified` contains a pair with no contractable indices -/// - a retained index does not appear in the inputs -/// - a shared internal label has inconsistent dimensions +/// This has the same index semantics as [`contract_pair`], with optional +/// conjugation applied to either operand before matching and contracting common +/// indices. Implementations may pass conjugation to the backend to avoid +/// materializing conjugated payloads. /// /// # Examples /// /// ``` -/// use tensor4all_core::{contract_multi_with_options, AllowedPairs, ContractionOptions, DynIndex, TensorDynLen}; +/// use num_complex::Complex64; +/// use tensor4all_core::{ +/// contract_pair, contract_pair_with_operand_options, DynIndex, +/// PairwiseContractionOptions, TensorDynLen, +/// }; /// /// let i = DynIndex::new_dyn(2); -/// let j = DynIndex::new_dyn(3); -/// let k = DynIndex::new_dyn(4); +/// let lhs = TensorDynLen::from_dense( +/// vec![i.clone()], +/// vec![Complex64::new(1.0, 1.0), Complex64::new(0.0, 2.0)], +/// ).unwrap(); +/// let rhs = TensorDynLen::from_dense( +/// vec![i], +/// vec![Complex64::new(2.0, 0.0), Complex64::new(3.0, -1.0)], +/// ).unwrap(); +/// +/// let flagged = contract_pair_with_operand_options( +/// &lhs, +/// &rhs, +/// PairwiseContractionOptions::new().with_lhs_conj(true), +/// ).unwrap(); +/// let materialized = contract_pair(&lhs.conj(), &rhs).unwrap(); /// -/// let a = TensorDynLen::from_dense(vec![i.clone(), j.clone()], vec![1.0_f64; 6]).unwrap(); -/// let b = TensorDynLen::from_dense(vec![j.clone(), k.clone()], vec![1.0_f64; 12]).unwrap(); -/// let retain_indices = [j.clone()]; -/// let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); -/// let c = contract_multi_with_options(&[&a, &b], options).unwrap(); -/// assert_eq!(c.dims(), vec![2, 3, 4]); +/// assert!((flagged.sum().unwrap() - materialized.sum().unwrap()).abs() < 1e-12); /// ``` -pub fn contract_multi_with_options( - tensors: &[&TensorDynLen], - options: ContractionOptions<'_>, +pub fn contract_pair_with_operand_options( + lhs: &TensorDynLen, + rhs: &TensorDynLen, + options: PairwiseContractionOptions, ) -> Result { - match tensors.len() { - 0 => Err(anyhow::anyhow!("No tensors to contract")), - _ => { - validate_retained_indices_exist(tensors, options.retain_indices)?; - if tensors.len() == 1 { - return Ok((*tensors[0]).clone()); - } - - // Validate AllowedPairs::Specified pairs have contractable indices - if let AllowedPairs::Specified(pairs) = options.allowed { - for &(i, j) in pairs { - if !has_contractable_indices(tensors[i], tensors[j]) { - return Err(anyhow::anyhow!( - "Specified pair ({}, {}) has no contractable indices", - i, - j - )); - } - } - } + lhs.try_contract_pairwise_default_with_options(rhs, options) +} - // Find connected components - let components = find_tensor_connected_components_with_retained( - tensors, - options.allowed, - options.retain_indices, - ); +/// Contract two tensors with explicit contraction options. +pub fn contract_pair_with_options( + lhs: &TensorDynLen, + rhs: &TensorDynLen, + options: ContractionOptions<'_>, +) -> Result { + contract_with_options(&[lhs, rhs], options) +} - if components.len() == 1 { - // All tensors connected - use optimized contraction (skip connectivity check) - contract_multi_impl(tensors, options) - } else { - // Multiple components - contract each and combine with outer product - let mut results: Vec = Vec::new(); - for component in &components { - let component_tensors: Vec<&TensorDynLen> = - component.iter().map(|&i| tensors[i]).collect(); - let component_retain_indices = - retained_indices_for_component(tensors, component, options.retain_indices); - - // Remap AllowedPairs for the component (connectivity already verified) - let remapped_allowed = remap_allowed_pairs(options.allowed, component); - let component_options = ContractionOptions { - allowed: remapped_allowed.as_ref(), - retain_indices: &component_retain_indices, - }; - let contracted = contract_multi_impl(&component_tensors, component_options)?; - results.push(contracted); - } +/// Contract two tensors along explicitly specified index pairs. +pub fn tensordot( + lhs: &TensorDynLen, + rhs: &TensorDynLen, + pairs: &[(DynIndex, DynIndex)], +) -> Result { + lhs.try_tensordot_pairwise_explicit(rhs, pairs) +} - // Combine with outer product - let mut results_iter = results.into_iter(); - let Some(mut result) = results_iter.next() else { - return Err(anyhow::anyhow!("No contracted components produced")); - }; - for other in results_iter { - result = result.outer_product(&other)?; - } - Ok(result) - } - } - } +/// Compute the outer product of two tensors. +/// +/// This is an explicit tensor product, not a dense-only operation. Compact +/// structured storage is preserved when the operand layouts allow it. +pub fn outer_product(lhs: &TensorDynLen, rhs: &TensorDynLen) -> Result { + lhs.try_outer_product_pairwise(rhs) } /// Contract multiple owned tensors into a single tensor. /// -/// This is the consuming counterpart to [`contract_multi_with_options`]. It +/// This is the consuming implementation for [`contract_owned_with_options`]. It /// preserves the same contraction semantics while allowing eligible non-AD /// dense inputs to use tenferro's owned eager einsum executor. When any input /// tracks gradients, or when compact structured metadata needs the borrowed /// path, this function falls back to the shared borrowed execution so semantics /// and reverse-mode AD remain intact. -/// -/// # Arguments -/// * `tensors` - Owned tensors to contract. -/// * `options` - Pair-selection policy and retained indices. -/// -/// # Returns -/// The contracted tensor, with retained shared indices preserved in the output. -/// -/// # Errors -/// Returns an error for the same conditions as -/// [`contract_multi_with_options`], including empty input, invalid retained -/// indices, and incompatible contraction pairs. -/// -/// # Examples -/// -/// ``` -/// use tensor4all_core::{contract_multi_owned, contract_multi_with_options, AllowedPairs, ContractionOptions, DynIndex, TensorDynLen}; -/// -/// let i = DynIndex::new_dyn(2); -/// let j = DynIndex::new_dyn(3); -/// let k = DynIndex::new_dyn(4); -/// let a = TensorDynLen::from_dense(vec![i.clone(), j.clone()], vec![1.0_f64; 6]).unwrap(); -/// let b = TensorDynLen::from_dense(vec![j.clone(), k.clone()], vec![1.0_f64; 12]).unwrap(); -/// let options = ContractionOptions::new(AllowedPairs::All); -/// -/// let owned = contract_multi_owned(vec![a.clone(), b.clone()], options).unwrap(); -/// let borrowed = contract_multi_with_options(&[&a, &b], options).unwrap(); -/// assert_eq!(owned.indices(), borrowed.indices()); -/// assert_eq!(owned.to_vec::().unwrap(), borrowed.to_vec::().unwrap()); -/// ``` -pub fn contract_multi_owned( +fn contract_owned_with_options_impl( tensors: Vec, options: ContractionOptions<'_>, ) -> Result { @@ -379,33 +408,22 @@ pub fn contract_multi_owned( return Ok(tensor); } - if let AllowedPairs::Specified(pairs) = options.allowed { - for &(i, j) in pairs { - if !has_contractable_indices(tensor_refs[i], tensor_refs[j]) { - return Err(anyhow::anyhow!( - "Specified pair ({}, {}) has no contractable indices", - i, - j - )); - } - } - } - let requires_borrowed_path = tensor_refs.iter().any(|tensor| tensor.tracks_grad()) || tensor_refs .iter() .any(|tensor| !has_dense_axis_classes(tensor)); if requires_borrowed_path { - return contract_multi_with_options(&tensor_refs, options); + return contract_with_options(&tensor_refs, options); } let components = find_tensor_connected_components_with_retained( &tensor_refs, - options.allowed, options.retain_indices, ); if components.len() > 1 { - return contract_multi_with_options(&tensor_refs, options); + return Err(anyhow::anyhow!( + "Tensors form disconnected components; use explicit outer_product operations for an intentional disconnected product" + )); } let mut diag_uf = AxisUnionFind::new(); @@ -440,104 +458,7 @@ fn has_dense_axis_classes(tensor: &TensorDynLen) -> bool { .eq(0..tensor.indices().len()) } -/// Contract multiple tensors that form a connected graph. -/// -/// Uses einsum optimization via tensorbackend. -/// -/// # Arguments -/// * `tensors` - Slice of tensors to contract (must form a connected graph) -/// * `allowed` - Specifies which tensor pairs can have their indices contracted -/// -/// # Returns -/// The result of contracting all tensors over allowed contractable indices. -/// -/// # Connectivity Requirement -/// All tensors must form a connected graph through contractable indices. -/// Two tensors are connected if they share a contractable index (same ID, dual direction). -/// If the tensors form disconnected components, this function returns an error. -/// -/// Use [`contract_multi`] if you want automatic handling of disconnected components. -/// -/// # Behavior by N -/// - N=0: Error -/// - N=1: Clone of input -/// - N>=2: Optimized order via the tensorbackend einsum path -/// -/// # Examples -/// -/// ``` -/// use tensor4all_core::{TensorDynLen, DynIndex, contract_connected, AllowedPairs}; -/// -/// // A[i, j] contracted with B[j, k] -/// let i = DynIndex::new_dyn(2); -/// let j = DynIndex::new_dyn(3); -/// let k = DynIndex::new_dyn(4); -/// -/// let a = TensorDynLen::from_dense( -/// vec![i.clone(), j.clone()], -/// vec![1.0_f64; 6], -/// ).unwrap(); -/// let b = TensorDynLen::from_dense( -/// vec![j.clone(), k.clone()], -/// vec![1.0_f64; 12], -/// ).unwrap(); -/// -/// let c = contract_connected(&[&a, &b], AllowedPairs::All).unwrap(); -/// assert_eq!(c.dims(), vec![2, 4]); -/// ``` -pub fn contract_connected( - tensors: &[&TensorDynLen], - allowed: AllowedPairs<'_>, -) -> Result { - contract_connected_with_options(tensors, ContractionOptions::new(allowed)) -} - -/// Contract a connected tensor network with additional options. -/// -/// This behaves like [`contract_connected`] but also allows selected shared -/// indices to be retained in the output. -/// -/// # Arguments -/// * `tensors` - Slice of tensors to contract -/// * `options` - Pair-selection policy and retained indices -/// -/// # Returns -/// The contracted tensor. -/// -/// # Errors -/// Returns an error if the tensors are disconnected, no tensors are provided, -/// or retained indices are invalid. -/// -/// # Examples -/// -/// ``` -/// use tensor4all_core::{ -/// contract_connected_with_options, AllowedPairs, ContractionOptions, DynIndex, TensorDynLen, -/// }; -/// -/// let batch = DynIndex::new_dyn(2); -/// let i = DynIndex::new_dyn(2); -/// let k = DynIndex::new_dyn(3); -/// let j = DynIndex::new_dyn(2); -/// -/// let a = TensorDynLen::from_dense( -/// vec![batch.clone(), i.clone(), k.clone()], -/// vec![1.0_f64; 12], -/// ) -/// .unwrap(); -/// let b = TensorDynLen::from_dense( -/// vec![batch.clone(), k, j.clone()], -/// vec![1.0_f64; 12], -/// ) -/// .unwrap(); -/// let retain = [batch.clone()]; -/// let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain); -/// -/// let c = contract_connected_with_options(&[&a, &b], options).unwrap(); -/// assert_eq!(c.indices(), &[batch, i, j]); -/// assert_eq!(c.to_vec::().unwrap(), vec![3.0; 8]); -/// ``` -pub fn contract_connected_with_options( +fn contract_with_options_impl( tensors: &[&TensorDynLen], options: ContractionOptions<'_>, ) -> Result { @@ -550,11 +471,8 @@ pub fn contract_connected_with_options( } // Check connectivity first - let components = find_tensor_connected_components_with_retained( - tensors, - options.allowed, - options.retain_indices, - ); + let components = + find_tensor_connected_components_with_retained(tensors, options.retain_indices); if components.len() > 1 { return Err(anyhow::anyhow!( "Disconnected tensor network: {} components found", @@ -562,7 +480,7 @@ pub fn contract_connected_with_options( )); } // Connectivity verified - skip check in impl - contract_multi_impl(tensors, options) + contract_impl(tensors, options) } } } @@ -727,7 +645,7 @@ pub fn collect_sizes(tensors: &[&TensorDynLen], uf: &mut AxisUnionFind) -> HashM /// /// This implementation preserves storage type: if all inputs are F64, the result /// is F64; if any input is C64, the result is C64. -fn contract_multi_impl( +fn contract_impl( tensors: &[&TensorDynLen], options: ContractionOptions<'_>, ) -> Result { @@ -739,7 +657,7 @@ fn contract_multi_impl( // 2. Build the contraction plan from internal labels. let plan = build_contraction_plan(tensors, options, &mut diag_uf)?; - // Note: Connectivity check is done by caller (contract_multi or contract_connected) + // Note: Connectivity check is done by caller. // via find_tensor_connected_components before calling this function // 3. Build sizes from unique internal IDs. @@ -830,7 +748,7 @@ fn execute_contraction_plan( }; let mut result = (*first).clone(); for tensor in iter { - result = result.try_contract_pairwise_default(tensor)?; + result = contract_pair(&result, tensor)?; } return Ok(result); } @@ -867,28 +785,26 @@ fn execute_contraction_plan( fn build_einsum_subscripts_from_usize_ids( input_ids: &[Vec], output_ids: &[usize], -) -> Result { - fn ids_to_subscript(ids: &[usize]) -> Result { - const LETTERS: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; - let mut out = String::with_capacity(ids.len()); - for &id in ids { - let letter = LETTERS.get(id).ok_or_else(|| { - anyhow::anyhow!("einsum label {id} exceeds supported label range") - })?; - out.push(char::from(*letter)); - } - Ok(out) - } - +) -> Result { let inputs = input_ids .iter() - .map(|ids| ids_to_subscript(ids)) + .map(|ids| { + ids.iter() + .map(|&id| { + u32::try_from(id) + .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range")) + }) + .collect::>>() + }) + .collect::>>()?; + let output = output_ids + .iter() + .map(|&id| { + u32::try_from(id).map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range")) + }) .collect::>>()?; - Ok(format!( - "{}->{}", - inputs.join(","), - ids_to_subscript(output_ids)? - )) + let input_refs = inputs.iter().map(Vec::as_slice).collect::>(); + Ok(EinsumSubscripts::new(&input_refs, &output)) } /// A contraction plan with internal labels and result ordering. @@ -907,7 +823,7 @@ fn build_contraction_plan( ) -> Result { let retained_indices: HashSet = options.retain_indices.iter().cloned().collect(); let (input_ids, internal_id_to_original) = - build_internal_ids(tensors, options.allowed, diag_uf, &retained_indices)?; + build_internal_ids(tensors, diag_uf, &retained_indices)?; let mut counts: HashMap = HashMap::new(); for ids in &input_ids { @@ -978,28 +894,6 @@ fn validate_retained_indices_exist( Ok(()) } -fn retained_indices_for_component( - tensors: &[&TensorDynLen], - component: &[usize], - retain_indices: &[DynIndex], -) -> Vec { - let mut seen = HashSet::new(); - let mut retained = Vec::new(); - for retain in retain_indices { - if seen.insert(retain.clone()) - && component.iter().any(|&tensor_idx| { - tensors[tensor_idx] - .indices() - .iter() - .any(|idx| idx == retain) - }) - { - retained.push(retain.clone()); - } - } - retained -} - fn validate_unique_output_indices(indices: &[DynIndex]) -> Result<()> { let mut seen = HashSet::new(); for idx in indices { @@ -1097,7 +991,6 @@ fn output_axis_classes( #[allow(clippy::type_complexity)] fn build_internal_ids( tensors: &[&TensorDynLen], - allowed: AllowedPairs<'_>, _diag_uf: &mut AxisUnionFind, retained_indices: &HashSet, ) -> Result<(Vec>, HashMap)> { @@ -1107,54 +1000,42 @@ fn build_internal_ids( let mut assigned: HashMap<(usize, usize), usize> = HashMap::new(); let mut internal_id_to_original: HashMap = HashMap::new(); - // Process contractable pairs - let pairs_to_process: Vec<(usize, usize)> = match allowed { - AllowedPairs::All => { - let mut pairs = Vec::new(); - for ti in 0..tensors.len() { - for tj in (ti + 1)..tensors.len() { - pairs.push((ti, tj)); - } - } - pairs - } - AllowedPairs::Specified(pairs) => pairs.to_vec(), - }; - - for (ti, tj) in pairs_to_process { - for (pi, idx_i) in tensors[ti].indices.iter().enumerate() { - for (pj, idx_j) in tensors[tj].indices.iter().enumerate() { - if idx_i.is_contractable(idx_j) { - let key_i = (ti, pi); - let key_j = (tj, pj); - - match (assigned.get(&key_i).copied(), assigned.get(&key_j).copied()) { - (None, None) => { - let internal_id = if let Some(&id) = index_to_internal.get(idx_i) { - id - } else { - let id = next_id; - next_id += 1; + for ti in 0..tensors.len() { + for tj in (ti + 1)..tensors.len() { + for (pi, idx_i) in tensors[ti].indices.iter().enumerate() { + for (pj, idx_j) in tensors[tj].indices.iter().enumerate() { + if idx_i.is_contractable(idx_j) { + let key_i = (ti, pi); + let key_j = (tj, pj); + + match (assigned.get(&key_i).copied(), assigned.get(&key_j).copied()) { + (None, None) => { + let internal_id = if let Some(&id) = index_to_internal.get(idx_i) { + id + } else { + let id = next_id; + next_id += 1; + index_to_internal.insert(idx_i.clone(), id); + internal_id_to_original.insert(id, key_i); + id + }; + assigned.insert(key_i, internal_id); + assigned.insert(key_j, internal_id); + if idx_i != idx_j { + index_to_internal.insert(idx_j.clone(), internal_id); + } + } + (Some(id), None) => { + assigned.insert(key_j, id); + index_to_internal.insert(idx_j.clone(), id); + } + (None, Some(id)) => { + assigned.insert(key_i, id); index_to_internal.insert(idx_i.clone(), id); - internal_id_to_original.insert(id, key_i); - id - }; - assigned.insert(key_i, internal_id); - assigned.insert(key_j, internal_id); - if idx_i != idx_j { - index_to_internal.insert(idx_j.clone(), internal_id); } - } - (Some(id), None) => { - assigned.insert(key_j, id); - index_to_internal.insert(idx_j.clone(), id); - } - (None, Some(id)) => { - assigned.insert(key_i, id); - index_to_internal.insert(idx_i.clone(), id); - } - (Some(_id_i), Some(_id_j)) => { - // Both already assigned + (Some(_id_i), Some(_id_j)) => { + // Both already assigned + } } } } @@ -1217,16 +1098,12 @@ fn has_contractable_indices(a: &TensorDynLen, b: &TensorDynLen) -> bool { /// /// Uses petgraph for O(V+E) connected component detection. #[allow(dead_code)] -fn find_tensor_connected_components( - tensors: &[&TensorDynLen], - allowed: AllowedPairs<'_>, -) -> Vec> { - find_tensor_connected_components_with_retained(tensors, allowed, &[]) +fn find_tensor_connected_components(tensors: &[&TensorDynLen]) -> Vec> { + find_tensor_connected_components_with_retained(tensors, &[]) } fn find_tensor_connected_components_with_retained( tensors: &[&TensorDynLen], - allowed: AllowedPairs<'_>, retain_indices: &[DynIndex], ) -> Vec> { let n = tensors.len(); @@ -1241,22 +1118,10 @@ fn find_tensor_connected_components_with_retained( let mut graph = UnGraph::<(), ()>::new_undirected(); let nodes: Vec<_> = (0..n).map(|_| graph.add_node(())).collect(); - // Add edges based on connectivity - match allowed { - AllowedPairs::All => { - for i in 0..n { - for j in (i + 1)..n { - if has_contractable_indices(tensors[i], tensors[j]) { - graph.add_edge(nodes[i], nodes[j], ()); - } - } - } - } - AllowedPairs::Specified(pairs) => { - for &(i, j) in pairs { - if has_contractable_indices(tensors[i], tensors[j]) { - graph.add_edge(nodes[i], nodes[j], ()); - } + for i in 0..n { + for j in (i + 1)..n { + if has_contractable_indices(tensors[i], tensors[j]) { + graph.add_edge(nodes[i], nodes[j], ()); } } } @@ -1310,46 +1175,5 @@ fn shares_retained_index(a: &TensorDynLen, b: &TensorDynLen, retain_indices: &[D }) } -/// Remap AllowedPairs for a subset of tensors. -fn remap_allowed_pairs(allowed: AllowedPairs<'_>, component: &[usize]) -> RemappedAllowedPairs { - match allowed { - AllowedPairs::All => RemappedAllowedPairs::All, - AllowedPairs::Specified(pairs) => { - let orig_to_local: HashMap = component - .iter() - .enumerate() - .map(|(local, &orig)| (orig, local)) - .collect(); - - let remapped: Vec<(usize, usize)> = pairs - .iter() - .filter_map( - |&(i, j)| match (orig_to_local.get(&i), orig_to_local.get(&j)) { - (Some(&li), Some(&lj)) => Some((li, lj)), - _ => None, - }, - ) - .collect(); - - RemappedAllowedPairs::Specified(remapped) - } - } -} - -/// Owned version of AllowedPairs for remapped components. -enum RemappedAllowedPairs { - All, - Specified(Vec<(usize, usize)>), -} - -impl RemappedAllowedPairs { - fn as_ref(&self) -> AllowedPairs<'_> { - match self { - RemappedAllowedPairs::All => AllowedPairs::All, - RemappedAllowedPairs::Specified(pairs) => AllowedPairs::Specified(pairs), - } - } -} - #[cfg(test)] mod tests; diff --git a/crates/tensor4all-core/src/defaults/contract/tests/mod.rs b/crates/tensor4all-core/src/defaults/contract/tests/mod.rs index 66d2d11c..a75953cc 100644 --- a/crates/tensor4all-core/src/defaults/contract/tests/mod.rs +++ b/crates/tensor4all-core/src/defaults/contract/tests/mod.rs @@ -1,6 +1,5 @@ use super::*; use crate::defaults::Index; -use crate::tensor_like::TensorLike; use num_complex::Complex64; use std::ffi::OsString; use std::time::Duration; @@ -69,34 +68,74 @@ fn col_major_offset(coords: &[usize], dims: &[usize]) -> usize { } // ======================================================================== -// contract_multi tests +// contract tests // ======================================================================== #[test] -fn test_contract_multi_empty() { +fn test_contract_empty() { let tensors: Vec<&TensorDynLen> = vec![]; - let result = contract_multi(&tensors, AllowedPairs::All); + let result = contract(&tensors); assert!(result.is_err()); } #[test] -fn test_contract_multi_single() { +fn test_contract_single() { let tensor = make_test_tensor(&[2, 3], &[1, 2]); - let result = contract_multi(&[&tensor], AllowedPairs::All).unwrap(); + let result = contract(&[&tensor]).unwrap(); assert_eq!(result.dims(), tensor.dims()); } #[test] -fn test_contract_multi_pair() { +fn test_contract_pair_nary_entry() { // A[i,j] * B[j,k] -> C[i,k] let a = make_test_tensor(&[2, 3], &[1, 2]); // i=1, j=2 let b = make_test_tensor(&[3, 4], &[2, 3]); // j=2, k=3 - let result = contract_multi(&[&a, &b], AllowedPairs::All).unwrap(); + let result = contract(&[&a, &b]).unwrap(); assert_eq!(result.dims(), vec![2, 4]); // i, k } #[test] -fn test_contract_multi_diag_diag_partial_preserves_diagonal_storage() { +fn test_contract_default_entry_rejects_disconnected_inputs() { + let i = DynIndex::new_dyn(2); + let j = DynIndex::new_dyn(3); + + let a = TensorDynLen::from_dense(vec![i], vec![1.0_f64, 2.0]).unwrap(); + let b = TensorDynLen::from_dense(vec![j], vec![3.0_f64, 4.0, 5.0]).unwrap(); + + let err = contract(&[&a, &b]).unwrap_err(); + assert!(err.to_string().contains("Disconnected")); +} + +#[test] +fn test_outer_product_is_explicit_disconnected_entry() { + let i = DynIndex::new_dyn(2); + let j = DynIndex::new_dyn(3); + + let a = TensorDynLen::from_dense(vec![i.clone()], vec![1.0_f64, 2.0]).unwrap(); + let b = TensorDynLen::from_dense(vec![j.clone()], vec![3.0_f64, 4.0, 5.0]).unwrap(); + + let result = outer_product(&a, &b).unwrap(); + assert_eq!(result.indices(), &[i, j]); + assert_eq!( + result.to_vec::().unwrap(), + vec![3.0, 6.0, 4.0, 8.0, 5.0, 10.0] + ); +} + +#[test] +fn test_contract_owned_rejects_disconnected_inputs() { + let i = DynIndex::new_dyn(2); + let j = DynIndex::new_dyn(2); + + let a = TensorDynLen::from_dense(vec![i], vec![1.0_f64, 2.0]).unwrap(); + let b = TensorDynLen::from_dense(vec![j], vec![3.0_f64, 4.0]).unwrap(); + + let err = contract_owned(vec![a, b]).unwrap_err(); + assert!(err.to_string().contains("disconnected")); +} + +#[test] +fn test_contract_diag_diag_partial_preserves_diagonal_storage() { let i = Index::new(DynId(1), 3); let j = Index::new(DynId(2), 3); let k = Index::new(DynId(3), 3); @@ -104,7 +143,7 @@ fn test_contract_multi_diag_diag_partial_preserves_diagonal_storage() { let a = TensorDynLen::from_diag(vec![i.clone(), j.clone()], vec![1.0_f64, 2.0, 3.0]).unwrap(); let b = TensorDynLen::from_diag(vec![j, k.clone()], vec![4.0_f64, 5.0, 6.0]).unwrap(); - let result = contract_multi(&[&a, &b], AllowedPairs::All).unwrap(); + let result = contract(&[&a, &b]).unwrap(); assert_eq!(result.dims(), vec![3, 3]); assert!(result.is_diag()); @@ -115,36 +154,36 @@ fn test_contract_multi_diag_diag_partial_preserves_diagonal_storage() { } #[test] -fn test_contract_multi_three() { +fn test_contract_three() { // A[i,j] * B[j,k] * C[k,l] -> D[i,l] let a = make_test_tensor(&[2, 3], &[1, 2]); // i=1, j=2 let b = make_test_tensor(&[3, 4], &[2, 3]); // j=2, k=3 let c = make_test_tensor(&[4, 5], &[3, 4]); // k=3, l=4 - let result = contract_multi(&[&a, &b, &c], AllowedPairs::All).unwrap(); + let result = contract(&[&a, &b, &c]).unwrap(); let mut sorted_dims = result.dims(); sorted_dims.sort(); assert_eq!(sorted_dims, vec![2, 5]); // i=2, l=5 } #[test] -fn test_contract_multi_four() { +fn test_contract_four() { // A[i,j] * B[j,k] * C[k,l] * D[l,m] -> E[i,m] let a = make_test_tensor(&[2, 3], &[1, 2]); let b = make_test_tensor(&[3, 4], &[2, 3]); let c = make_test_tensor(&[4, 5], &[3, 4]); let d = make_test_tensor(&[5, 6], &[4, 5]); - let result = contract_multi(&[&a, &b, &c, &d], AllowedPairs::All).unwrap(); + let result = contract(&[&a, &b, &c, &d]).unwrap(); let mut sorted_dims = result.dims(); sorted_dims.sort(); assert_eq!(sorted_dims, vec![2, 6]); // i=2, m=6 } #[test] -fn test_contract_multi_outer_product() { +fn test_outer_product_matrix_matrix() { // A[i,j] * B[k,l] (no common indices) -> outer product C[i,j,k,l] let a = make_test_tensor(&[2, 3], &[1, 2]); let b = make_test_tensor(&[4, 5], &[3, 4]); - let result = contract_multi(&[&a, &b], AllowedPairs::All).unwrap(); + let result = outer_product(&a, &b).unwrap(); let result_dims = result.dims(); let total_elements: usize = result_dims.iter().product(); assert_eq!(total_elements, 2 * 3 * 4 * 5); @@ -152,11 +191,11 @@ fn test_contract_multi_outer_product() { } #[test] -fn test_contract_multi_vector_outer_product() { +fn test_outer_product_vector_vector() { // A[i] * B[j] (no common indices) -> outer product C[i,j] let a = make_test_tensor(&[2], &[1]); // i=1 let b = make_test_tensor(&[3], &[2]); // j=2 - let result = contract_multi(&[&a, &b], AllowedPairs::All).unwrap(); + let result = outer_product(&a, &b).unwrap(); let result_dims = result.dims(); let total_elements: usize = result_dims.iter().product(); assert_eq!(total_elements, 2 * 3); @@ -164,10 +203,10 @@ fn test_contract_multi_vector_outer_product() { } #[test] -fn test_contract_connected_disconnected_error() { +fn test_contract_disconnected_error() { let a = make_test_tensor(&[2, 3], &[1, 2]); let b = make_test_tensor(&[4, 5], &[3, 4]); - let result = contract_connected(&[&a, &b], AllowedPairs::All); + let result = contract(&[&a, &b]); assert!(result.is_err()); assert!(result .unwrap_err() @@ -177,21 +216,7 @@ fn test_contract_connected_disconnected_error() { } #[test] -fn test_contract_connected_specified_no_contractable_error() { - let a = make_test_tensor(&[2, 3], &[1, 2]); // i=1, j=2 - let b = make_test_tensor(&[4, 5], &[3, 4]); // k=3, l=4 (no common with a) - let result = contract_connected(&[&a, &b], AllowedPairs::Specified(&[(0, 1)])); - assert!(result.is_err()); - let err_msg = result.unwrap_err().to_string().to_lowercase(); - assert!( - err_msg.contains("disconnected") || err_msg.contains("no contractable"), - "Expected error about disconnected or no contractable indices, got: {}", - err_msg - ); -} - -#[test] -fn test_contract_multi_with_options_retains_shared_batch_index() { +fn test_contract_with_options_retains_shared_batch_index() { let batch = Index::new(DynId(10), 2); let i = Index::new(DynId(11), 2); let k = Index::new(DynId(12), 3); @@ -209,8 +234,8 @@ fn test_contract_multi_with_options_retains_shared_batch_index() { ); let retain_indices = [batch.clone()]; - let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); - let result = contract_multi_with_options(&[&a, &b], options).unwrap(); + let options = ContractionOptions::new().with_retain_indices(&retain_indices); + let result = contract_with_options(&[&a, &b], options).unwrap(); assert_eq!(result.indices(), &[batch.clone(), i.clone(), j.clone()]); assert_eq!(result.dims(), vec![2, 2, 2]); @@ -255,13 +280,12 @@ fn test_tensor_contract_with_options_retains_shared_index() { vec![1.0; 12], ); - let result = a - .contract_with_options( - &b, - ContractionOptions::new(AllowedPairs::All) - .with_retain_indices(std::slice::from_ref(&batch)), - ) - .unwrap(); + let result = contract_pair_with_options( + &a, + &b, + ContractionOptions::new().with_retain_indices(std::slice::from_ref(&batch)), + ) + .unwrap(); assert_eq!(result.indices(), &[batch, i, j]); assert_eq!(result.dims(), vec![2, 2, 2]); @@ -286,8 +310,8 @@ fn test_contract_retains_exact_same_id_prime_index() { ); let retain_indices = [batch_prime.clone()]; - let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); - let result = contract_multi_with_options(&[&a, &b], options).unwrap(); + let options = ContractionOptions::new().with_retain_indices(&retain_indices); + let result = contract_with_options(&[&a, &b], options).unwrap(); assert_eq!(result.indices(), &[batch_prime]); assert_eq!(result.dims(), vec![2]); @@ -295,7 +319,7 @@ fn test_contract_retains_exact_same_id_prime_index() { } #[test] -fn test_contract_multi_with_options_supports_three_way_retained_label() { +fn test_contract_with_options_supports_three_way_retained_label() { let batch = Index::new(DynId(20), 2); let i = Index::new(DynId(21), 2); let j = Index::new(DynId(22), 3); @@ -318,8 +342,8 @@ fn test_contract_multi_with_options_supports_three_way_retained_label() { ); let retain_indices = [batch.clone()]; - let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); - let result = contract_multi_with_options(&[&a, &b, &c], options).unwrap(); + let options = ContractionOptions::new().with_retain_indices(&retain_indices); + let result = contract_with_options(&[&a, &b, &c], options).unwrap(); assert_eq!( result.indices(), @@ -351,7 +375,7 @@ fn test_contract_multi_with_options_supports_three_way_retained_label() { } #[test] -fn test_contract_multi_with_options_errors_for_missing_retained_index() { +fn test_contract_with_options_errors_for_missing_retained_index() { let i = Index::new(DynId(30), 2); let j = Index::new(DynId(31), 3); let missing = Index::new(DynId(32), 2); @@ -360,14 +384,14 @@ fn test_contract_multi_with_options_errors_for_missing_retained_index() { let b = make_test_tensor_from_data(&[3], vec![j.clone()], vec![3.0, 4.0, 5.0]); let retain_indices = [missing]; - let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); - let result = contract_multi_with_options(&[&a, &b], options); + let options = ContractionOptions::new().with_retain_indices(&retain_indices); + let result = contract_with_options(&[&a, &b], options); assert!(result.is_err()); } #[test] -fn test_contract_multi_with_options_retained_index_connects_components() { +fn test_contract_with_options_retained_index_connects_components() { let batch = Index::new(DynId(40), 2); let i = Index::new(DynId(41), 2); let j = Index::new(DynId(42), 3); @@ -384,9 +408,8 @@ fn test_contract_multi_with_options_retained_index_connects_components() { ); let retain_indices = [batch.clone()]; - let options = - ContractionOptions::new(AllowedPairs::Specified(&[])).with_retain_indices(&retain_indices); - let result = contract_multi_with_options(&[&a, &b], options).unwrap(); + let options = ContractionOptions::new().with_retain_indices(&retain_indices); + let result = contract_with_options(&[&a, &b], options).unwrap(); assert_eq!(result.indices(), &[batch.clone(), i.clone(), j.clone()]); assert_eq!(result.dims(), vec![2, 2, 3]); @@ -410,32 +433,7 @@ fn test_contract_multi_with_options_retained_index_connects_components() { } #[test] -fn test_contract_multi_with_options_does_not_contract_unretained_shared_index_between_retained_components( -) { - let batch = Index::new(DynId(50), 2); - let i = Index::new(DynId(51), 2); - - let a = make_test_tensor_from_data( - &[2, 2], - vec![batch.clone(), i.clone()], - vec![1.0, 2.0, 3.0, 4.0], - ); - let b = make_test_tensor_from_data( - &[2, 2], - vec![batch.clone(), i.clone()], - vec![5.0, 6.0, 7.0, 8.0], - ); - - let retain_indices = [batch.clone()]; - let options = - ContractionOptions::new(AllowedPairs::Specified(&[])).with_retain_indices(&retain_indices); - let result = contract_multi_with_options(&[&a, &b], options); - - assert!(result.is_err()); -} - -#[test] -fn test_contract_multi_owned_matches_borrowed_with_options() { +fn test_contract_owned_with_options_matches_borrowed_with_options() { let batch = Index::new(DynId(70), 2); let i = Index::new(DynId(71), 2); let k = Index::new(DynId(72), 3); @@ -453,9 +451,9 @@ fn test_contract_multi_owned_matches_borrowed_with_options() { ); let retain_indices = [batch.clone()]; - let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); - let owned = contract_multi_owned(vec![a.clone(), b.clone()], options).unwrap(); - let borrowed = contract_multi_with_options(&[&a, &b], options).unwrap(); + let options = ContractionOptions::new().with_retain_indices(&retain_indices); + let owned = contract_owned_with_options(vec![a.clone(), b.clone()], options).unwrap(); + let borrowed = contract_with_options(&[&a, &b], options).unwrap(); assert_eq!(owned.indices(), borrowed.indices()); assert_eq!(owned.dims(), borrowed.dims()); @@ -466,16 +464,16 @@ fn test_contract_multi_owned_matches_borrowed_with_options() { } #[test] -fn test_contract_multi_owned_single_matches_borrowed_with_specified_pairs() { +fn test_contract_owned_with_options_single_matches_borrowed() { let a = make_test_tensor_from_data( &[2, 3], vec![Index::new(DynId(74), 2), Index::new(DynId(75), 3)], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], ); - let options = ContractionOptions::new(AllowedPairs::Specified(&[(0, 1)])); + let options = ContractionOptions::new(); - let owned = contract_multi_owned(vec![a.clone()], options).unwrap(); - let borrowed = contract_multi_with_options(&[&a], options).unwrap(); + let owned = contract_owned_with_options(vec![a.clone()], options).unwrap(); + let borrowed = contract_with_options(&[&a], options).unwrap(); assert_eq!(owned.indices(), borrowed.indices()); assert_eq!( @@ -485,17 +483,17 @@ fn test_contract_multi_owned_single_matches_borrowed_with_specified_pairs() { } #[test] -fn test_contract_multi_owned_falls_back_for_structured_storage() { +fn test_contract_owned_with_options_falls_back_for_structured_storage() { let i = Index::new(DynId(76), 3); let j = Index::new(DynId(77), 3); let k = Index::new(DynId(78), 3); let a = TensorDynLen::from_diag(vec![i.clone(), j.clone()], vec![1.0_f64, 2.0, 3.0]).unwrap(); let b = TensorDynLen::from_diag(vec![j, k.clone()], vec![4.0_f64, 5.0, 6.0]).unwrap(); - let options = ContractionOptions::new(AllowedPairs::All); + let options = ContractionOptions::new(); - let owned = contract_multi_owned(vec![a.clone(), b.clone()], options).unwrap(); - let borrowed = contract_multi_with_options(&[&a, &b], options).unwrap(); + let owned = contract_owned_with_options(vec![a.clone(), b.clone()], options).unwrap(); + let borrowed = contract_with_options(&[&a, &b], options).unwrap(); assert_eq!(owned.indices(), borrowed.indices()); assert_eq!(owned.storage().storage_kind(), StorageKind::Diagonal); @@ -505,7 +503,7 @@ fn test_contract_multi_owned_falls_back_for_structured_storage() { } #[test] -fn test_contract_multi_owned_supports_three_way_retained_label() { +fn test_contract_owned_with_options_supports_three_way_retained_label() { let batch = Index::new(DynId(80), 2); let i = Index::new(DynId(81), 2); let j = Index::new(DynId(82), 3); @@ -528,9 +526,10 @@ fn test_contract_multi_owned_supports_three_way_retained_label() { ); let retain_indices = [batch.clone()]; - let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); - let owned = contract_multi_owned(vec![a.clone(), b.clone(), c.clone()], options).unwrap(); - let borrowed = contract_multi_with_options(&[&a, &b, &c], options).unwrap(); + let options = ContractionOptions::new().with_retain_indices(&retain_indices); + let owned = + contract_owned_with_options(vec![a.clone(), b.clone(), c.clone()], options).unwrap(); + let borrowed = contract_with_options(&[&a, &b, &c], options).unwrap(); assert_eq!( owned.indices(), @@ -545,7 +544,7 @@ fn test_contract_multi_owned_supports_three_way_retained_label() { } #[test] -fn test_contract_multi_owned_falls_back_to_borrowed_for_grad_tensors() { +fn test_contract_owned_with_options_falls_back_to_borrowed_for_grad_tensors() { let batch = Index::new(DynId(90), 2); let i = Index::new(DynId(91), 2); let k = Index::new(DynId(92), 3); @@ -565,11 +564,11 @@ fn test_contract_multi_owned_falls_back_to_borrowed_for_grad_tensors() { ); let retain_indices = [batch.clone()]; - let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); - let owned = contract_multi_owned(vec![x.clone(), y], options).unwrap(); + let options = ContractionOptions::new().with_retain_indices(&retain_indices); + let owned = contract_owned_with_options(vec![x.clone(), y], options).unwrap(); let ones = TensorDynLen::from_dense(owned.indices().to_vec(), vec![1.0; 8]).unwrap(); - let loss = contract_multi(&[&owned, &ones], AllowedPairs::All).unwrap(); + let loss = contract(&[&owned, &ones]).unwrap(); loss.backward().unwrap(); let grad = x.grad().unwrap().unwrap(); @@ -580,13 +579,10 @@ fn test_contract_multi_owned_falls_back_to_borrowed_for_grad_tensors() { #[test] fn test_find_tensor_connected_components_trivial_cases() { let empty: Vec<&TensorDynLen> = Vec::new(); - assert!(find_tensor_connected_components(&empty, AllowedPairs::All).is_empty()); + assert!(find_tensor_connected_components(&empty).is_empty()); let a = make_test_tensor(&[2, 3], &[1, 2]); - assert_eq!( - find_tensor_connected_components(&[&a], AllowedPairs::All), - vec![vec![0]] - ); + assert_eq!(find_tensor_connected_components(&[&a]), vec![vec![0]]); } #[test] @@ -596,20 +592,11 @@ fn test_find_tensor_connected_components_multiple_components() { let c = make_test_tensor(&[5, 6], &[4, 5]); assert_eq!( - find_tensor_connected_components(&[&a, &b, &c], AllowedPairs::All), + find_tensor_connected_components(&[&a, &b, &c]), vec![vec![0, 1], vec![2]] ); } -#[test] -fn test_remap_allowed_pairs_filters_pairs_outside_component() { - let remapped = remap_allowed_pairs(AllowedPairs::Specified(&[(0, 1), (1, 3), (2, 3)]), &[1, 3]); - match remapped { - RemappedAllowedPairs::All => panic!("expected specified remapped pairs"), - RemappedAllowedPairs::Specified(pairs) => assert_eq!(pairs, vec![(0, 1)]), - } -} - #[test] fn test_union_find_remap_and_tensor_id_helpers() { let i = DynId(1); @@ -662,45 +649,44 @@ fn test_contract_profile_helpers() { } // ======================================================================== -// AllowedPairs::Specified tests +// Connected contraction semantics // ======================================================================== #[test] -fn test_contract_specified_pairs() { +fn test_contract_all_pairs() { // A[i,j], B[j,k], C[i,l] - tensors 0, 1, 2 let a = make_test_tensor(&[2, 3], &[1, 2]); // i=1, j=2 let b = make_test_tensor(&[3, 4], &[2, 3]); // j=2, k=3 let c = make_test_tensor(&[2, 5], &[1, 4]); // i=1, l=4 - let result = contract_multi(&[&a, &b, &c], AllowedPairs::Specified(&[(0, 1), (0, 2)])).unwrap(); + let result = contract(&[&a, &b, &c]).unwrap(); let mut sorted_dims = result.dims(); sorted_dims.sort(); assert_eq!(sorted_dims, vec![4, 5]); // k=4, l=5 } #[test] -fn test_contract_specified_no_contractable_indices_error() { +fn test_contract_disconnected_component_error() { let a = make_test_tensor(&[2, 3], &[1, 2]); // i=1, j=2 let b = make_test_tensor(&[3, 4], &[2, 3]); // j=2, k=3 let c = make_test_tensor(&[6, 5], &[5, 4]); // m=5, l=4 (no common with B) - let result = contract_multi(&[&a, &b, &c], AllowedPairs::Specified(&[(0, 1), (1, 2)])); + let result = contract(&[&a, &b, &c]); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("no contractable indices")); + let err_msg = result.unwrap_err().to_string().to_lowercase(); + assert!( + err_msg.contains("disconnected") || err_msg.contains("no contractable indices"), + "Expected error about disconnected tensors or missing contractable indices, got: {err_msg}" + ); } #[test] -fn test_contract_specified_disconnected_outer_product() { +fn test_contract_components_then_explicit_outer_product() { let a = make_test_tensor(&[2, 3], &[1, 2]); // i=1, j=2 let b = make_test_tensor(&[3, 4], &[2, 3]); // j=2, k=3 let c = make_test_tensor(&[4, 5], &[4, 5]); // m=4, n=5 let d = make_test_tensor(&[5, 6], &[5, 6]); // n=5, p=6 - let result = contract_multi( - &[&a, &b, &c, &d], - AllowedPairs::Specified(&[(0, 1), (2, 3)]), - ) - .unwrap(); + let left = contract(&[&a, &b]).unwrap(); + let right = contract(&[&c, &d]).unwrap(); + let result = outer_product(&left, &right).unwrap(); assert_eq!(result.dims().len(), 4); let mut sorted_dims = result.dims(); sorted_dims.sort(); @@ -1007,7 +993,7 @@ fn test_remap_preserves_order() { } // ======================================================================== -// contract_connected tests +// dense contract tests // ======================================================================== fn make_dense_tensor(shape: &[usize], ids: &[u64]) -> TensorDynLen { @@ -1024,35 +1010,35 @@ fn make_dense_tensor(shape: &[usize], ids: &[u64]) -> TensorDynLen { } #[test] -fn test_contract_connected_empty() { +fn test_contract_empty_dense_entry() { let tensors: Vec<&TensorDynLen> = vec![]; - let result = contract_connected(&tensors, AllowedPairs::All); + let result = contract(&tensors); assert!(result.is_err()); } #[test] -fn test_contract_connected_single() { +fn test_contract_single_dense_entry() { let tensor = make_dense_tensor(&[2, 3], &[1, 2]); - let result = contract_connected(&[&tensor], AllowedPairs::All).unwrap(); + let result = contract(&[&tensor]).unwrap(); assert_eq!(result.dims(), tensor.dims()); } #[test] -fn test_contract_connected_pair_dense() { +fn test_contract_pair_dense() { // A[i,j] * B[j,k] -> C[i,k] let a = make_dense_tensor(&[2, 3], &[1, 2]); // i=1, j=2 let b = make_dense_tensor(&[3, 4], &[2, 3]); // j=2, k=3 - let result = contract_connected(&[&a, &b], AllowedPairs::All).unwrap(); + let result = contract(&[&a, &b]).unwrap(); assert_eq!(result.dims(), vec![2, 4]); // i, k } #[test] -fn test_contract_connected_three_dense() { +fn test_contract_three_dense() { // A[i,j] * B[j,k] * C[k,l] -> D[i,l] let a = make_dense_tensor(&[2, 3], &[1, 2]); // i=1, j=2 let b = make_dense_tensor(&[3, 4], &[2, 3]); // j=2, k=3 let c = make_dense_tensor(&[4, 5], &[3, 4]); // k=3, l=4 - let result = contract_connected(&[&a, &b, &c], AllowedPairs::All).unwrap(); + let result = contract(&[&a, &b, &c]).unwrap(); let mut sorted_dims = result.dims(); sorted_dims.sort(); assert_eq!(sorted_dims, vec![2, 5]); // i=2, l=5 diff --git a/crates/tensor4all-core/src/defaults/factorize.rs b/crates/tensor4all-core/src/defaults/factorize.rs index 7a253df9..4589ff37 100644 --- a/crates/tensor4all-core/src/defaults/factorize.rs +++ b/crates/tensor4all-core/src/defaults/factorize.rs @@ -34,7 +34,7 @@ use crate::defaults::tensordynlen::unfold_split_inner; use crate::defaults::DynIndex; -use crate::{unfold_split, TensorDynLen}; +use crate::{contract_pair, unfold_split, TensorDynLen}; use num_complex::{Complex64, ComplexFloat}; use tensor4all_tcicore::{rrlu, AbstractMatrixCI, MatrixLUCI, RrLUOptions, Scalar as MatrixScalar}; use tensor4all_tensorbackend::{Matrix, TensorElement}; @@ -123,7 +123,7 @@ pub fn factorize( /// /// ``` /// use tensor4all_core::{ -/// factorize_full_rank, Canonical, DynIndex, FactorizeAlg, TensorDynLen, TensorLike, +/// factorize_full_rank, Canonical, DynIndex, FactorizeAlg, TensorContractionLike, TensorDynLen, /// }; /// /// let i = DynIndex::new_dyn(2); @@ -139,7 +139,7 @@ pub fn factorize( /// FactorizeAlg::QR, /// Canonical::Left, /// )?; -/// let reconstructed = result.left.contract(&result.right).unwrap(); +/// let reconstructed = result.left.contract_pair(&result.right).unwrap(); /// assert!(tensor.sub(&reconstructed)?.maxabs() < 1.0e-18); /// # Ok::<(), Box>(()) /// ``` @@ -264,7 +264,7 @@ fn factorize_svd_with_options( match canonical { Canonical::Left => { // L = U (orthogonal), R = S * V^H - let right_contracted = s.contract(&vh)?; + let right_contracted = contract_pair(&s, &vh)?; let right = right_contracted.replaceind(&sim_bond_index, &bond_index)?; Ok(FactorizeResult { left: u, @@ -276,7 +276,7 @@ fn factorize_svd_with_options( } Canonical::Right => { // L = U * S, R = V^H - let left_contracted = u.contract(&s)?; + let left_contracted = contract_pair(&u, &s)?; let left = left_contracted.replaceind(&sim_bond_index, &bond_index)?; Ok(FactorizeResult { left, diff --git a/crates/tensor4all-core/src/defaults/mod.rs b/crates/tensor4all-core/src/defaults/mod.rs index 9a096a3e..5a7cb6aa 100644 --- a/crates/tensor4all-core/src/defaults/mod.rs +++ b/crates/tensor4all-core/src/defaults/mod.rs @@ -32,14 +32,16 @@ pub mod qr; pub mod svd; pub use contract::{ - build_diag_union, collect_sizes, contract_connected, contract_connected_with_options, - contract_multi, contract_multi_owned, contract_multi_with_options, + build_diag_union, collect_sizes, contract, contract_owned, contract_owned_with_options, + contract_pair, contract_pair_with_options, contract_with_options, outer_product, print_and_reset_contract_profile, remap_output_ids, remap_tensor_ids, reset_contract_profile, - AxisUnionFind, ContractionOptions, + tensordot, AxisUnionFind, ContractionOptions, }; pub use index::{DefaultIndex, DefaultTagSet, DynId, DynIndex, Index, TagSet}; pub use tensordynlen::{ - compute_permutation_from_indices, diag_tensor_dyn_len, unfold_split, TensorDynLen, + compute_permutation_from_indices, diag_tensor_dyn_len, + print_and_reset_pairwise_contract_profile, reset_pairwise_contract_profile, unfold_split, + TensorDynLen, }; // Re-export linear algebra functions and types diff --git a/crates/tensor4all-core/src/defaults/qr.rs b/crates/tensor4all-core/src/defaults/qr.rs index bccf3a48..c6c4394b 100644 --- a/crates/tensor4all-core/src/defaults/qr.rs +++ b/crates/tensor4all-core/src/defaults/qr.rs @@ -30,7 +30,7 @@ pub enum QrError { /// /// ``` /// use tensor4all_core::qr::{QrOptions, qr_with}; -/// use tensor4all_core::{DynIndex, TensorDynLen}; +/// use tensor4all_core::{DynIndex, TensorContractionLike, TensorDynLen}; /// /// let i = DynIndex::new_dyn(3); /// let j = DynIndex::new_dyn(3); @@ -41,7 +41,7 @@ pub enum QrError { /// let (q, r) = qr_with::(&tensor, &[i], &opts).unwrap(); /// /// // Q * R recovers the original tensor -/// let recovered = q.contract(&r).unwrap(); +/// let recovered = q.contract_pair(&r).unwrap(); /// assert!(tensor.distance(&recovered).unwrap() < 1e-12); /// ``` #[derive(Debug, Clone, Copy)] diff --git a/crates/tensor4all-core/src/defaults/structured_contraction.rs b/crates/tensor4all-core/src/defaults/structured_contraction.rs index b802d44f..edc6cffc 100644 --- a/crates/tensor4all-core/src/defaults/structured_contraction.rs +++ b/crates/tensor4all-core/src/defaults/structured_contraction.rs @@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet}; use anyhow::{anyhow, ensure, Result}; use num_complex::Complex64; use tenferro::{DType, Tensor as NativeTensor}; -use tensor4all_tensorbackend::{einsum_native_tensors, Storage}; +use tensor4all_tensorbackend::{einsum_native_tensors, NativeTensorReadInput, Storage}; #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct OperandLayout { @@ -262,6 +262,7 @@ fn unique_first_appearance(values: &[usize]) -> Vec { result } +#[allow(dead_code)] pub(crate) fn normalize_payload_for_roots( payload: &NativeTensor, roots: &[usize], @@ -295,6 +296,39 @@ pub(crate) fn normalize_payload_for_roots( Ok((current_payload, current_roots)) } +pub(crate) fn normalize_payload_read_for_roots<'a>( + payload: NativeTensorReadInput<'a>, + roots: &[usize], +) -> Result<(NativeTensorReadInput<'a>, Vec)> { + ensure!( + payload.shape().len() == roots.len(), + "payload rank {} does not match root label count {}", + payload.shape().len(), + roots.len() + ); + + if unique_first_appearance(roots).len() == roots.len() { + return Ok((payload, roots.to_vec())); + } + + let mut current_payload = payload.as_read().to_tensor(); + let mut current_roots = roots.to_vec(); + while let Some((axis_a, axis_b)) = first_duplicate_pair(¤t_roots) { + let mut input_ids: Vec = (0..current_roots.len()).collect(); + input_ids[axis_b] = input_ids[axis_a]; + let output_ids: Vec = input_ids + .iter() + .enumerate() + .filter_map(|(axis, &label)| (axis != axis_b).then_some(label)) + .collect(); + + current_payload = einsum_native_tensors(&[(¤t_payload, &input_ids)], &output_ids)?; + current_roots.remove(axis_b); + } + + Ok((NativeTensorReadInput::Owned(current_payload), current_roots)) +} + pub(crate) fn storage_payload_native(storage: &Storage) -> Result { if storage.is_f64() { Ok(NativeTensor::from_vec( diff --git a/crates/tensor4all-core/src/defaults/tensordynlen.rs b/crates/tensor4all-core/src/defaults/tensordynlen.rs index c83c6fc6..12bec45e 100644 --- a/crates/tensor4all-core/src/defaults/tensordynlen.rs +++ b/crates/tensor4all-core/src/defaults/tensordynlen.rs @@ -8,24 +8,130 @@ use num_complex::Complex64; use num_traits::Zero; use rand::Rng; use rand_distr::{Distribution, StandardNormal}; -use std::collections::HashSet; +use std::cell::RefCell; +use std::cmp::Reverse; +use std::collections::{HashMap, HashSet}; +use std::env; use std::sync::{Arc, OnceLock}; -use tenferro::eager_einsum::eager_einsum_ad; -use tenferro::{CpuBackend, DType, EagerTensor, Tensor as NativeTensor}; +use std::time::{Duration, Instant}; +use tenferro::eager_tensor::einsum_subscripts as eager_einsum_ad; +use tenferro::{ + CpuBackend, DType, DotGeneralConfig, EagerTensor, EinsumSubscripts, Tensor as NativeTensor, +}; use tensor4all_tensorbackend::{ axpby_native_tensor, contract_native_tensor, default_eager_ctx, dense_native_tensor_from_col_major, diag_native_tensor_from_col_major, native_tensor_primal_to_dense_col_major, native_tensor_primal_to_diag_c64, native_tensor_primal_to_diag_f64, native_tensor_primal_to_storage, scale_native_tensor, - storage_to_native_tensor, StorageScalar, TensorElement, + storage_payload_native_read_input, storage_to_native_tensor, AnyScalar as BackendScalar, + StorageScalar, TensorElement, }; use tensor4all_tensorbackend::{Storage, StorageKind}; +use super::contract::PairwiseContractionOptions; use super::structured_contraction::{ - normalize_payload_for_roots, storage_from_payload_native, storage_payload_native, + normalize_payload_read_for_roots, storage_from_payload_native, storage_payload_native, OperandLayout, StructuredContractionPlan, StructuredContractionSpec, }; +#[derive(Debug, Default, Clone)] +struct PairwiseContractProfileEntry { + calls: usize, + total_time: Duration, + total_bytes: usize, +} + +thread_local! { + static PAIRWISE_CONTRACT_PROFILE_STATE: RefCell> = + RefCell::new(HashMap::new()); +} + +fn pairwise_contract_profile_enabled() -> bool { + static ENABLED: OnceLock = OnceLock::new(); + *ENABLED.get_or_init(|| env::var("T4A_PROFILE_PAIRWISE_CONTRACT").is_ok()) +} + +fn record_pairwise_contract_profile(section: &'static str, elapsed: Duration) { + if !pairwise_contract_profile_enabled() { + return; + } + PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| { + let mut state = state.borrow_mut(); + let entry = state.entry(section).or_default(); + entry.calls += 1; + entry.total_time += elapsed; + }); +} + +fn record_pairwise_contract_profile_bytes(section: &'static str, bytes: usize) { + if !pairwise_contract_profile_enabled() { + return; + } + PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| { + let mut state = state.borrow_mut(); + let entry = state.entry(section).or_default(); + entry.total_bytes += bytes; + }); +} + +fn profile_pairwise_contract_section(section: &'static str, f: impl FnOnce() -> T) -> T { + if !pairwise_contract_profile_enabled() { + return f(); + } + let started = Instant::now(); + let result = f(); + record_pairwise_contract_profile(section, started.elapsed()); + result +} + +/// Reset the aggregated pairwise `TensorDynLen` contraction profile. +pub fn reset_pairwise_contract_profile() { + PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| state.borrow_mut().clear()); +} + +/// Print and clear the aggregated pairwise `TensorDynLen` contraction profile. +pub fn print_and_reset_pairwise_contract_profile() { + if !pairwise_contract_profile_enabled() { + return; + } + PAIRWISE_CONTRACT_PROFILE_STATE.with(|state| { + let mut entries: Vec<_> = state + .borrow() + .iter() + .map(|(section, entry)| (*section, entry.clone())) + .collect(); + state.borrow_mut().clear(); + entries.sort_by_key(|(_, entry)| Reverse(entry.total_time)); + + eprintln!("=== TensorDynLen pairwise contract profile ==="); + for (section, entry) in entries { + let per_call_us = if entry.calls == 0 { + 0.0 + } else { + entry.total_time.as_secs_f64() * 1.0e6 / entry.calls as f64 + }; + eprintln!( + "{section}: calls={} total={:.6}ms per_call={:.3}us bytes={}", + entry.calls, + entry.total_time.as_secs_f64() * 1.0e3, + per_call_us, + entry.total_bytes, + ); + } + }); +} + +fn native_tensor_profile_bytes(native: &NativeTensor) -> usize { + let element_size = match native.dtype() { + DType::F32 => 4, + DType::F64 => 8, + DType::C32 => 8, + DType::C64 => 16, + DType::I64 => 8, + }; + native.shape().iter().product::() * element_size +} + /// Trait for scalar types that can generate random values from a standard /// normal distribution. /// @@ -115,6 +221,148 @@ pub(crate) struct StructuredAdValue { axis_classes: Vec, } +#[derive(Clone)] +pub(crate) enum TensorDynLenStorage { + Materialized(Arc), + Eager { + inner: Arc>, + axis_classes: Vec, + }, +} + +impl TensorDynLenStorage { + fn from_storage(storage: Arc) -> Self { + Self::Materialized(storage) + } + + fn from_eager_dense(inner: EagerTensor, rank: usize) -> Self { + Self::Eager { + inner: Arc::new(inner), + axis_classes: TensorDynLen::dense_axis_classes(rank), + } + } + + fn eager(&self) -> Option<&EagerTensor> { + match self { + Self::Materialized(_) => None, + Self::Eager { inner, .. } => Some(inner.as_ref()), + } + } + + fn axis_classes(&self) -> &[usize] { + match self { + Self::Materialized(storage) => storage.axis_classes(), + Self::Eager { axis_classes, .. } => axis_classes, + } + } + + fn payload_dims(&self) -> &[usize] { + match self { + Self::Materialized(storage) => storage.payload_dims(), + Self::Eager { inner, .. } => inner.data().shape(), + } + } + + fn payload_strides_vec(&self) -> Vec { + match self { + Self::Materialized(storage) => storage.payload_strides().to_vec(), + Self::Eager { inner, .. } => { + let mut stride = 1isize; + inner + .data() + .shape() + .iter() + .map(|&dim| { + let current = stride; + stride *= isize::try_from(dim).unwrap_or(isize::MAX); + current + }) + .collect() + } + } + } + + fn is_f64(&self) -> bool { + match self { + Self::Materialized(storage) => storage.is_f64(), + Self::Eager { inner, .. } => inner.data().dtype() == DType::F64, + } + } + + fn is_c64(&self) -> bool { + match self { + Self::Materialized(storage) => storage.is_c64(), + Self::Eager { inner, .. } => inner.data().dtype() == DType::C64, + } + } + + fn is_complex(&self) -> bool { + match self { + Self::Materialized(storage) => storage.is_complex(), + Self::Eager { inner, .. } => matches!(inner.data().dtype(), DType::C32 | DType::C64), + } + } + + fn is_diag(&self) -> bool { + match self { + Self::Materialized(storage) => storage.is_diag(), + Self::Eager { axis_classes, .. } => TensorDynLen::is_diag_axis_classes(axis_classes), + } + } + + fn storage_kind(&self) -> StorageKind { + match self { + Self::Materialized(storage) => storage.storage_kind(), + Self::Eager { axis_classes, .. } => { + if axis_classes.iter().copied().eq(0..axis_classes.len()) { + StorageKind::Dense + } else if TensorDynLen::is_diag_axis_classes(axis_classes) { + StorageKind::Diagonal + } else { + StorageKind::Structured + } + } + } + } + + fn materialize(&self, logical_rank: usize) -> Result> { + match self { + Self::Materialized(storage) => Ok(Arc::clone(storage)), + Self::Eager { + inner, + axis_classes, + } => Ok(Arc::new( + TensorDynLen::storage_from_native_with_axis_classes( + inner.data(), + axis_classes, + logical_rank, + )?, + )), + } + } + + fn scale(&self, scalar: &BackendScalar) -> Result { + Ok(self.materialize(self.axis_classes().len())?.scale(scalar)) + } + + fn conj(&self) -> Result { + match self { + Self::Materialized(storage) => Ok(Self::Materialized(Arc::new(storage.conj()))), + Self::Eager { + inner, + axis_classes, + } => Ok(Self::Eager { + inner: Arc::new(inner.conj()?), + axis_classes: axis_classes.clone(), + }), + } + } + + fn max_abs(&self) -> Result { + Ok(self.materialize(self.axis_classes().len())?.max_abs()) + } +} + /// Dynamic-rank tensor with structured payload storage -- the central data type /// of tensor4all. /// @@ -133,7 +381,7 @@ pub(crate) struct StructuredAdValue { /// | Extract data | [`to_vec`](Self::to_vec), [`into_dense_col_major_parts`](Self::into_dense_col_major_parts), [`sum`](Self::sum), [`only`](Self::only) | /// | Contraction | [`contract`](Self::contract) | /// | Arithmetic | [`add`](Self::add), [`scale`](Self::scale), [`axpby`](Self::axpby) | -/// | Factorization | via [`TensorLike::factorize`] | +/// | Factorization | via [`TensorFactorizationLike::factorize`](crate::TensorFactorizationLike::factorize) | /// | Norms | [`norm`](Self::norm), [`norm_squared`](Self::norm_squared), [`maxabs`](Self::maxabs) | /// | Index ops | [`replaceind`](Self::replaceind), [`permute_indices`](Self::permute_indices) | /// @@ -171,7 +419,7 @@ pub struct TensorDynLen { /// Full index information (includes tags and other metadata). pub indices: Vec, /// Authoritative compact payload storage. - pub(crate) storage: Arc, + pub(crate) storage: TensorDynLenStorage, /// Optional tracked compact payload used to preserve structured AD layouts. pub(crate) structured_ad: Option>, /// Lazily materialized eager payload for native execution and AD. @@ -179,8 +427,6 @@ pub struct TensorDynLen { } impl TensorDynLen { - const EINSUM_LABELS: &'static [u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; - fn dense_axis_classes(rank: usize) -> Vec { (0..rank).collect() } @@ -231,15 +477,30 @@ impl TensorDynLen { axis_classes.len() >= 2 && axis_classes.iter().all(|&class_id| class_id == 0) } - fn einsum_labels(ids: &[usize]) -> Result { - let mut out = String::with_capacity(ids.len()); - for &id in ids { - let label = Self::EINSUM_LABELS.get(id).ok_or_else(|| { - anyhow::anyhow!("einsum label {id} exceeds supported label range") - })?; - out.push(char::from(*label)); - } - Ok(out) + fn einsum_subscripts_from_usize_ids( + inputs: &[Vec], + output: &[usize], + ) -> Result { + let input_labels = inputs + .iter() + .map(|ids| { + ids.iter() + .map(|&id| { + u32::try_from(id) + .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range")) + }) + .collect::>>() + }) + .collect::>>()?; + let output_labels = output + .iter() + .map(|&id| { + u32::try_from(id) + .map_err(|_| anyhow::anyhow!("einsum label {id} exceeds u32 range")) + }) + .collect::>>()?; + let input_refs = input_labels.iter().map(Vec::as_slice).collect::>(); + Ok(EinsumSubscripts::new(&input_refs, &output_labels)) } fn build_binary_einsum_subscripts( @@ -247,7 +508,7 @@ impl TensorDynLen { axes_a: &[usize], rhs_rank: usize, axes_b: &[usize], - ) -> Result { + ) -> Result { anyhow::ensure!( axes_a.len() == axes_b.len(), "contract axis length mismatch: lhs {:?}, rhs {:?}", @@ -302,12 +563,22 @@ impl TensorDynLen { } } - Ok(format!( - "{},{}->{}", - Self::einsum_labels(&lhs_ids)?, - Self::einsum_labels(&rhs_ids)?, - Self::einsum_labels(&output_ids)?, - )) + Self::einsum_subscripts_from_usize_ids(&[lhs_ids, rhs_ids], &output_ids) + } + + fn binary_dot_general_config(axes_a: &[usize], axes_b: &[usize]) -> Result { + anyhow::ensure!( + axes_a.len() == axes_b.len(), + "contract axis length mismatch: lhs {:?}, rhs {:?}", + axes_a, + axes_b + ); + Ok(DotGeneralConfig { + lhs_contracting_dims: axes_a.to_vec(), + rhs_contracting_dims: axes_b.to_vec(), + lhs_batch_dims: vec![], + rhs_batch_dims: vec![], + }) } fn binary_contraction_axis_classes( @@ -395,14 +666,9 @@ impl TensorDynLen { axis_classes } - fn scale_subscripts(rank: usize) -> Result { - if rank == 0 { - Ok("->".to_string()) - } else { - let ids: Vec = (0..rank).collect(); - let labels = Self::einsum_labels(&ids)?; - Ok(format!("{labels},->{labels}")) - } + fn scale_subscripts(rank: usize) -> Result { + let ids: Vec = (0..rank).collect(); + Self::einsum_subscripts_from_usize_ids(&[ids.clone(), Vec::new()], &ids) } fn validate_indices(indices: &[DynIndex]) -> Result<()> { @@ -447,7 +713,7 @@ impl TensorDynLen { fn compact_payload_inner(&self) -> Result> { Ok(EagerTensor::from_tensor_in( - storage_payload_native(self.storage.as_ref())?, + storage_payload_native(self.storage.materialize(self.indices.len())?.as_ref())?, default_eager_ctx(), )) } @@ -474,6 +740,14 @@ impl TensorDynLen { Ok(()) } + fn operand_indices_for_contraction(&self, conjugate: bool) -> Vec { + if conjugate { + self.indices.iter().map(|index| index.conj()).collect() + } else { + self.indices.clone() + } + } + fn build_binary_contraction_labels( lhs_rank: usize, axes_a: &[usize], @@ -540,13 +814,8 @@ impl TensorDynLen { fn build_payload_einsum_subscripts( input_roots: &[Vec], output_roots: &[usize], - ) -> Result { - let input_labels = input_roots - .iter() - .map(|roots| Self::einsum_labels(roots)) - .collect::>>()?; - let output = Self::einsum_labels(output_roots)?; - Ok(format!("{}->{}", input_labels.join(","), output)) + ) -> Result { + Self::einsum_subscripts_from_usize_ids(input_roots, output_roots) } fn normalize_eager_payload_for_roots( @@ -629,7 +898,7 @@ impl TensorDynLen { Self::validate_storage_matches_indices(&indices, &storage)?; Ok(Self { indices, - storage: Arc::new(storage), + storage: TensorDynLenStorage::from_storage(Arc::new(storage)), structured_ad: Some(Arc::new(StructuredAdValue { payload: Arc::new(payload_inner), payload_dims, @@ -743,16 +1012,18 @@ impl TensorDynLen { ); } - let lhs = storage_payload_native(self.storage.as_ref())?; - let rhs = storage_payload_native(other.storage.as_ref())?; + let lhs_storage = self.storage.materialize(self.indices.len())?; + let rhs_storage = other.storage.materialize(other.indices.len())?; + let lhs = storage_payload_native_read_input(lhs_storage.as_ref())?; + let rhs = storage_payload_native_read_input(rhs_storage.as_ref())?; if lhs.dtype() != rhs.dtype() { return Err(anyhow::anyhow!( "structured payload contraction requires matching payload dtypes" )); } - let (lhs, lhs_labels) = normalize_payload_for_roots(&lhs, &lhs_roots)?; - let (rhs, rhs_labels) = normalize_payload_for_roots(&rhs, &rhs_roots)?; - let payload = tensor4all_tensorbackend::einsum_native_tensors( + let (lhs, lhs_labels) = normalize_payload_read_for_roots(lhs, &lhs_roots)?; + let (rhs, rhs_labels) = normalize_payload_read_for_roots(rhs, &rhs_roots)?; + let payload = tensor4all_tensorbackend::einsum_native_tensor_reads( &[(&lhs, lhs_labels.as_slice()), (&rhs, rhs_labels.as_slice())], &plan.output_payload_roots, )?; @@ -836,15 +1107,15 @@ impl TensorDynLen { positions: &[usize], ) -> Result { if self.storage.is_f64() { - let payload = self - .storage + let storage = self.storage.materialize(self.indices.len())?; + let payload = storage .payload_f64_col_major_vec() .map_err(anyhow::Error::msg)?; let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions); Self::from_dense(kept_indices, data) } else if self.storage.is_c64() { - let payload = self - .storage + let storage = self.storage.materialize(self.indices.len())?; + let payload = storage .payload_c64_col_major_vec() .map_err(anyhow::Error::msg)?; let data = Self::dense_selected_diag_payload(payload, &kept_dims, positions); @@ -1033,8 +1304,8 @@ impl TensorDynLen { positions: &[usize], ) -> Result { if self.storage.is_f64() { - let payload = self - .storage + let storage = self.storage.materialize(self.indices.len())?; + let payload = storage .payload_f64_col_major_vec() .map_err(anyhow::Error::msg)?; self.select_structured_indices_typed( @@ -1046,8 +1317,8 @@ impl TensorDynLen { positions, ) } else if self.storage.is_c64() { - let payload = self - .storage + let storage = self.storage.materialize(self.indices.len())?; + let payload = storage .payload_c64_col_major_vec() .map_err(anyhow::Error::msg)?; self.select_structured_indices_typed( @@ -1087,9 +1358,20 @@ impl TensorDynLen { return Ok(value.payload.as_ref()); } } + if let Some(inner) = self.storage.eager() { + return Ok(inner); + } if self.eager_cache.get().is_none() { - let native = Self::seed_native_payload(self.storage.as_ref(), &self.dims()) - .context("TensorDynLen materialization failed")?; + let dims = self.dims(); + let native = profile_pairwise_contract_section("materialize_storage_to_native", || { + let storage = self.storage.materialize(self.indices.len())?; + Self::seed_native_payload(storage.as_ref(), &dims) + }) + .context("TensorDynLen materialization failed")?; + record_pairwise_contract_profile_bytes( + "materialize_storage_to_native", + native_tensor_profile_bytes(&native), + ); let _ = self.eager_cache.set(Arc::new(EagerTensor::from_tensor_in( native, default_eager_ctx(), @@ -1474,7 +1756,7 @@ impl TensorDynLen { Self::validate_storage_matches_indices(&indices, storage.as_ref())?; Ok(Self { indices, - storage, + storage: TensorDynLenStorage::from_storage(storage), structured_ad: None, eager_cache: Self::empty_eager_cache(), }) @@ -1551,8 +1833,12 @@ impl TensorDynLen { inner: EagerTensor, axis_classes: Vec, ) -> Result { - let dims = Self::expected_dims_from_indices(&indices); - Self::validate_indices(&indices)?; + let dims = profile_pairwise_contract_section("from_inner_expected_dims", || { + Self::expected_dims_from_indices(&indices) + }); + profile_pairwise_contract_section("from_inner_validate_indices", || { + Self::validate_indices(&indices) + })?; if dims != inner.data().shape() { return Err(anyhow::anyhow!( "native payload dims {:?} do not match indices dims {:?}", @@ -1561,18 +1847,39 @@ impl TensorDynLen { )); } if Self::is_diag_axis_classes(&axis_classes) { - Self::validate_diag_dims(&dims)?; + profile_pairwise_contract_section("from_inner_validate_diag_dims", || { + Self::validate_diag_dims(&dims) + })?; } - let storage = Self::storage_from_native_with_axis_classes( - inner.data(), - &axis_classes, - indices.len(), - )?; + let (storage, eager_cache) = if axis_classes == Self::dense_axis_classes(indices.len()) { + ( + TensorDynLenStorage::from_eager_dense(inner, indices.len()), + Self::empty_eager_cache(), + ) + } else { + let storage = profile_pairwise_contract_section("from_inner_storage_snapshot", || { + Self::storage_from_native_with_axis_classes( + inner.data(), + &axis_classes, + indices.len(), + ) + })?; + record_pairwise_contract_profile_bytes( + "from_inner_storage_snapshot", + native_tensor_profile_bytes(inner.data()), + ); + ( + TensorDynLenStorage::from_storage(Arc::new(storage)), + profile_pairwise_contract_section("from_inner_eager_cache", || { + Self::eager_cache_with(inner) + }), + ) + }; Ok(Self { indices, - storage: Arc::new(storage), + storage, structured_ad: None, - eager_cache: Self::eager_cache_with(inner), + eager_cache, }) } @@ -1588,7 +1895,8 @@ impl TensorDynLen { /// Enable reverse-mode AD tracking on this tensor by creating a tracked leaf. pub fn enable_grad(self) -> Result { - let payload = storage_payload_native(self.storage.as_ref()) + let materialized = self.storage.materialize(self.indices.len())?; + let payload = storage_payload_native(materialized.as_ref()) .context("TensorDynLen::enable_grad failed")?; let payload_dims = self.storage.payload_dims().to_vec(); let axis_classes = self.storage.axis_classes().to_vec(); @@ -1609,6 +1917,7 @@ impl TensorDynLen { self.structured_ad .as_ref() .is_some_and(|value| value.payload.tracks_grad()) + || self.storage.eager().is_some_and(EagerTensor::tracks_grad) || self .eager_cache .get() @@ -1648,6 +1957,9 @@ impl TensorDynLen { if let Some(value) = self.tracked_compact_payload_value() { value.payload.clear_grad(); } + if let Some(inner) = self.storage.eager() { + inner.clear_grad(); + } if let Some(inner) = self.eager_cache.get() { inner.clear_grad(); } @@ -1672,7 +1984,10 @@ impl TensorDynLen { /// Detach this tensor from the reverse graph. pub fn detach(&self) -> Result { if self.tracked_compact_payload_value().is_some() { - return Self::from_storage(self.indices.clone(), Arc::clone(&self.storage)); + return Self::from_storage( + self.indices.clone(), + self.storage.materialize(self.indices.len())?, + ); } Self::from_inner_with_axis_classes( self.indices.clone(), @@ -1688,12 +2003,14 @@ impl TensorDynLen { /// Materialize the primal snapshot as storage. pub fn to_storage(&self) -> Result> { - Ok(Arc::clone(&self.storage)) + self.storage.materialize(self.indices.len()) } /// Returns the authoritative compact storage. pub fn storage(&self) -> Arc { - Arc::clone(&self.storage) + self.storage + .materialize(self.indices.len()) + .expect("TensorDynLen storage materialization failed") } /// Sum all elements, returning `AnyScalar`. @@ -1785,7 +2102,7 @@ impl TensorDynLen { if perm.iter().copied().eq(0..perm.len()) { return Ok(Self { indices: new_indices.to_vec(), - storage: Arc::clone(&self.storage), + storage: self.storage.clone(), structured_ad: self.structured_ad.clone(), eager_cache: Arc::clone(&self.eager_cache), }); @@ -1848,194 +2165,154 @@ impl TensorDynLen { Self::from_inner_with_axis_classes(new_indices, permuted, axis_classes) } - /// Contract this tensor with another tensor along common indices. - /// - /// This method finds common indices between `self` and `other`, then contracts - /// along those indices. The result tensor contains all non-contracted indices - /// from both tensors, with indices from `self` appearing first, followed by - /// indices from `other` that are not common. - /// - /// # Arguments - /// * `other` - The tensor to contract with - /// - /// # Returns - /// A new tensor resulting from the contraction. - /// - /// # Panics - /// Panics if there are no common indices, if common indices have mismatched - /// dimensions, or if storage types don't match. - /// - /// # Example - /// ``` - /// use tensor4all_core::TensorDynLen; - /// use tensor4all_core::index::{DefaultIndex as Index, DynId}; - /// - /// // Create two tensors: A[i, j] and B[j, k] - /// let i = Index::new_dyn(2); - /// let j = Index::new_dyn(3); - /// let k = Index::new_dyn(4); - /// - /// let indices_a = vec![i.clone(), j.clone()]; - /// let tensor_a: TensorDynLen = TensorDynLen::from_dense(indices_a, vec![0.0; 6]).unwrap(); - /// - /// let indices_b = vec![j.clone(), k.clone()]; - /// let tensor_b: TensorDynLen = TensorDynLen::from_dense(indices_b, vec![0.0; 12]).unwrap(); - /// - /// // Contract along j with the default pairwise semantics: result is C[i, k] - /// let result = tensor_a.contract(&tensor_b).unwrap(); - /// assert_eq!(result.dims(), vec![2, 4]); - /// ``` - pub fn contract(&self, other: &Self) -> Result { - self.try_contract_pairwise_default(other) + pub(crate) fn try_contract_pairwise_default(&self, other: &Self) -> Result { + self.try_contract_pairwise_default_with_options(other, PairwiseContractionOptions::new()) } - /// Contract this tensor with another tensor using explicit contraction options. - /// - /// # Arguments - /// * `other` - The tensor to contract with. - /// * `options` - Pair-selection policy and retained indices. - /// - /// # Returns - /// The contracted tensor, or an error if the contraction cannot be built. - /// - /// # Errors - /// Returns an error if the tensors are disconnected, retained indices are - /// invalid, or the contraction plan cannot be executed. - /// - /// # Examples - /// ``` - /// use tensor4all_core::{AllowedPairs, ContractionOptions, DynIndex, TensorDynLen}; - /// - /// let batch = DynIndex::new_dyn(2); - /// let i = DynIndex::new_dyn(2); - /// let k = DynIndex::new_dyn(3); - /// let j = DynIndex::new_dyn(2); - /// - /// let a = TensorDynLen::from_dense( - /// vec![batch.clone(), i.clone(), k.clone()], - /// vec![1.0_f64; 12], - /// ).unwrap(); - /// let b = TensorDynLen::from_dense( - /// vec![batch.clone(), k.clone(), j.clone()], - /// vec![1.0_f64; 12], - /// ).unwrap(); - /// let retain = [batch.clone()]; - /// let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain); - /// let result = a.contract_with_options(&b, options).unwrap(); - /// - /// assert_eq!(result.indices(), &[batch, i, j]); - /// assert_eq!(result.dims(), vec![2, 2, 2]); - /// assert_eq!(result.to_vec::().unwrap(), vec![3.0; 8]); - /// ``` - pub fn contract_with_options( + pub(crate) fn try_contract_pairwise_default_with_options( &self, other: &Self, - options: crate::defaults::contract::ContractionOptions<'_>, + options: PairwiseContractionOptions, ) -> Result { - crate::defaults::contract::contract_multi_with_options(&[self, other], options) - } - - pub(crate) fn try_contract_pairwise_default(&self, other: &Self) -> Result { - let self_dims = Self::expected_dims_from_indices(&self.indices); - let other_dims = Self::expected_dims_from_indices(&other.indices); - let spec = prepare_contraction(&self.indices, &self_dims, &other.indices, &other_dims) - .context("contraction preparation failed")?; - let result_axis_classes = Self::binary_contraction_axis_classes( - self.storage.axis_classes(), - &spec.axes_a, - other.storage.axis_classes(), - &spec.axes_b, - ); - - if self.should_use_structured_payload_contract(other) { - return self.contract_structured_payloads( - other, - spec.result_indices, + let self_indices = profile_pairwise_contract_section("operand_indices", || { + self.operand_indices_for_contraction(options.lhs_conj) + }); + let other_indices = profile_pairwise_contract_section("operand_indices", || { + other.operand_indices_for_contraction(options.rhs_conj) + }); + let self_dims = profile_pairwise_contract_section("expected_dims", || { + Self::expected_dims_from_indices(&self_indices) + }); + let other_dims = profile_pairwise_contract_section("expected_dims", || { + Self::expected_dims_from_indices(&other_indices) + }); + let spec = profile_pairwise_contract_section("prepare_contraction", || { + prepare_contraction(&self_indices, &self_dims, &other_indices, &other_dims) + }) + .context("contraction preparation failed")?; + let result_axis_classes = profile_pairwise_contract_section("result_axis_classes", || { + Self::binary_contraction_axis_classes( + self.storage.axis_classes(), &spec.axes_a, + other.storage.axis_classes(), &spec.axes_b, - ); + ) + }); + + if profile_pairwise_contract_section("structured_check", || { + self.should_use_structured_payload_contract(other) + }) { + if options.has_conj() { + let lhs = if options.lhs_conj { + self.conj() + } else { + self.clone() + }; + let rhs = if options.rhs_conj { + other.conj() + } else { + other.clone() + }; + return profile_pairwise_contract_section("structured_conj_fallback", || { + lhs.try_contract_pairwise_default(&rhs) + }); + } + return profile_pairwise_contract_section("structured_payload_contract", || { + self.contract_structured_payloads( + other, + spec.result_indices.into_vec(), + &spec.axes_a, + &spec.axes_b, + ) + }); } if self.indices.is_empty() && other.indices.is_empty() { - let result = self - .try_materialized_inner()? - .mul(other.try_materialized_inner()?)?; - return Self::from_inner(spec.result_indices, result); + if options.has_conj() { + let lhs = if options.lhs_conj { + self.conj() + } else { + self.clone() + }; + let rhs = if options.rhs_conj { + other.conj() + } else { + other.clone() + }; + return lhs.try_contract_pairwise_default(&rhs); + } + let result = profile_pairwise_contract_section("scalar_mul", || { + Ok::<_, anyhow::Error>( + self.try_materialized_inner()? + .mul(other.try_materialized_inner()?)?, + ) + })?; + return profile_pairwise_contract_section("from_inner", || { + Self::from_inner(spec.result_indices.into_vec(), result) + }); } - let self_native = self.as_native()?; - let other_native = other.as_native()?; + let self_native = profile_pairwise_contract_section("as_native", || self.as_native())?; + let other_native = profile_pairwise_contract_section("as_native", || other.as_native())?; if self_native.dtype() != other_native.dtype() { - let result_native = - contract_native_tensor(self_native, &spec.axes_a, other_native, &spec.axes_b)?; - return Self::from_native_with_axis_classes( - spec.result_indices, - result_native, - result_axis_classes, - ); + if options.has_conj() { + let lhs = if options.lhs_conj { + self.conj() + } else { + self.clone() + }; + let rhs = if options.rhs_conj { + other.conj() + } else { + other.clone() + }; + return lhs.try_contract_pairwise_default(&rhs); + } + let result_native = profile_pairwise_contract_section("native_contract", || { + contract_native_tensor(self_native, &spec.axes_a, other_native, &spec.axes_b) + })?; + return profile_pairwise_contract_section("from_native", || { + Self::from_native_with_axis_classes( + spec.result_indices.into_vec(), + result_native, + result_axis_classes, + ) + }); } - let subscripts = Self::build_binary_einsum_subscripts( - self.indices.len(), - &spec.axes_a, - other.indices.len(), - &spec.axes_b, - )?; - let result = eager_einsum_ad( - &[ - self.try_materialized_inner()?, - other.try_materialized_inner()?, - ], - &subscripts, - )?; - Self::from_inner_with_axis_classes(spec.result_indices, result, result_axis_classes) + let config = profile_pairwise_contract_section("build_dot_general_config", || { + Self::binary_dot_general_config(&spec.axes_a, &spec.axes_b) + })?; + let result = profile_pairwise_contract_section("dot_general_with_conj", || { + let lhs = profile_pairwise_contract_section("lhs_try_materialized_inner", || { + self.try_materialized_inner() + })?; + let rhs = profile_pairwise_contract_section("rhs_try_materialized_inner", || { + other.try_materialized_inner() + })?; + profile_pairwise_contract_section("dot_general_execute", || { + lhs.dot_general_with_conj(rhs, &config, options.lhs_conj, options.rhs_conj) + }) + .map_err(anyhow::Error::from) + })?; + record_pairwise_contract_profile_bytes( + "dot_general_output", + native_tensor_profile_bytes(result.data()), + ); + profile_pairwise_contract_section("from_inner_axis_classes", || { + Self::from_inner_with_axis_classes( + spec.result_indices.into_vec(), + result, + result_axis_classes, + ) + }) } - /// Contract this tensor with another tensor along explicitly specified index pairs. - /// - /// Similar to NumPy's `tensordot`, this method contracts only along the explicitly - /// specified pairs of indices. Unlike `contract()` which automatically contracts - /// all common indices, `tensordot` gives you explicit control over which indices - /// to contract. - /// - /// # Arguments - /// * `other` - The tensor to contract with - /// * `pairs` - Pairs of indices to contract: `(index_from_self, index_from_other)` - /// - /// # Returns - /// A new tensor resulting from the contraction, or an error if: - /// - Any specified index is not found in the respective tensor - /// - Dimensions don't match for any pair - /// - The same axis is specified multiple times in `self` or `other` - /// - There are common indices (same ID) that are not in the contraction pairs - /// (batch contraction is not yet implemented) - /// - /// # Future: Batch Contraction - /// In a future version, common indices not specified in `pairs` will be treated - /// as batch dimensions (like batched GEMM). Currently, this case returns an error. - /// - /// # Example - /// ``` - /// use tensor4all_core::TensorDynLen; - /// use tensor4all_core::index::{DefaultIndex as Index, DynId}; - /// - /// // Create two tensors: A[i, j] and B[k, l] where j and k have same dimension but different IDs - /// let i = Index::new_dyn(2); - /// let j = Index::new_dyn(3); - /// let k = Index::new_dyn(3); // Same dimension as j, but different ID - /// let l = Index::new_dyn(4); - /// - /// let indices_a = vec![i.clone(), j.clone()]; - /// let tensor_a: TensorDynLen = TensorDynLen::from_dense(indices_a, vec![0.0; 6]).unwrap(); - /// - /// let indices_b = vec![k.clone(), l.clone()]; - /// let tensor_b: TensorDynLen = TensorDynLen::from_dense(indices_b, vec![0.0; 12]).unwrap(); - /// - /// // Contract j (from A) with k (from B): result is C[i, l] - /// let result = tensor_a.tensordot(&tensor_b, &[(j.clone(), k.clone())]).unwrap(); - /// assert_eq!(result.dims(), vec![2, 4]); - /// ``` - pub fn tensordot(&self, other: &Self, pairs: &[(DynIndex, DynIndex)]) -> Result { + pub(crate) fn try_tensordot_pairwise_explicit( + &self, + other: &Self, + pairs: &[(DynIndex, DynIndex)], + ) -> Result { use crate::index_ops::ContractionError; let self_dims = Self::expected_dims_from_indices(&self.indices); @@ -2084,7 +2361,7 @@ impl TensorDynLen { if self.should_use_structured_payload_contract(other) { return self.contract_structured_payloads( other, - spec.result_indices, + spec.result_indices.into_vec(), &spec.axes_a, &spec.axes_b, ); @@ -2095,7 +2372,7 @@ impl TensorDynLen { .try_materialized_inner()? .mul(other.try_materialized_inner()?) .map_err(|e| anyhow::anyhow!("tensordot scalar multiply failed: {e}"))?; - return Self::from_inner(spec.result_indices, result); + return Self::from_inner(spec.result_indices.into_vec(), result); } let self_native = self.as_native()?; @@ -2104,7 +2381,7 @@ impl TensorDynLen { let result_native = contract_native_tensor(self_native, &spec.axes_a, other_native, &spec.axes_b)?; return Self::from_native_with_axis_classes( - spec.result_indices, + spec.result_indices.into_vec(), result_native, result_axis_classes, ); @@ -2124,39 +2401,14 @@ impl TensorDynLen { &subscripts, ) .map_err(|e| anyhow::anyhow!("tensordot failed: {e}"))?; - Self::from_inner_with_axis_classes(spec.result_indices, result, result_axis_classes) + Self::from_inner_with_axis_classes( + spec.result_indices.into_vec(), + result, + result_axis_classes, + ) } - /// Compute the outer product (tensor product) of two tensors. - /// - /// Creates a new tensor whose indices are the concatenation of the indices - /// from both input tensors. The result has shape `[...self.dims, ...other.dims]`. - /// - /// This is equivalent to numpy's `np.outer` or `np.tensordot(a, b, axes=0)`, - /// or ITensor's `*` operator when there are no common indices. - /// - /// # Arguments - /// * `other` - The other tensor to compute outer product with - /// - /// # Returns - /// A new tensor with indices from both tensors. - /// - /// # Example - /// ``` - /// use tensor4all_core::TensorDynLen; - /// use tensor4all_core::index::{DefaultIndex as Index, DynId}; - /// - /// let i = Index::new_dyn(2); - /// let j = Index::new_dyn(3); - /// let tensor_a: TensorDynLen = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap(); - /// let tensor_b: TensorDynLen = - /// TensorDynLen::from_dense(vec![j.clone()], vec![1.0, 2.0, 3.0]).unwrap(); - /// - /// // Outer product: C[i, j] = A[i] * B[j] - /// let result = tensor_a.outer_product(&tensor_b).unwrap(); - /// assert_eq!(result.dims(), vec![2, 3]); - /// ``` - pub fn outer_product(&self, other: &Self) -> Result { + pub(crate) fn try_outer_product_pairwise(&self, other: &Self) -> Result { use anyhow::Context; // Check for common indices - outer product should have none @@ -2406,7 +2658,7 @@ impl TensorDynLen { let same_compact_layout = self.storage.payload_dims() == other_aligned.storage.payload_dims() - && self.storage.payload_strides() == other_aligned.storage.payload_strides() + && self.storage.payload_strides_vec() == other_aligned.storage.payload_strides_vec() && self.storage.axis_classes() == other_aligned.storage.axis_classes(); if same_compact_layout && !self.tracks_grad() @@ -2414,11 +2666,14 @@ impl TensorDynLen { && !a.tracks_grad() && !b.tracks_grad() { - let combined = self + let lhs_storage = self.storage.materialize(self.indices.len())?; + let rhs_storage = other_aligned .storage + .materialize(other_aligned.indices.len())?; + let combined = lhs_storage .axpby( &a.to_backend_scalar(), - other_aligned.storage.as_ref(), + rhs_storage.as_ref(), &b.to_backend_scalar(), ) .map_err(|e| anyhow::anyhow!("storage axpby failed: {e}"))?; @@ -2470,6 +2725,11 @@ impl TensorDynLen { /// assert_eq!(scaled.to_vec::().unwrap(), vec![2.0, 4.0, 6.0]); /// ``` pub fn scale(&self, scalar: AnyScalar) -> Result { + if !self.tracks_grad() && !scalar.tracks_grad() { + let scaled = self.storage.scale(&scalar.to_backend_scalar())?; + return Self::from_storage(self.indices.clone(), Arc::new(scaled)); + } + let self_native = self.as_native()?; let scalar_native = scalar.as_tensor()?.as_native()?; if self_native.dtype() != scalar_native.dtype() { @@ -2526,15 +2786,21 @@ impl TensorDynLen { let other_set: HashSet<_> = other.indices.iter().collect(); if self_set == other_set { let other_aligned = other.permute_indices(&self.indices)?; - let result = self.conj().contract(&other_aligned)?; + let result = super::contract::contract_pair_with_operand_options( + self, + &other_aligned, + PairwiseContractionOptions::new().with_lhs_conj(true), + )?; return result.sum(); } } // Contract self.conj() with other over all indices - let conj_self = self.conj(); - let result = - super::contract::contract_multi(&[&conj_self, other], crate::AllowedPairs::All)?; + let result = super::contract::contract_pair_with_operand_options( + self, + other, + PairwiseContractionOptions::new().with_lhs_conj(true), + )?; // Result should be a scalar (no indices) result.sum() } @@ -2602,7 +2868,7 @@ impl TensorDynLen { Ok(Self { indices: new_indices, - storage: Arc::clone(&self.storage), + storage: self.storage.clone(), structured_ad: self.structured_ad.clone(), eager_cache: Arc::clone(&self.eager_cache), }) @@ -2680,7 +2946,7 @@ impl TensorDynLen { Ok(Self { indices: new_indices_vec, - storage: Arc::clone(&self.storage), + storage: self.storage.clone(), structured_ad: self.structured_ad.clone(), eager_cache: Arc::clone(&self.eager_cache), }) @@ -2736,7 +3002,9 @@ impl TensorDynLen { .unwrap_or_else(Self::empty_eager_cache); Self { indices: new_indices, - storage: Arc::new(self.storage.conj()), + storage: self.storage.conj().unwrap_or_else(|_| { + TensorDynLenStorage::from_storage(Arc::new(self.storage().conj())) + }), structured_ad, eager_cache, } @@ -2785,7 +3053,7 @@ impl TensorDynLen { // Contract tensor with its conjugate over all indices → scalar // ||T||² = Σ T_ijk... * conj(T_ijk...) = Σ |T_ijk...|² let conj = self.conj(); - let scalar = self.contract(&conj)?; + let scalar = super::contract::contract_pair(self, &conj)?; // The mathematical result is nonnegative and real. Clamp tiny negative // roundoff so downstream `sqrt` stays well-defined for complex tensors. Ok(scalar.sum()?.real().max(0.0)) @@ -2820,9 +3088,80 @@ impl TensorDynLen { /// assert!((t.maxabs() - 5.0).abs() < 1e-12); /// ``` pub fn maxabs(&self) -> f64 { - self.to_storage() - .map(|storage| storage.max_abs()) - .unwrap_or(0.0) + self.storage.max_abs().unwrap_or(0.0) + } + + /// Element-wise subtraction with index alignment. + /// + /// This computes `self - other` using the same vector-space semantics as + /// [`TensorVectorSpace`](crate::TensorVectorSpace). + /// + /// # Errors + /// Returns an error if the tensors cannot be aligned or subtracted. + pub fn sub(&self, other: &Self) -> Result { + self.axpby(AnyScalar::new_real(1.0), other, AnyScalar::new_real(-1.0)) + } + + /// Negate all elements. + /// + /// # Errors + /// Returns an error if scalar multiplication fails for the tensor storage. + pub fn neg(&self) -> Result { + self.scale(AnyScalar::new_real(-1.0)) + } + + /// Approximate equality check using Julia `isapprox`-style semantics. + /// + /// Returns `true` when `||self - other|| <= max(atol, rtol * + /// max(||self||, ||other||))`. + pub fn isapprox(&self, other: &Self, atol: f64, rtol: f64) -> bool { + let diff = match self.sub(other) { + Ok(d) => d, + Err(_) => return false, + }; + let diff_norm = diff.norm(); + diff_norm <= atol.max(rtol * self.norm().max(other.norm())) + } + + /// Create a diagonal Kronecker-delta tensor for one input/output index pair. + /// + /// # Errors + /// Returns an error if the two indices have different dimensions. + pub fn diagonal(input_index: &DynIndex, output_index: &DynIndex) -> Result { + ::diagonal(input_index, output_index) + } + + /// Create a product of Kronecker-delta tensors for paired index lists. + /// + /// # Errors + /// Returns an error if the index lists have different lengths or paired + /// dimensions do not match. + pub fn delta(input_indices: &[DynIndex], output_indices: &[DynIndex]) -> Result { + ::delta(input_indices, output_indices) + } + + /// Create a scalar tensor equal to one. + /// + /// # Errors + /// Returns an error if dense scalar construction fails. + pub fn scalar_one() -> Result { + ::scalar_one() + } + + /// Create a tensor filled with ones over the given indices. + /// + /// # Errors + /// Returns an error if the tensor size overflows or dense construction fails. + pub fn ones(indices: &[DynIndex]) -> Result { + ::ones(indices) + } + + /// Create a one-hot tensor with value one at the specified index positions. + /// + /// # Errors + /// Returns an error if any coordinate is outside its index dimension. + pub fn onehot(index_vals: &[(DynIndex, usize)]) -> Result { + ::onehot(index_vals) } /// Compute the relative distance between two tensors. @@ -3087,9 +3426,34 @@ impl TensorIndex for TensorDynLen { // TensorLike implementation for TensorDynLen // ============================================================================ -use crate::tensor_like::{FactorizeError, FactorizeOptions, FactorizeResult, TensorLike}; +use crate::tensor_like::{ + FactorizeError, FactorizeOptions, FactorizeResult, TensorConstructionLike, + TensorContractionLike, TensorFactorizationLike, TensorVectorSpace, +}; + +impl TensorVectorSpace for TensorDynLen { + fn norm_squared(&self) -> f64 { + TensorDynLen::norm_squared(self) + } + + fn maxabs(&self) -> f64 { + TensorDynLen::maxabs(self) + } + + fn axpby(&self, a: crate::AnyScalar, other: &Self, b: crate::AnyScalar) -> Result { + TensorDynLen::axpby(self, a, other, b) + } + + fn scale(&self, scalar: crate::AnyScalar) -> Result { + TensorDynLen::scale(self, scalar) + } + + fn inner_product(&self, other: &Self) -> Result { + TensorDynLen::inner_product(self, other) + } +} -impl TensorLike for TensorDynLen { +impl TensorFactorizationLike for TensorDynLen { fn factorize( &self, left_inds: &[DynIndex], @@ -3106,7 +3470,9 @@ impl TensorLike for TensorDynLen { ) -> std::result::Result, FactorizeError> { crate::factorize::factorize_full_rank(self, left_inds, alg, canonical) } +} +impl TensorContractionLike for TensorDynLen { fn conj(&self) -> Self { // Delegate to the inherent method (complex conjugate for dense tensors) TensorDynLen::conj(self) @@ -3125,17 +3491,7 @@ impl TensorLike for TensorDynLen { } fn outer_product(&self, other: &Self) -> Result { - // Delegate to the inherent method - TensorDynLen::outer_product(self, other) - } - - fn norm_squared(&self) -> f64 { - // Delegate to the inherent method - TensorDynLen::norm_squared(self) - } - - fn maxabs(&self) -> f64 { - TensorDynLen::maxabs(self) + super::contract::outer_product(self, other) } fn permuteinds(&self, new_order: &[DynIndex]) -> Result { @@ -3152,39 +3508,20 @@ impl TensorLike for TensorDynLen { TensorDynLen::fuse_indices(self, old_indices, new_index, order) } - fn contract(tensors: &[&Self], allowed: crate::AllowedPairs<'_>) -> Result { - // Delegate to contract_multi which handles disconnected components - super::contract::contract_multi(tensors, allowed) + fn contract(tensors: &[&Self]) -> Result { + super::contract::contract(tensors) } fn contract_pair(&self, other: &Self) -> Result { - self.try_contract_pairwise_default(other) - } - - fn contract_connected(tensors: &[&Self], allowed: crate::AllowedPairs<'_>) -> Result { - // Delegate to contract_connected which requires connected graph - super::contract::contract_connected(tensors, allowed) + super::contract::contract_pair(self, other) } +} +impl TensorConstructionLike for TensorDynLen { fn select_indices(&self, selected_indices: &[DynIndex], positions: &[usize]) -> Result { TensorDynLen::select_indices(self, selected_indices, positions) } - fn axpby(&self, a: crate::AnyScalar, other: &Self, b: crate::AnyScalar) -> Result { - // Delegate to the inherent method - TensorDynLen::axpby(self, a, other, b) - } - - fn scale(&self, scalar: crate::AnyScalar) -> Result { - // Delegate to the inherent method - TensorDynLen::scale(self, scalar) - } - - fn inner_product(&self, other: &Self) -> Result { - // Delegate to the inherent method - TensorDynLen::inner_product(self, other) - } - fn diagonal(input_index: &DynIndex, output_index: &DynIndex) -> Result { let dim = input_index.dim(); if dim != output_index.dim() { @@ -3529,7 +3866,7 @@ impl TensorDynLen { /// /// # Examples /// ``` - /// use tensor4all_core::{DynIndex, LinearizationOrder, TensorDynLen, TensorLike}; + /// use tensor4all_core::{DynIndex, LinearizationOrder, TensorDynLen}; /// /// let i = DynIndex::new_dyn(2); /// let j = DynIndex::new_dyn(2); @@ -3652,7 +3989,7 @@ impl TensorDynLen { /// /// # Examples /// ``` - /// use tensor4all_core::{DynIndex, LinearizationOrder, TensorDynLen, TensorLike}; + /// use tensor4all_core::{DynIndex, LinearizationOrder, TensorDynLen}; /// /// let fused = DynIndex::new_dyn(4); /// let i = DynIndex::new_dyn(2); diff --git a/crates/tensor4all-core/src/index_ops.rs b/crates/tensor4all-core/src/index_ops.rs index a7eba58b..f8beb4e6 100644 --- a/crates/tensor4all-core/src/index_ops.rs +++ b/crates/tensor4all-core/src/index_ops.rs @@ -1,4 +1,16 @@ use crate::IndexLike; +use smallvec::SmallVec; + +const SMALL_CONTRACTION_INLINE: usize = 8; +const LINEAR_CONTRACTION_SCAN_LIMIT: usize = 64; + +/// Small axis list used by contraction preparation. +pub(crate) type AxisVec = SmallVec<[usize; SMALL_CONTRACTION_INLINE]>; +/// Small index list used by contraction preparation. +pub(crate) type IndexVec = SmallVec<[I; SMALL_CONTRACTION_INLINE]>; + +type AxisPairVec = SmallVec<[(usize, usize); SMALL_CONTRACTION_INLINE]>; +type BoolVec = SmallVec<[bool; SMALL_CONTRACTION_INLINE]>; /// Error type for index replacement operations. #[derive(Debug, Clone, PartialEq, Eq)] @@ -495,7 +507,19 @@ pub fn common_inds(indices_a: &[I], indices_b: &[I]) -> Vec { /// assert_eq!(positions, vec![(1, 0)]); // j is at position 1 in a, position 0 in b /// ``` pub fn common_ind_positions(indices_a: &[I], indices_b: &[I]) -> Vec<(usize, usize)> { - let mut positions = Vec::new(); + common_ind_positions_small(indices_a, indices_b).into_vec() +} + +fn common_ind_positions_small(indices_a: &[I], indices_b: &[I]) -> AxisPairVec { + let scan_work = indices_a.len().saturating_mul(indices_b.len()); + if scan_work <= LINEAR_CONTRACTION_SCAN_LIMIT { + return common_ind_positions_linear(indices_a, indices_b); + } + common_ind_positions_hashed(indices_a, indices_b) +} + +fn common_ind_positions_linear(indices_a: &[I], indices_b: &[I]) -> AxisPairVec { + let mut positions = AxisPairVec::new(); for (pos_a, idx_a) in indices_a.iter().enumerate() { for (pos_b, idx_b) in indices_b.iter().enumerate() { if idx_a.is_contractable(idx_b) { @@ -507,26 +531,48 @@ pub fn common_ind_positions(indices_a: &[I], indices_b: &[I]) -> V positions } +fn common_ind_positions_hashed(indices_a: &[I], indices_b: &[I]) -> AxisPairVec { + use std::collections::HashMap; + + let mut positions_by_id: HashMap<&I::Id, SmallVec<[usize; 2]>> = + HashMap::with_capacity(indices_b.len()); + for (pos_b, idx_b) in indices_b.iter().enumerate() { + positions_by_id.entry(idx_b.id()).or_default().push(pos_b); + } + + let mut positions = AxisPairVec::new(); + for (pos_a, idx_a) in indices_a.iter().enumerate() { + let Some(candidate_positions) = positions_by_id.get(idx_a.id()) else { + continue; + }; + for &pos_b in candidate_positions { + if idx_a.is_contractable(&indices_b[pos_b]) { + positions.push((pos_a, pos_b)); + break; // Each index in a can match at most one in b + } + } + } + positions +} + /// Result of preparing a tensor contraction. /// /// Contains all the information needed to perform the contraction: /// - Which axes to contract from each tensor -/// - The resulting indices and dimensions after contraction +/// - The resulting indices after contraction #[derive(Debug, Clone)] -pub struct ContractionSpec { +pub(crate) struct ContractionSpec { /// Axes to contract from the first tensor (positions in `indices_a`). - pub axes_a: Vec, + pub axes_a: AxisVec, /// Axes to contract from the second tensor (positions in `indices_b`). - pub axes_b: Vec, + pub axes_b: AxisVec, /// Indices of the result tensor (non-contracted indices from both tensors). - pub result_indices: Vec, - /// Dimensions of the result tensor. - pub result_dims: Vec, + pub result_indices: IndexVec, } /// Error type for contraction preparation. #[derive(Debug, Clone, PartialEq, Eq)] -pub enum ContractionError { +pub(crate) enum ContractionError { /// No common indices found for contraction. NoCommonIndices, /// Dimension mismatch for a common index. @@ -591,29 +637,9 @@ impl std::error::Error for ContractionError {} /// Prepare contraction data for two tensors that share common indices. /// -/// This function finds common indices and computes the axes to contract -/// and the resulting indices/dimensions. -/// -/// # Example -/// ``` -/// use tensor4all_core::index::DefaultIndex as Index; -/// use tensor4all_core::index_ops::prepare_contraction; -/// -/// let i = Index::new_dyn(2); -/// let j = Index::new_dyn(3); -/// let k = Index::new_dyn(4); -/// -/// let indices_a = vec![i.clone(), j.clone()]; -/// let dims_a = vec![2, 3]; -/// let indices_b = vec![j.clone(), k.clone()]; -/// let dims_b = vec![3, 4]; -/// -/// let spec = prepare_contraction(&indices_a, &dims_a, &indices_b, &dims_b).unwrap(); -/// assert_eq!(spec.axes_a, vec![1]); // j is at position 1 in a -/// assert_eq!(spec.axes_b, vec![0]); // j is at position 0 in b -/// assert_eq!(spec.result_dims, vec![2, 4]); // [i, k] -/// ``` -pub fn prepare_contraction( +/// This internal helper finds common indices and computes the axes to contract +/// together with the resulting non-contracted indices. +pub(crate) fn prepare_contraction( indices_a: &[I], dims_a: &[usize], indices_b: &[I], @@ -621,9 +647,14 @@ pub fn prepare_contraction( ) -> Result, ContractionError> { // Find common indices and their positions. // If no common indices exist, this becomes an outer product (empty axes). - let positions = common_ind_positions(indices_a, indices_b); + let positions = common_ind_positions_small(indices_a, indices_b); - let (axes_a, axes_b): (Vec<_>, Vec<_>) = positions.iter().copied().unzip(); + let mut axes_a = AxisVec::with_capacity(positions.len()); + let mut axes_b = AxisVec::with_capacity(positions.len()); + for &(pos_a, pos_b) in &positions { + axes_a.push(pos_a); + axes_b.push(pos_b); + } // Verify dimensions match for &(pos_a, pos_b) in &positions { @@ -637,92 +668,52 @@ pub fn prepare_contraction( } } - // Build result indices and dimensions (non-contracted indices) - let mut result_indices = Vec::new(); - let mut result_dims = Vec::new(); - - for (i, idx) in indices_a.iter().enumerate() { - if !axes_a.contains(&i) { - result_indices.push(idx.clone()); - result_dims.push(dims_a[i]); - } - } - - for (i, idx) in indices_b.iter().enumerate() { - if !axes_b.contains(&i) { - result_indices.push(idx.clone()); - result_dims.push(dims_b[i]); - } - } + let result_indices = build_contraction_result_indices(indices_a, &axes_a, indices_b, &axes_b); Ok(ContractionSpec { axes_a, axes_b, result_indices, - result_dims, }) } /// Prepare contraction data for explicit index pairs (like tensordot). /// -/// Unlike `prepare_contraction`, this function takes explicit pairs of indices -/// to contract, allowing contraction of indices with different IDs. -/// -/// # Example -/// ``` -/// use tensor4all_core::index::DefaultIndex as Index; -/// use tensor4all_core::index_ops::prepare_contraction_pairs; -/// -/// let i = Index::new_dyn(2); -/// let j = Index::new_dyn(3); -/// let k = Index::new_dyn(3); // Same dim as j but different ID -/// let l = Index::new_dyn(4); -/// -/// let indices_a = vec![i.clone(), j.clone()]; -/// let dims_a = vec![2, 3]; -/// let indices_b = vec![k.clone(), l.clone()]; -/// let dims_b = vec![3, 4]; -/// -/// // Contract j with k -/// let spec = prepare_contraction_pairs( -/// &indices_a, &dims_a, -/// &indices_b, &dims_b, -/// &[(j.clone(), k.clone())] -/// ).unwrap(); -/// assert_eq!(spec.axes_a, vec![1]); -/// assert_eq!(spec.axes_b, vec![0]); -/// assert_eq!(spec.result_dims, vec![2, 4]); -/// ``` -pub fn prepare_contraction_pairs( +/// Unlike `prepare_contraction`, this internal helper takes explicit pairs of +/// indices to contract, allowing contraction of indices with different IDs. +pub(crate) fn prepare_contraction_pairs( indices_a: &[I], dims_a: &[usize], indices_b: &[I], dims_b: &[usize], pairs: &[(I, I)], ) -> Result, ContractionError> { - use std::collections::HashSet; - if pairs.is_empty() { return Err(ContractionError::NoCommonIndices); } // Check for batch contraction (common indices not in pairs). The explicit // pair list identifies axes by full index metadata, not by ID alone. - let contracted_a_indices: HashSet<_> = pairs.iter().map(|(idx, _)| idx).collect(); - let contracted_b_indices: HashSet<_> = pairs.iter().map(|(_, idx)| idx).collect(); - - let common_positions = common_ind_positions(indices_a, indices_b); + let common_positions = common_ind_positions_small(indices_a, indices_b); for (pos_a, pos_b) in &common_positions { let idx_a = &indices_a[*pos_a]; let idx_b = &indices_b[*pos_b]; - if !contracted_a_indices.contains(idx_a) || !contracted_b_indices.contains(idx_b) { + if !pairs + .iter() + .any(|(contracted_idx, _)| contracted_idx == idx_a) + || !pairs + .iter() + .any(|(_, contracted_idx)| contracted_idx == idx_b) + { return Err(ContractionError::BatchContractionNotImplemented); } } // Find positions and validate - let mut axes_a = Vec::new(); - let mut axes_b = Vec::new(); + let mut axes_a = AxisVec::with_capacity(pairs.len()); + let mut axes_b = AxisVec::with_capacity(pairs.len()); + let mut contracted_a = bool_flags(indices_a.len()); + let mut contracted_b = bool_flags(indices_b.len()); for (idx_a, idx_b) in pairs { let pos_a = indices_a @@ -746,45 +737,115 @@ pub fn prepare_contraction_pairs( } // Check for duplicate axes - if axes_a.contains(&pos_a) { + if contracted_a[pos_a] { return Err(ContractionError::DuplicateAxis { tensor: "self", pos: pos_a, }); } - if axes_b.contains(&pos_b) { + if contracted_b[pos_b] { return Err(ContractionError::DuplicateAxis { tensor: "other", pos: pos_b, }); } + contracted_a[pos_a] = true; + contracted_b[pos_b] = true; axes_a.push(pos_a); axes_b.push(pos_b); } - // Build result indices and dimensions - let mut result_indices = Vec::new(); - let mut result_dims = Vec::new(); + let result_indices = build_contraction_result_indices(indices_a, &axes_a, indices_b, &axes_b); + + Ok(ContractionSpec { + axes_a, + axes_b, + result_indices, + }) +} + +fn bool_flags(len: usize) -> BoolVec { + let mut flags = BoolVec::with_capacity(len); + flags.resize(len, false); + flags +} + +fn build_contraction_result_indices( + indices_a: &[I], + axes_a: &[usize], + indices_b: &[I], + axes_b: &[usize], +) -> IndexVec { + let mut contracted_a = bool_flags(indices_a.len()); + let mut contracted_b = bool_flags(indices_b.len()); + for &axis in axes_a { + contracted_a[axis] = true; + } + for &axis in axes_b { + contracted_b[axis] = true; + } + + let result_len = indices_a.len() + indices_b.len() - axes_a.len() - axes_b.len(); + let mut result_indices = IndexVec::with_capacity(result_len); for (i, idx) in indices_a.iter().enumerate() { - if !axes_a.contains(&i) { + if !contracted_a[i] { result_indices.push(idx.clone()); - result_dims.push(dims_a[i]); } } for (i, idx) in indices_b.iter().enumerate() { - if !axes_b.contains(&i) { + if !contracted_b[i] { result_indices.push(idx.clone()); - result_dims.push(dims_b[i]); } } - Ok(ContractionSpec { - axes_a, - axes_b, - result_indices, - result_dims, - }) + result_indices +} + +#[cfg(test)] +mod tests { + use super::{prepare_contraction, prepare_contraction_pairs}; + use crate::index::DefaultIndex as Index; + + #[test] + fn prepare_contraction_pairs_selects_exact_same_id_prime_index() { + let i = Index::new_dyn(2); + let i_prime = i.prime(); + let spec = prepare_contraction_pairs( + &[i.clone(), i_prime.clone()], + &[2, 2], + std::slice::from_ref(&i_prime), + &[2], + &[(i_prime.clone(), i_prime.clone())], + ) + .unwrap(); + + assert_eq!(spec.axes_a.as_slice(), &[1]); + assert_eq!(spec.axes_b.as_slice(), &[0]); + assert_eq!(spec.result_indices.as_slice(), &[i]); + } + + #[test] + fn prepare_contraction_large_rank_uses_hash_fallback_semantics() { + let mut lhs: Vec<_> = (0..9).map(|_| Index::new_dyn(2)).collect(); + let shared = lhs[7].clone(); + let mut rhs: Vec<_> = (0..9).map(|_| Index::new_dyn(2)).collect(); + rhs[5] = shared; + + let lhs_dims = vec![2; lhs.len()]; + let rhs_dims = vec![2; rhs.len()]; + let spec = prepare_contraction(&lhs, &lhs_dims, &rhs, &rhs_dims).unwrap(); + + assert_eq!(spec.axes_a.as_slice(), &[7]); + assert_eq!(spec.axes_b.as_slice(), &[5]); + assert_eq!(spec.result_indices.len(), lhs.len() + rhs.len() - 2); + + lhs.remove(7); + rhs.remove(5); + let mut expected = lhs; + expected.extend(rhs); + assert_eq!(spec.result_indices.as_slice(), expected.as_slice()); + } } diff --git a/crates/tensor4all-core/src/krylov.rs b/crates/tensor4all-core/src/krylov.rs index 1d707945..2ead2896 100644 --- a/crates/tensor4all-core/src/krylov.rs +++ b/crates/tensor4all-core/src/krylov.rs @@ -1,6 +1,6 @@ //! Krylov subspace methods for solving linear equations with abstract tensors. //! -//! This module provides iterative solvers that work with any type implementing [`TensorLike`], +//! This module provides iterative solvers that work with any type implementing [`TensorVectorSpace`], //! enabling their use in tensor network algorithms without requiring dense vector representations. //! //! # Solvers @@ -17,7 +17,7 @@ //! ``` //! use tensor4all_core::{ //! krylov::{gmres, GmresOptions}, -//! DynIndex, TensorDynLen, TensorLike, +//! DynIndex, TensorDynLen, TensorVectorSpace, //! }; //! //! # fn main() -> anyhow::Result<()> { @@ -35,8 +35,93 @@ //! ``` use crate::any_scalar::AnyScalar; -use crate::TensorLike; +use crate::TensorVectorSpace; use anyhow::Result; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::{Duration, Instant}; + +static GMRES_OP_PROFILE_COUNTER: AtomicUsize = AtomicUsize::new(0); + +#[derive(Debug, Clone)] +struct GmresOpProfile { + started: Instant, + b_norm: Duration, + apply: Duration, + inner_product: Duration, + axpby: Duration, + norm: Duration, + scale: Duration, + triangular_solve: Duration, + solution_update: Duration, + apply_calls: usize, + inner_product_calls: usize, + axpby_calls: usize, + norm_calls: usize, + scale_calls: usize, + triangular_solve_calls: usize, + solution_update_calls: usize, +} + +impl Default for GmresOpProfile { + fn default() -> Self { + Self { + started: Instant::now(), + b_norm: Duration::ZERO, + apply: Duration::ZERO, + inner_product: Duration::ZERO, + axpby: Duration::ZERO, + norm: Duration::ZERO, + scale: Duration::ZERO, + triangular_solve: Duration::ZERO, + solution_update: Duration::ZERO, + apply_calls: 0, + inner_product_calls: 0, + axpby_calls: 0, + norm_calls: 0, + scale_calls: 0, + triangular_solve_calls: 0, + solution_update_calls: 0, + } + } +} + +impl GmresOpProfile { + fn measured(&self) -> Duration { + self.b_norm + + self.apply + + self.inner_product + + self.axpby + + self.norm + + self.scale + + self.triangular_solve + + self.solution_update + } + + fn print(&self, id: usize, iterations: usize, residual_norm: f64, converged: bool) { + let total = self.started.elapsed(); + let other = total.saturating_sub(self.measured()); + eprintln!( + "T4A gmres_op_profile #{id}: iterations={iterations} residual={residual_norm:.6e} converged={converged} total_ms={:.3} apply_ms={:.3} apply_calls={} inner_ms={:.3} inner_calls={} axpby_ms={:.3} axpby_calls={} norm_ms={:.3} norm_calls={} scale_ms={:.3} scale_calls={} update_ms={:.3} update_calls={} triangular_ms={:.3} triangular_calls={} b_norm_ms={:.3} other_ms={:.3}", + total.as_secs_f64() * 1000.0, + self.apply.as_secs_f64() * 1000.0, + self.apply_calls, + self.inner_product.as_secs_f64() * 1000.0, + self.inner_product_calls, + self.axpby.as_secs_f64() * 1000.0, + self.axpby_calls, + self.norm.as_secs_f64() * 1000.0, + self.norm_calls, + self.scale.as_secs_f64() * 1000.0, + self.scale_calls, + self.solution_update.as_secs_f64() * 1000.0, + self.solution_update_calls, + self.triangular_solve.as_secs_f64() * 1000.0, + self.triangular_solve_calls, + self.b_norm.as_secs_f64() * 1000.0, + other.as_secs_f64() * 1000.0, + ); + } +} /// Options for GMRES solver. /// @@ -95,6 +180,28 @@ impl Default for GmresOptions { } } +#[derive(Debug, Clone, Copy, PartialEq)] +enum GmresTolerance { + Relative(f64), + Absolute(f64), +} + +impl GmresTolerance { + fn residual_value(self, residual_norm: f64, b_norm: f64) -> f64 { + match self { + Self::Relative(_) => residual_norm / b_norm, + Self::Absolute(_) => residual_norm, + } + } + + fn is_converged(self, residual_norm: f64, b_norm: f64) -> bool { + match self { + Self::Relative(rtol) => residual_norm / b_norm < rtol, + Self::Absolute(atol) => residual_norm < atol, + } + } +} + /// Result of GMRES solver. /// /// Contains the solution, iteration count, final residual norm, and @@ -103,7 +210,7 @@ impl Default for GmresOptions { /// # Examples /// /// ``` -/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; +/// use tensor4all_core::{DynIndex, TensorDynLen, TensorVectorSpace}; /// use tensor4all_core::krylov::{gmres, GmresOptions}; /// /// let i = DynIndex::new_dyn(2); @@ -132,7 +239,7 @@ pub struct GmresResult { /// Solve `A x = b` using GMRES (Generalized Minimal Residual Method). /// /// This implements the restarted GMRES algorithm that works with abstract tensor types -/// through the [`TensorLike`] trait's vector space operations. +/// through the [`TensorVectorSpace`] trait's vector space operations. /// /// # Algorithm /// @@ -142,7 +249,7 @@ pub struct GmresResult { /// /// # Type Parameters /// -/// * `T` - A tensor type implementing `TensorLike` +/// * `T` - A tensor type implementing `TensorVectorSpace` /// * `F` - A function that applies the linear operator: `F(x) = A x` /// /// # Arguments @@ -163,7 +270,547 @@ pub struct GmresResult { /// - The linear operator application fails pub fn gmres(apply_a: F, b: &T, x0: &T, options: &GmresOptions) -> Result> where - T: TensorLike, + T: TensorVectorSpace, + F: Fn(&T) -> Result, +{ + gmres_impl( + apply_a, + b, + x0, + options, + GmresTolerance::Relative(options.rtol), + None, + ) +} + +/// Solve `A x = b` using GMRES with an absolute residual tolerance. +/// +/// This variant stops when `||b - A*x|| < atol`. The default [`gmres`] API uses +/// relative residual tolerance and is preferred for scale-independent solves. +pub fn gmres_with_absolute_tolerance( + apply_a: F, + b: &T, + x0: &T, + options: &GmresOptions, + atol: f64, +) -> Result> +where + T: TensorVectorSpace, + F: Fn(&T) -> Result, +{ + gmres_impl( + apply_a, + b, + x0, + options, + GmresTolerance::Absolute(atol), + None, + ) +} + +/// Solve `(a0 I + a1 A) x = b` using GMRES with relative residual tolerance. +/// +/// The Arnoldi basis is built from the unshifted `A` callback, while affine +/// coefficients are applied in the projected Hessenberg problem, matching +/// KrylovKit's affine linear-solve convention. +pub fn gmres_affine( + apply_a: F, + b: &T, + x0: &T, + a0: AnyScalar, + a1: AnyScalar, + options: &GmresOptions, +) -> Result> +where + T: TensorVectorSpace, + F: Fn(&T) -> Result, +{ + gmres_affine_impl( + apply_a, + b, + x0, + a0, + a1, + options, + GmresTolerance::Relative(options.rtol), + ) +} + +/// Solve `(a0 I + a1 A) x = b` using GMRES with an absolute residual tolerance. +/// +/// The Arnoldi basis is built from the unshifted `A` callback, while the affine +/// coefficients are applied to the small Hessenberg problem. This mirrors +/// KrylovKit's `linsolve(operator, b, a0, a1)` algorithm and avoids changing the +/// Krylov basis when affine coefficients are present. +pub fn gmres_affine_with_absolute_tolerance( + apply_a: F, + b: &T, + x0: &T, + a0: AnyScalar, + a1: AnyScalar, + options: &GmresOptions, + atol: f64, +) -> Result> +where + T: TensorVectorSpace, + F: Fn(&T) -> Result, +{ + gmres_affine_impl( + apply_a, + b, + x0, + a0, + a1, + options, + GmresTolerance::Absolute(atol), + ) +} + +fn gmres_affine_impl( + apply_a: F, + b: &T, + x0: &T, + a0: AnyScalar, + a1: AnyScalar, + options: &GmresOptions, + tolerance: GmresTolerance, +) -> Result> +where + T: TensorVectorSpace, + F: Fn(&T) -> Result, +{ + b.validate()?; + x0.validate()?; + + let profile_enabled = std::env::var_os("T4A_GMRES_OP_PROFILE").is_some(); + let profile_id = if profile_enabled { + GMRES_OP_PROFILE_COUNTER.fetch_add(1, Ordering::Relaxed) + } else { + 0 + }; + let mut profile = GmresOpProfile::default(); + let mut total_iters = 0usize; + + macro_rules! finish { + ($result:expr) => {{ + let result = $result; + if profile_enabled { + profile.print( + profile_id, + result.iterations, + result.residual_norm, + result.converged, + ); + } + return Ok(result); + }}; + } + + let started = Instant::now(); + let b_norm = b.norm(); + if profile_enabled { + profile.b_norm += started.elapsed(); + } + if b_norm < 1e-15 { + finish!(GmresResult { + solution: x0.clone(), + iterations: 0, + residual_norm: 0.0, + converged: true, + }); + } + if a0.is_zero() && a1.is_zero() { + anyhow::bail!("gmres_affine: at least one affine coefficient must be nonzero"); + } + if a1.is_zero() { + let started = Instant::now(); + let solution = b.scale(AnyScalar::new_real(1.0) / a0)?; + if profile_enabled { + profile.scale += started.elapsed(); + profile.scale_calls += 1; + } + finish!(GmresResult { + solution, + iterations: 0, + residual_norm: 0.0, + converged: true, + }); + } + + let mut x = x0.clone(); + + for restart in 0..options.max_restarts { + let started = Instant::now(); + let ax = apply_a(&x)?; + if profile_enabled { + profile.apply += started.elapsed(); + profile.apply_calls += 1; + } + if restart == 0 { + ax.validate()?; + } + let started = Instant::now(); + let affine_x = x.axpby(a0.clone(), &ax, a1.clone())?; + if profile_enabled { + profile.axpby += started.elapsed(); + profile.axpby_calls += 1; + } + let started = Instant::now(); + let r = b.axpby( + AnyScalar::new_real(1.0), + &affine_x, + AnyScalar::new_real(-1.0), + )?; + if profile_enabled { + profile.axpby += started.elapsed(); + profile.axpby_calls += 1; + } + let started = Instant::now(); + let r_norm = r.norm(); + if profile_enabled { + profile.norm += started.elapsed(); + profile.norm_calls += 1; + } + let residual_value = tolerance.residual_value(r_norm, b_norm); + if options.verbose { + eprintln!( + "GMRES restart {}: initial residual = {:.6e}", + restart, residual_value + ); + } + if tolerance.is_converged(r_norm, b_norm) { + finish!(GmresResult { + solution: x, + iterations: total_iters, + residual_norm: residual_value, + converged: true, + }); + } + + let cycle_max_iter = options.max_iter; + let mut v_basis: Vec = Vec::with_capacity(cycle_max_iter + 1); + let started = Instant::now(); + v_basis.push(r.scale(AnyScalar::new_real(1.0 / r_norm))?); + if profile_enabled { + profile.scale += started.elapsed(); + profile.scale_calls += 1; + } + + let mut h_matrix: Vec> = Vec::with_capacity(cycle_max_iter); + let mut cs: Vec = Vec::with_capacity(cycle_max_iter); + let mut sn: Vec = Vec::with_capacity(cycle_max_iter); + let mut g: Vec = vec![AnyScalar::new_real(r_norm)]; + let mut solution_already_updated = false; + + for j in 0..cycle_max_iter { + total_iters += 1; + + let started = Instant::now(); + let w = apply_a(&v_basis[j])?; + if profile_enabled { + profile.apply += started.elapsed(); + profile.apply_calls += 1; + } + let mut h_a_col: Vec = Vec::with_capacity(j + 2); + let mut w_orth = w; + + for v_i in v_basis.iter().take(j + 1) { + let started = Instant::now(); + let h_ij = v_i.inner_product(&w_orth)?; + if profile_enabled { + profile.inner_product += started.elapsed(); + profile.inner_product_calls += 1; + } + h_a_col.push(h_ij.clone()); + let neg_h_ij = AnyScalar::new_real(0.0) - h_ij; + let started = Instant::now(); + w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?; + if profile_enabled { + profile.axpby += started.elapsed(); + profile.axpby_calls += 1; + } + } + for (i, v_i) in v_basis.iter().take(j + 1).enumerate() { + let started = Instant::now(); + let correction = v_i.inner_product(&w_orth)?; + if profile_enabled { + profile.inner_product += started.elapsed(); + profile.inner_product_calls += 1; + } + h_a_col[i] = h_a_col[i].clone() + correction.clone(); + let neg_correction = AnyScalar::new_real(0.0) - correction; + let started = Instant::now(); + w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_correction)?; + if profile_enabled { + profile.axpby += started.elapsed(); + profile.axpby_calls += 1; + } + } + + let started = Instant::now(); + let h_jp1_j_real = w_orth.norm(); + if profile_enabled { + profile.norm += started.elapsed(); + profile.norm_calls += 1; + } + h_a_col.push(AnyScalar::new_real(h_jp1_j_real)); + + let mut h_col: Vec = Vec::with_capacity(j + 2); + for h in h_a_col.iter().take(j) { + h_col.push(a1.clone() * h.clone()); + } + h_col.push(a0.clone() + a1.clone() * h_a_col[j].clone()); + h_col.push(a1.clone() * h_a_col[j + 1].clone()); + + #[allow(clippy::needless_range_loop)] + for i in 0..j { + let h_i = h_col[i].clone(); + let h_ip1 = h_col[i + 1].clone(); + let (new_hi, new_hip1) = apply_givens_rotation(&cs[i], &sn[i], &h_i, &h_ip1); + h_col[i] = new_hi; + h_col[i + 1] = new_hip1; + } + + let (c_j, s_j) = compute_givens_rotation(&h_col[j], &h_col[j + 1]); + cs.push(c_j.clone()); + sn.push(s_j.clone()); + + let (new_hj, _) = apply_givens_rotation(&c_j, &s_j, &h_col[j], &h_col[j + 1]); + h_col[j] = new_hj; + h_col[j + 1] = AnyScalar::new_real(0.0); + + let g_j = g[j].clone(); + let g_jp1 = AnyScalar::new_real(0.0); + let (new_gj, new_gjp1) = apply_givens_rotation(&c_j, &s_j, &g_j, &g_jp1); + g[j] = new_gj; + let res_norm = new_gjp1.abs(); + g.push(new_gjp1); + + h_matrix.push(h_col); + let residual_value = tolerance.residual_value(res_norm, b_norm); + if options.verbose { + eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, residual_value); + } + + if tolerance.is_converged(res_norm, b_norm) { + let started = Instant::now(); + let y = solve_upper_triangular(&h_matrix, &g[..=j])?; + if profile_enabled { + profile.triangular_solve += started.elapsed(); + profile.triangular_solve_calls += 1; + } + let started = Instant::now(); + x = update_solution(&x, &v_basis[..=j], &y)?; + if profile_enabled { + profile.solution_update += started.elapsed(); + profile.solution_update_calls += 1; + } + if options.check_true_residual { + let started = Instant::now(); + let ax_check = apply_a(&x)?; + if profile_enabled { + profile.apply += started.elapsed(); + profile.apply_calls += 1; + } + let started = Instant::now(); + let affine_check = x.axpby(a0.clone(), &ax_check, a1.clone())?; + if profile_enabled { + profile.axpby += started.elapsed(); + profile.axpby_calls += 1; + } + let started = Instant::now(); + let r_check = b.axpby( + AnyScalar::new_real(1.0), + &affine_check, + AnyScalar::new_real(-1.0), + )?; + if profile_enabled { + profile.axpby += started.elapsed(); + profile.axpby_calls += 1; + } + let started = Instant::now(); + let true_abs_res = r_check.norm(); + if profile_enabled { + profile.norm += started.elapsed(); + profile.norm_calls += 1; + } + let true_residual_value = tolerance.residual_value(true_abs_res, b_norm); + if options.verbose { + eprintln!( + "GMRES true residual check: hessenberg={:.6e}, checked={:.6e}", + residual_value, true_residual_value + ); + } + if tolerance.is_converged(true_abs_res, b_norm) { + finish!(GmresResult { + solution: x, + iterations: total_iters, + residual_norm: true_residual_value, + converged: true, + }); + } + solution_already_updated = true; + break; + } else { + finish!(GmresResult { + solution: x, + iterations: total_iters, + residual_norm: residual_value, + converged: true, + }); + } + } + + if h_jp1_j_real > 1e-14 { + let started = Instant::now(); + v_basis.push(w_orth.scale(AnyScalar::new_real(1.0 / h_jp1_j_real))?); + if profile_enabled { + profile.scale += started.elapsed(); + profile.scale_calls += 1; + } + } else { + let started = Instant::now(); + let y = solve_upper_triangular(&h_matrix, &g[..=j])?; + if profile_enabled { + profile.triangular_solve += started.elapsed(); + profile.triangular_solve_calls += 1; + } + let started = Instant::now(); + x = update_solution(&x, &v_basis[..=j], &y)?; + if profile_enabled { + profile.solution_update += started.elapsed(); + profile.solution_update_calls += 1; + } + let started = Instant::now(); + let ax_final = apply_a(&x)?; + if profile_enabled { + profile.apply += started.elapsed(); + profile.apply_calls += 1; + } + let started = Instant::now(); + let affine_final = x.axpby(a0.clone(), &ax_final, a1.clone())?; + if profile_enabled { + profile.axpby += started.elapsed(); + profile.axpby_calls += 1; + } + let started = Instant::now(); + let r_final = b.axpby( + AnyScalar::new_real(1.0), + &affine_final, + AnyScalar::new_real(-1.0), + )?; + if profile_enabled { + profile.axpby += started.elapsed(); + profile.axpby_calls += 1; + } + let started = Instant::now(); + let final_abs_res = r_final.norm(); + if profile_enabled { + profile.norm += started.elapsed(); + profile.norm_calls += 1; + } + let final_res = tolerance.residual_value(final_abs_res, b_norm); + finish!(GmresResult { + solution: x, + iterations: total_iters, + residual_norm: final_res, + converged: tolerance.is_converged(final_abs_res, b_norm), + }); + } + } + + if !solution_already_updated { + let actual_iters = h_matrix.len(); + let started = Instant::now(); + let y = solve_upper_triangular(&h_matrix, &g[..actual_iters])?; + if profile_enabled { + profile.triangular_solve += started.elapsed(); + profile.triangular_solve_calls += 1; + } + let started = Instant::now(); + x = update_solution(&x, &v_basis[..actual_iters], &y)?; + if profile_enabled { + profile.solution_update += started.elapsed(); + profile.solution_update_calls += 1; + } + } + } + + let started = Instant::now(); + let ax_final = apply_a(&x)?; + if profile_enabled { + profile.apply += started.elapsed(); + profile.apply_calls += 1; + } + let started = Instant::now(); + let affine_final = x.axpby(a0, &ax_final, a1)?; + if profile_enabled { + profile.axpby += started.elapsed(); + profile.axpby_calls += 1; + } + let started = Instant::now(); + let r_final = b.axpby( + AnyScalar::new_real(1.0), + &affine_final, + AnyScalar::new_real(-1.0), + )?; + if profile_enabled { + profile.axpby += started.elapsed(); + profile.axpby_calls += 1; + } + let started = Instant::now(); + let final_abs_res = r_final.norm(); + if profile_enabled { + profile.norm += started.elapsed(); + profile.norm_calls += 1; + } + let final_res = tolerance.residual_value(final_abs_res, b_norm); + + finish!(GmresResult { + solution: x, + iterations: total_iters, + residual_norm: final_res, + converged: tolerance.is_converged(final_abs_res, b_norm), + }) +} + +/// Solve `A x = b` using GMRES while enforcing a total iteration limit. +/// +/// [`GmresOptions::max_iter`] remains the restart cycle length and +/// [`GmresOptions::max_restarts`] remains the maximum number of restart cycles. +/// `max_total_iter` caps the total number of Arnoldi steps across all restart +/// cycles; the final cycle is shortened when necessary. +pub fn gmres_with_total_iteration_limit( + apply_a: F, + b: &T, + x0: &T, + options: &GmresOptions, + max_total_iter: usize, +) -> Result> +where + T: TensorVectorSpace, + F: Fn(&T) -> Result, +{ + gmres_impl( + apply_a, + b, + x0, + options, + GmresTolerance::Relative(options.rtol), + Some(max_total_iter), + ) +} + +fn gmres_impl( + apply_a: F, + b: &T, + x0: &T, + options: &GmresOptions, + tolerance: GmresTolerance, + max_total_iter: Option, +) -> Result> +where + T: TensorVectorSpace, F: Fn(&T) -> Result, { // Validate structural consistency of inputs @@ -185,6 +832,20 @@ where let mut total_iters = 0; for _restart in 0..options.max_restarts { + let cycle_max_iter = match max_total_iter { + Some(limit) => { + let remaining = limit.saturating_sub(total_iters); + if remaining == 0 { + break; + } + options.max_iter.min(remaining) + } + None => options.max_iter, + }; + if cycle_max_iter == 0 { + break; + } + // Compute initial residual: r = b - A*x let ax = apply_a(&x)?; // Validate operator output on first restart @@ -194,38 +855,39 @@ where // r = 1.0 * b + (-1.0) * ax let r = b.axpby(AnyScalar::new_real(1.0), &ax, AnyScalar::new_real(-1.0))?; let r_norm = r.norm(); - let rel_res = r_norm / b_norm; + let residual_value = tolerance.residual_value(r_norm, b_norm); if options.verbose { eprintln!( "GMRES restart {}: initial residual = {:.6e}", - _restart, rel_res + _restart, residual_value ); } - if rel_res < options.rtol { + if tolerance.is_converged(r_norm, b_norm) { return Ok(GmresResult { solution: x, iterations: total_iters, - residual_norm: rel_res, + residual_norm: residual_value, converged: true, }); } // Arnoldi process with modified Gram-Schmidt - let mut v_basis: Vec = Vec::with_capacity(options.max_iter + 1); - let mut h_matrix: Vec> = Vec::with_capacity(options.max_iter); + let mut v_basis: Vec = Vec::with_capacity(cycle_max_iter + 1); + let mut h_matrix: Vec> = Vec::with_capacity(cycle_max_iter); // v_0 = r / ||r|| let v0 = r.scale(AnyScalar::new_real(1.0 / r_norm))?; v_basis.push(v0); // Initialize Givens rotation storage - let mut cs: Vec = Vec::with_capacity(options.max_iter); - let mut sn: Vec = Vec::with_capacity(options.max_iter); + let mut cs: Vec = Vec::with_capacity(cycle_max_iter); + let mut sn: Vec = Vec::with_capacity(cycle_max_iter); let mut g: Vec = vec![AnyScalar::new_real(r_norm)]; // residual in upper Hessenberg space + let mut solution_already_updated = false; - for j in 0..options.max_iter { + for j in 0..cycle_max_iter { total_iters += 1; // w = A * v_j @@ -243,6 +905,16 @@ where w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_h_ij)?; } + // KrylovKit's default orthogonalizer is ModifiedGramSchmidt2. + // The second pass is important for long Krylov bases and complex + // non-Hermitian local problems. + for (i, v_i) in v_basis.iter().take(j + 1).enumerate() { + let correction = v_i.inner_product(&w_orth)?; + h_col[i] = h_col[i].clone() + correction.clone(); + let neg_correction = AnyScalar::new_real(0.0) - correction; + w_orth = w_orth.axpby(AnyScalar::new_real(1.0), v_i, neg_correction)?; + } + let h_jp1_j_real = w_orth.norm(); let h_jp1_j = AnyScalar::new_real(h_jp1_j_real); h_col.push(h_jp1_j); @@ -278,22 +950,51 @@ where h_matrix.push(h_col); // Check convergence - let rel_res = res_norm / b_norm; + let residual_value = tolerance.residual_value(res_norm, b_norm); if options.verbose { - eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, rel_res); + eprintln!("GMRES iter {}: residual = {:.6e}", j + 1, residual_value); } - if rel_res < options.rtol { + if tolerance.is_converged(res_norm, b_norm) { // Solve upper triangular system and update x let y = solve_upper_triangular(&h_matrix, &g[..=j])?; x = update_solution(&x, &v_basis[..=j], &y)?; - return Ok(GmresResult { - solution: x, - iterations: total_iters, - residual_norm: rel_res, - converged: true, - }); + if options.check_true_residual { + let ax_check = apply_a(&x)?; + let r_check = b.axpby( + AnyScalar::new_real(1.0), + &ax_check, + AnyScalar::new_real(-1.0), + )?; + let true_abs_res = r_check.norm(); + let true_residual_value = tolerance.residual_value(true_abs_res, b_norm); + + if options.verbose { + eprintln!( + "GMRES true residual check: hessenberg={:.6e}, checked={:.6e}", + residual_value, true_residual_value + ); + } + + if tolerance.is_converged(true_abs_res, b_norm) { + return Ok(GmresResult { + solution: x, + iterations: total_iters, + residual_norm: true_residual_value, + converged: true, + }); + } + solution_already_updated = true; + break; + } else { + return Ok(GmresResult { + solution: x, + iterations: total_iters, + residual_norm: residual_value, + converged: true, + }); + } } // Add new basis vector (if not converged and h_jp1_j is not too small) @@ -310,19 +1011,23 @@ where &ax_final, AnyScalar::new_real(-1.0), )?; - let final_res = r_final.norm() / b_norm; + let final_abs_res = r_final.norm(); + let final_res = tolerance.residual_value(final_abs_res, b_norm); return Ok(GmresResult { solution: x, iterations: total_iters, residual_norm: final_res, - converged: final_res < options.rtol, + converged: tolerance.is_converged(final_abs_res, b_norm), }); } } // End of restart cycle - update x with current solution - let y = solve_upper_triangular(&h_matrix, &g[..options.max_iter])?; - x = update_solution(&x, &v_basis[..options.max_iter], &y)?; + if !solution_already_updated { + let actual_iters = h_matrix.len(); + let y = solve_upper_triangular(&h_matrix, &g[..actual_iters])?; + x = update_solution(&x, &v_basis[..actual_iters], &y)?; + } } // Compute final residual @@ -332,13 +1037,14 @@ where &ax_final, AnyScalar::new_real(-1.0), )?; - let final_res = r_final.norm() / b_norm; + let final_abs_res = r_final.norm(); + let final_res = tolerance.residual_value(final_abs_res, b_norm); Ok(GmresResult { solution: x, iterations: total_iters, residual_norm: final_res, - converged: final_res < options.rtol, + converged: tolerance.is_converged(final_abs_res, b_norm), }) } @@ -349,7 +1055,7 @@ where /// /// # Type Parameters /// -/// * `T` - A tensor type implementing `TensorLike` +/// * `T` - A tensor type implementing `TensorVectorSpace` /// * `F` - A function that applies the linear operator: `F(x) = A x` /// * `Tr` - A function that truncates a tensor in-place: `Tr(&mut x)` /// @@ -372,7 +1078,7 @@ where /// Solve `2x = b` with a no-op truncation function: /// /// ``` -/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike, AnyScalar}; +/// use tensor4all_core::{DynIndex, TensorDynLen, TensorVectorSpace, AnyScalar}; /// use tensor4all_core::krylov::{gmres_with_truncation, GmresOptions}; /// /// let i = DynIndex::new_dyn(2); @@ -398,7 +1104,7 @@ pub fn gmres_with_truncation( truncate: Tr, ) -> Result> where - T: TensorLike, + T: TensorVectorSpace, F: Fn(&T) -> Result, Tr: Fn(&mut T) -> Result<()>, { @@ -848,7 +1554,7 @@ pub struct RestartGmresResult { /// /// # Type Parameters /// -/// * `T` - A tensor type implementing `TensorLike` +/// * `T` - A tensor type implementing `TensorVectorSpace` /// * `F` - A function that applies the linear operator: `F(x) = A x` /// * `Tr` - A function that truncates a tensor in-place: `Tr(&mut x)` /// @@ -869,7 +1575,7 @@ pub struct RestartGmresResult { /// Solve `5x = b` with no truncation: /// /// ``` -/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike, AnyScalar}; +/// use tensor4all_core::{DynIndex, TensorDynLen, TensorVectorSpace, AnyScalar}; /// use tensor4all_core::krylov::{restart_gmres_with_truncation, RestartGmresOptions}; /// /// let i = DynIndex::new_dyn(3); @@ -894,7 +1600,7 @@ pub fn restart_gmres_with_truncation( truncate: Tr, ) -> Result> where - T: TensorLike, + T: TensorVectorSpace, F: Fn(&T) -> Result, Tr: Fn(&mut T) -> Result<()>, { @@ -1044,13 +1750,22 @@ where /// This function keeps computation in `AnyScalar` space to preserve AD metadata /// as much as possible. fn compute_givens_rotation(a: &AnyScalar, b: &AnyScalar) -> (AnyScalar, AnyScalar) { - // r^2 = conj(a)*a + conj(b)*b (works for both real and complex) - let norm2 = a.clone().conj() * a.clone() + b.clone().conj() * b.clone(); - let r = norm2.sqrt(); - if r.abs() < 1e-15 { + let a_abs = a.abs(); + let b_abs = b.abs(); + let r = (a_abs * a_abs + b_abs * b_abs).sqrt(); + if r < 1e-15 { (AnyScalar::new_real(1.0), AnyScalar::new_real(0.0)) + } else if a_abs < 1e-15 { + ( + AnyScalar::new_real(0.0), + b.clone().conj() / AnyScalar::new_real(r), + ) } else { - (a.clone() / r.clone(), b.clone() / r) + let phase = a.clone() / AnyScalar::new_real(a_abs); + ( + AnyScalar::new_real(a_abs / r), + phase * b.clone().conj() / AnyScalar::new_real(r), + ) } } @@ -1102,7 +1817,7 @@ fn solve_upper_triangular(h: &[Vec], g: &[AnyScalar]) -> Result(x: &T, v_basis: &[T], y: &[AnyScalar]) -> Result { +fn update_solution(x: &T, v_basis: &[T], y: &[AnyScalar]) -> Result { let mut result = x.clone(); for (vi, yi) in v_basis.iter().zip(y.iter()) { @@ -1126,7 +1841,7 @@ fn update_solution_truncated( truncate: &Tr, ) -> Result where - T: TensorLike, + T: TensorVectorSpace, Tr: Fn(&mut T) -> Result<()>, { let mut result = x.clone(); diff --git a/crates/tensor4all-core/src/krylov/tests/mod.rs b/crates/tensor4all-core/src/krylov/tests/mod.rs index 15922379..bc64b3b3 100644 --- a/crates/tensor4all-core/src/krylov/tests/mod.rs +++ b/crates/tensor4all-core/src/krylov/tests/mod.rs @@ -1,6 +1,549 @@ use super::*; use crate::defaults::tensordynlen::TensorDynLen; use crate::defaults::DynIndex; +use crate::tensor_index::TensorIndex; +use crate::TensorVectorSpace; +use num_complex::Complex64; +use std::sync::Mutex; + +static GMRES_PROFILE_ENV_LOCK: Mutex<()> = Mutex::new(()); + +struct GmresProfileEnvGuard; + +impl GmresProfileEnvGuard { + fn set() -> Self { + unsafe { + std::env::set_var("T4A_GMRES_OP_PROFILE", "1"); + } + Self + } +} + +impl Drop for GmresProfileEnvGuard { + fn drop(&mut self) { + unsafe { + std::env::remove_var("T4A_GMRES_OP_PROFILE"); + } + } +} + +fn with_gmres_profile_env(f: impl FnOnce() -> R) -> R { + let _lock = GMRES_PROFILE_ENV_LOCK.lock().unwrap(); + let _guard = GmresProfileEnvGuard::set(); + f() +} + +#[derive(Debug, Clone)] +struct PlainVector { + data: Vec, +} + +impl TensorIndex for PlainVector { + type Index = DynIndex; + + fn external_indices(&self) -> Vec { + Vec::new() + } + + fn replaceind(&self, _old_index: &Self::Index, _new_index: &Self::Index) -> Result { + Ok(self.clone()) + } + + fn replaceinds( + &self, + _old_indices: &[Self::Index], + _new_indices: &[Self::Index], + ) -> Result { + Ok(self.clone()) + } +} + +impl TensorVectorSpace for PlainVector { + fn norm_squared(&self) -> f64 { + self.data.iter().map(|x| x * x).sum() + } + + fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result { + anyhow::ensure!( + self.data.len() == other.data.len(), + "vector lengths must match" + ); + anyhow::ensure!( + a.is_real() && b.is_real(), + "PlainVector test helper only supports real coefficients" + ); + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(&x, &y)| a.real() * x + b.real() * y) + .collect(); + Ok(Self { data }) + } + + fn scale(&self, scalar: AnyScalar) -> Result { + anyhow::ensure!( + scalar.is_real(), + "PlainVector test helper only supports real coefficients" + ); + Ok(Self { + data: self.data.iter().map(|&x| scalar.real() * x).collect(), + }) + } + + fn inner_product(&self, other: &Self) -> Result { + anyhow::ensure!( + self.data.len() == other.data.len(), + "vector lengths must match" + ); + Ok(AnyScalar::new_real( + self.data + .iter() + .zip(other.data.iter()) + .map(|(&x, &y)| x * y) + .sum(), + )) + } + + fn maxabs(&self) -> f64 { + self.data.iter().map(|x| x.abs()).fold(0.0, f64::max) + } +} + +#[test] +fn gmres_accepts_vector_space_without_tensorlike() { + let b = PlainVector { + data: vec![1.0, -2.0], + }; + let x0 = PlainVector { + data: vec![0.0, 0.0], + }; + let result = gmres( + |x: &PlainVector| Ok(x.clone()), + &b, + &x0, + &GmresOptions::default(), + ) + .expect("GMRES should accept TensorVectorSpace-only values"); + + assert!(result.converged); + assert!(result.solution.sub(&b).unwrap().maxabs() < 1e-12); +} + +#[test] +fn gmres_absolute_tolerance_and_total_iteration_limit_paths() { + let idx = DynIndex::new_dyn(2); + let b = make_vector_with_index(vec![4.0, 9.0], &idx); + let x0 = make_vector_with_index(vec![0.0, 0.0], &idx); + let apply_a = |x: &TensorDynLen| -> Result { scale_vector_f64(x, &[2.0, 3.0]) }; + let options = GmresOptions { + max_iter: 4, + rtol: 1e-12, + max_restarts: 2, + verbose: false, + check_true_residual: true, + }; + + let result = gmres_with_absolute_tolerance(apply_a, &b, &x0, &options, 1e-10).unwrap(); + assert!(result.converged); + let expected = make_vector_with_index(vec![2.0, 3.0], &idx); + assert!(result.solution.sub(&expected).unwrap().maxabs() < 1e-10); + + let limited = gmres_with_total_iteration_limit( + |x: &TensorDynLen| scale_vector_f64(x, &[2.0, 3.0]), + &b, + &x0, + &options, + 0, + ) + .unwrap(); + assert!(!limited.converged); + assert_eq!(limited.iterations, 0); +} + +#[test] +fn gmres_affine_matches_shifted_system_and_scalar_shortcuts() { + let idx = DynIndex::new_dyn(2); + let b = make_vector_with_index(vec![10.0, 28.0], &idx); + let x0 = make_vector_with_index(vec![0.0, 0.0], &idx); + let options = GmresOptions { + max_iter: 4, + rtol: 1e-12, + max_restarts: 2, + verbose: true, + check_true_residual: true, + }; + + let result = gmres_affine_with_absolute_tolerance( + |x: &TensorDynLen| scale_vector_f64(x, &[2.0, 5.0]), + &b, + &x0, + AnyScalar::new_real(1.0), + AnyScalar::new_real(2.0), + &options, + 1e-10, + ) + .unwrap(); + assert!(result.converged); + let expected = make_vector_with_index(vec![2.0, 28.0 / 11.0], &idx); + assert!(result.solution.sub(&expected).unwrap().maxabs() < 1e-10); + + let scaled = gmres_affine( + |x: &TensorDynLen| scale_vector_f64(x, &[100.0, 100.0]), + &b, + &x0, + AnyScalar::new_real(2.0), + AnyScalar::new_real(0.0), + &options, + ) + .unwrap(); + assert!(scaled.converged); + assert_eq!(scaled.iterations, 0); + assert!( + scaled + .solution + .sub(&make_vector_with_index(vec![5.0, 14.0], &idx)) + .unwrap() + .maxabs() + < 1e-12 + ); + + let err = gmres_affine( + |x: &TensorDynLen| Ok(x.clone()), + &b, + &x0, + AnyScalar::new_real(0.0), + AnyScalar::new_real(0.0), + &options, + ) + .unwrap_err(); + assert!(err.to_string().contains("at least one affine coefficient")); +} + +#[test] +fn gmres_affine_profile_and_zero_rhs_paths() { + let idx = DynIndex::new_dyn(1); + let b = make_vector_with_index(vec![0.0], &idx); + let x0 = make_vector_with_index(vec![7.0], &idx); + let options = GmresOptions { + max_iter: 1, + rtol: 1e-12, + max_restarts: 1, + verbose: false, + check_true_residual: false, + }; + + let result = with_gmres_profile_env(|| { + gmres_affine( + |x: &TensorDynLen| Ok(x.clone()), + &b, + &x0, + AnyScalar::new_real(1.0), + AnyScalar::new_real(1.0), + &options, + ) + .unwrap() + }); + + assert!(result.converged); + assert_eq!(result.iterations, 0); + assert!(result.solution.sub(&x0).unwrap().maxabs() < 1e-12); +} + +#[test] +fn restart_gmres_options_builder_sets_all_fields() { + let options = RestartGmresOptions::new() + .with_max_outer_iters(3) + .with_rtol(1e-7) + .with_inner_max_iter(5) + .with_inner_max_restarts(2) + .with_min_reduction(0.8) + .with_inner_rtol(0.25) + .with_verbose(true); + + assert_eq!(options.max_outer_iters, 3); + assert_eq!(options.rtol, 1e-7); + assert_eq!(options.inner_max_iter, 5); + assert_eq!(options.inner_max_restarts, 2); + assert_eq!(options.min_reduction, Some(0.8)); + assert_eq!(options.inner_rtol, Some(0.25)); + assert!(options.verbose); +} + +#[test] +fn gmres_affine_profile_covers_nonconverged_restart_and_final_paths() { + let idx = DynIndex::new_dyn(2); + let b = make_vector_with_index(vec![1.0, 1.0], &idx); + let x0 = make_vector_with_index(vec![0.0, 0.0], &idx); + let options = GmresOptions { + max_iter: 1, + rtol: 1e-30, + max_restarts: 1, + verbose: true, + check_true_residual: false, + }; + + let result = with_gmres_profile_env(|| { + gmres_affine( + |x: &TensorDynLen| scale_vector_f64(x, &[2.0, 5.0]), + &b, + &x0, + AnyScalar::new_real(0.5), + AnyScalar::new_real(1.25), + &options, + ) + .unwrap() + }); + + assert!(!result.converged); + assert_eq!(result.iterations, 1); + assert!(result.residual_norm.is_finite()); +} + +#[test] +fn gmres_affine_profile_covers_scalar_shortcut() { + let idx = DynIndex::new_dyn(2); + let b = make_vector_with_index(vec![2.0, 4.0], &idx); + let x0 = make_vector_with_index(vec![0.0, 0.0], &idx); + let options = GmresOptions::default(); + + let result = with_gmres_profile_env(|| { + gmres_affine( + |x: &TensorDynLen| Ok(x.clone()), + &b, + &x0, + AnyScalar::new_real(2.0), + AnyScalar::new_real(0.0), + &options, + ) + .unwrap() + }); + + assert!(result.converged); + assert_eq!(result.iterations, 0); + assert_eq!(result.solution.to_vec::().unwrap(), vec![1.0, 2.0]); +} + +#[test] +fn gmres_lucky_breakdown_paths_are_reachable_with_zero_tolerance() { + let idx = DynIndex::new_dyn(1); + let b = make_vector_with_index(vec![3.0], &idx); + let x0 = make_vector_with_index(vec![0.0], &idx); + let options = GmresOptions { + max_iter: 1, + rtol: 0.0, + max_restarts: 1, + verbose: false, + check_true_residual: false, + }; + + let result = gmres(|x: &TensorDynLen| Ok(x.clone()), &b, &x0, &options).unwrap(); + assert!(!result.converged); + assert_eq!(result.iterations, 1); + assert!(result.solution.sub(&b).unwrap().maxabs() < 1e-12); + + let affine = gmres_affine( + |x: &TensorDynLen| Ok(x.clone()), + &b, + &x0, + AnyScalar::new_real(0.0), + AnyScalar::new_real(1.0), + &options, + ) + .unwrap(); + assert!(!affine.converged); + assert_eq!(affine.iterations, 1); + assert!(affine.solution.sub(&b).unwrap().maxabs() < 1e-12); + + let truncated = gmres_with_truncation( + |x: &TensorDynLen| Ok(x.clone()), + &b, + &x0, + &options, + |_: &mut TensorDynLen| Ok(()), + ) + .unwrap(); + assert!(!truncated.converged); + assert_eq!(truncated.iterations, 1); + assert!(truncated.solution.sub(&b).unwrap().maxabs() < 1e-12); +} + +#[test] +fn gmres_convergence_branches_cover_true_residual_and_affine_fast_finish() { + let idx = DynIndex::new_dyn(1); + let b = make_vector_with_index(vec![3.0], &idx); + let x0 = make_vector_with_index(vec![0.0], &idx); + let options = GmresOptions { + max_iter: 1, + rtol: 1e-12, + max_restarts: 1, + verbose: true, + check_true_residual: true, + }; + + let checked = gmres(|x: &TensorDynLen| Ok(x.clone()), &b, &x0, &options).unwrap(); + assert!(checked.converged); + assert!(checked.solution.sub(&b).unwrap().maxabs() < 1e-12); + + let truncated = gmres_with_truncation( + |x: &TensorDynLen| Ok(x.clone()), + &b, + &x0, + &options, + |_: &mut TensorDynLen| Ok(()), + ) + .unwrap(); + assert!(truncated.converged); + assert!(truncated.solution.sub(&b).unwrap().maxabs() < 1e-12); + + let no_true_check = GmresOptions { + check_true_residual: false, + ..options + }; + let affine = gmres_affine( + |x: &TensorDynLen| Ok(x.clone()), + &b, + &x0, + AnyScalar::new_real(0.0), + AnyScalar::new_real(1.0), + &no_true_check, + ) + .unwrap(); + assert!(affine.converged); + assert!(affine.solution.sub(&b).unwrap().maxabs() < 1e-12); +} + +#[test] +fn gmres_affine_profile_covers_true_residual_rejection_and_lucky_breakdown() { + use std::cell::Cell; + + let idx = DynIndex::new_dyn(1); + let b = make_vector_with_index(vec![1.0], &idx); + let x0 = make_vector_with_index(vec![0.0], &idx); + let checked_options = GmresOptions { + max_iter: 1, + rtol: 1e-12, + max_restarts: 1, + verbose: true, + check_true_residual: true, + }; + + let calls = Cell::new(0usize); + let checked = with_gmres_profile_env(|| { + gmres_affine( + |x: &TensorDynLen| { + let call = calls.get(); + calls.set(call + 1); + if call == 2 { + x.scale(AnyScalar::new_real(2.0)) + } else { + Ok(x.clone()) + } + }, + &b, + &x0, + AnyScalar::new_real(0.0), + AnyScalar::new_real(1.0), + &checked_options, + ) + .unwrap() + }); + assert!(checked.converged); + assert!(checked.solution.sub(&b).unwrap().maxabs() < 1e-12); + + let lucky_options = GmresOptions { + check_true_residual: false, + rtol: 0.0, + ..checked_options + }; + let lucky = with_gmres_profile_env(|| { + gmres_affine( + |x: &TensorDynLen| Ok(x.clone()), + &b, + &x0, + AnyScalar::new_real(0.0), + AnyScalar::new_real(1.0), + &lucky_options, + ) + .unwrap() + }); + assert!(!lucky.converged); + assert_eq!(lucky.iterations, 1); + assert!(lucky.solution.sub(&b).unwrap().maxabs() < 1e-12); +} + +#[test] +fn gmres_zero_cycle_and_restart_nonzero_update_paths() { + let idx = DynIndex::new_dyn(1); + let b = make_vector_with_index(vec![2.0], &idx); + let x0 = make_vector_with_index(vec![0.0], &idx); + let zero_cycle_options = GmresOptions { + max_iter: 0, + rtol: 1e-12, + max_restarts: 1, + verbose: false, + check_true_residual: false, + }; + + let zero_cycle = gmres_with_total_iteration_limit( + |x: &TensorDynLen| Ok(x.clone()), + &b, + &x0, + &zero_cycle_options, + 1, + ) + .unwrap(); + assert!(!zero_cycle.converged); + assert_eq!(zero_cycle.iterations, 0); + + let restart_options = RestartGmresOptions { + max_outer_iters: 2, + rtol: 0.0, + inner_max_iter: 1, + inner_max_restarts: 0, + min_reduction: None, + inner_rtol: Some(0.0), + verbose: true, + }; + let restarted = restart_gmres_with_truncation( + |x: &TensorDynLen| Ok(x.clone()), + &b, + None, + &restart_options, + |_: &mut TensorDynLen| Ok(()), + ) + .unwrap(); + assert!(!restarted.converged); + assert_eq!(restarted.outer_iterations, 2); + assert!(restarted.solution.sub(&b).unwrap().maxabs() < 1e-12); +} + +#[test] +fn gmres_private_helpers_cover_edge_paths() { + let zero = AnyScalar::new_real(0.0); + let (c, s) = compute_givens_rotation(&zero, &zero); + assert!(c.is_real()); + assert!(s.is_real()); + assert_eq!(c.real(), 1.0); + assert_eq!(s.real(), 0.0); + + let empty = solve_upper_triangular(&[], &[]).unwrap(); + assert!(empty.is_empty()); + + let singular = solve_upper_triangular( + &[vec![AnyScalar::new_real(0.0)]], + &[AnyScalar::new_real(1.0)], + ) + .unwrap_err(); + assert!(singular.to_string().contains("Near-singular")); + + let idx = DynIndex::new_dyn(2); + let x = make_vector_with_index(vec![1.0, 1.0], &idx); + let v = make_vector_with_index(vec![2.0, -1.0], &idx); + let updated = update_solution(&x, &[v], &[AnyScalar::new_real(3.0)]).unwrap(); + assert_eq!(updated.to_vec::().unwrap(), vec![7.0, -2.0]); +} + /// Helper to create a 1D tensor (vector) with given data and shared index. fn make_vector_with_index(data: Vec, idx: &DynIndex) -> TensorDynLen { TensorDynLen::from_dense(vec![idx.clone()], data).unwrap() @@ -25,6 +568,19 @@ fn apply_matrix2_f64(x: &TensorDynLen, a_data: &[f64; 4]) -> Result, idx: &DynIndex) -> TensorDynLen { + TensorDynLen::from_dense(vec![idx.clone()], data).unwrap() +} + +fn apply_matrix2_c64(x: &TensorDynLen, a_data: &[Complex64; 4]) -> Result { + let x_data = x.to_vec::()?; + let result_data = vec![ + a_data[0] * x_data[0] + a_data[1] * x_data[1], + a_data[2] * x_data[0] + a_data[3] * x_data[1], + ]; + Ok(TensorDynLen::from_dense(x.indices.clone(), result_data).unwrap()) +} + #[test] fn test_givens_rotation_real() { let a = AnyScalar::new_real(3.0); @@ -61,20 +617,19 @@ fn test_givens_rotation_complex() { let b = AnyScalar::new_complex(1.0, -2.0); let (c, s) = compute_givens_rotation(&a, &b); - assert!(c.is_complex()); + assert!(c.is_real()); assert!(s.is_complex()); - assert_eq!(c.is_real(), s.is_real()); - // c*a + s*b should recover sqrt(|a|^2 + |b|^2) on the real axis. - let rotated = c.clone() * a + s.clone() * b; - assert!(rotated.is_complex()); - assert!(rotated.real().is_finite()); - assert!(rotated.imag().is_finite()); + let (first, second) = apply_givens_rotation(&c, &s, &a, &b); + let expected_norm = (3.0_f64 * 3.0 + 4.0 * 4.0 + 1.0 * 1.0 + 2.0 * 2.0).sqrt(); + assert!((first.abs() - expected_norm).abs() < 1e-12); + assert!(second.abs() < 1e-12, "{second:?}"); + assert!(((c.abs() * c.abs() + s.abs() * s.abs()) - 1.0).abs() < 1e-12); } #[test] fn test_apply_givens_rotation_complex() { - let c = AnyScalar::new_complex(0.6, 0.1); + let c = AnyScalar::new_real(0.6); let s = AnyScalar::new_complex(0.8, -0.2); let x = AnyScalar::new_complex(3.0, 1.0); let y = AnyScalar::new_complex(4.0, -2.0); @@ -87,6 +642,30 @@ fn test_apply_givens_rotation_complex() { assert!(new_y.real().is_finite() && new_y.imag().is_finite()); } +#[test] +fn test_complex_givens_rotation_eliminates_second_component() { + let cases = [ + ( + AnyScalar::new_complex(2.0, -3.0), + AnyScalar::new_complex(-1.5, 0.75), + ), + ( + AnyScalar::new_complex(0.0, 0.0), + AnyScalar::new_complex(1.0, -2.0), + ), + ( + AnyScalar::new_complex(-0.25, 0.5), + AnyScalar::new_complex(3.0, 4.0), + ), + ]; + + for (a, b) in cases { + let (c, s) = compute_givens_rotation(&a, &b); + let (_first, second) = apply_givens_rotation(&c, &s, &a, &b); + assert!(second.abs() < 1e-12, "a={a:?}, b={b:?}, second={second:?}"); + } +} + #[test] fn test_gmres_identity_operator() { // Solve A x = b where A = I (identity) @@ -169,6 +748,87 @@ fn test_gmres_diagonal_matrix() { ); } +#[test] +fn test_gmres_complex_nonsymmetric_matrix() { + let idx = DynIndex::new_dyn(2); + let expected_x = make_vector_c64( + vec![Complex64::new(1.0, -0.5), Complex64::new(-2.0, 0.75)], + &idx, + ); + let x0 = make_vector_c64(vec![Complex64::new(0.0, 0.0); 2], &idx); + let a_data = [ + Complex64::new(2.0, 0.5), + Complex64::new(-1.0, 0.25), + Complex64::new(0.75, -0.5), + Complex64::new(1.5, 1.0), + ]; + let b = apply_matrix2_c64(&expected_x, &a_data).unwrap(); + let apply_a = move |x: &TensorDynLen| apply_matrix2_c64(x, &a_data); + let options = GmresOptions { + max_iter: 2, + rtol: 1e-12, + max_restarts: 1, + verbose: false, + check_true_residual: true, + }; + + let result = gmres(apply_a, &b, &x0, &options).unwrap(); + assert!(result.converged, "residual={}", result.residual_norm); + let diff = result + .solution + .axpby( + AnyScalar::new_real(1.0), + &expected_x, + AnyScalar::new_real(-1.0), + ) + .unwrap(); + assert!(diff.norm() < 1e-10, "solution error={}", diff.norm()); +} + +#[test] +fn test_gmres_with_total_iteration_limit_shortens_final_restart() { + let idx = DynIndex::new_dyn(6); + let b = make_vector_with_index(vec![1.0; 6], &idx); + let x0 = make_vector_with_index(vec![0.0; 6], &idx); + + let diag = [1.0, 1.7, 2.3, 3.1, 4.2, 5.6]; + let apply_a = move |x: &TensorDynLen| scale_vector_f64(x, &diag); + let options = GmresOptions { + max_iter: 3, + rtol: 0.0, + max_restarts: 3, + verbose: false, + check_true_residual: false, + }; + + let result = gmres_with_total_iteration_limit(apply_a, &b, &x0, &options, 4).unwrap(); + + assert_eq!(result.iterations, 4); + assert!(!result.converged); +} + +#[test] +fn test_gmres_with_total_iteration_limit_allows_zero_iterations() { + let idx = DynIndex::new_dyn(3); + let b = make_vector_with_index(vec![1.0, 2.0, 3.0], &idx); + let x0 = make_vector_with_index(vec![0.0, 0.0, 0.0], &idx); + + let apply_a = |x: &TensorDynLen| -> Result { Ok(x.clone()) }; + let options = GmresOptions { + max_iter: 3, + rtol: 1e-10, + max_restarts: 2, + verbose: false, + check_true_residual: false, + }; + + let result = gmres_with_total_iteration_limit(apply_a, &b, &x0, &options, 0).unwrap(); + + assert_eq!(result.iterations, 0); + assert!(!result.converged); + assert!(result.solution.sub(&x0).unwrap().norm() < 1.0e-12); +} + #[test] fn test_gmres_nonsymmetric_matrix() { // Solve A x = b where A is a 2x2 non-symmetric matrix diff --git a/crates/tensor4all-core/src/lib.rs b/crates/tensor4all-core/src/lib.rs index cb0cd0e1..2ce42b0c 100644 --- a/crates/tensor4all-core/src/lib.rs +++ b/crates/tensor4all-core/src/lib.rs @@ -54,9 +54,7 @@ pub use index_like::{ConjState, IndexLike}; pub mod index_ops; pub use index_ops::{ check_unique_indices, common_ind_positions, common_inds, hascommoninds, hasind, hasinds, - noncommon_inds, prepare_contraction, prepare_contraction_pairs, replaceinds, - replaceinds_in_place, union_inds, unique_inds, ContractionError, ContractionSpec, - ReplaceIndsError, + noncommon_inds, replaceinds, replaceinds_in_place, union_inds, unique_inds, ReplaceIndsError, }; pub use smallstring::{SmallChar, SmallString, SmallStringError}; pub use tagset::{Tag, TagSetError, TagSetLike}; @@ -85,16 +83,19 @@ pub use tensor4all_tensorbackend::{ print_and_reset_native_einsum_profile, reset_native_einsum_profile, }; pub use tensor_like::{ - AllowedPairs, Canonical, DirectSumResult, FactorizeAlg, FactorizeError, FactorizeOptions, - FactorizeResult, LinearizationOrder, TensorLike, + Canonical, DirectSumResult, FactorizeAlg, FactorizeError, FactorizeOptions, FactorizeResult, + LinearizationOrder, TensorConstructionLike, TensorContractionLike, TensorFactorizationLike, + TensorLike, TensorVectorSpace, }; -// Contraction - backwards compatibility -pub use defaults::contract; pub use defaults::contract::{ - contract_connected, contract_connected_with_options, contract_multi, contract_multi_owned, - contract_multi_with_options, print_and_reset_contract_profile, reset_contract_profile, - ContractionOptions, + contract, contract_owned, contract_owned_with_options, contract_pair, + contract_pair_with_operand_options, contract_pair_with_options, contract_with_options, + outer_product, print_and_reset_contract_profile, reset_contract_profile, tensordot, + ContractionOptions, PairwiseContractionOptions, +}; +pub use defaults::tensordynlen::{ + print_and_reset_pairwise_contract_profile, reset_pairwise_contract_profile, }; // Re-export linear algebra modules from defaults for backwards compatibility diff --git a/crates/tensor4all-core/src/tensor_like.rs b/crates/tensor4all-core/src/tensor_like.rs index 075fa88a..a9c44265 100644 --- a/crates/tensor4all-core/src/tensor_like.rs +++ b/crates/tensor4all-core/src/tensor_like.rs @@ -31,7 +31,9 @@ use thiserror::Error; /// # Examples /// /// ``` -/// use tensor4all_core::{factorize, Canonical, DynIndex, FactorizeOptions, TensorDynLen}; +/// use tensor4all_core::{ +/// factorize, Canonical, DynIndex, FactorizeOptions, TensorContractionLike, TensorDynLen, +/// }; /// /// let i = DynIndex::new_dyn(3); /// let j = DynIndex::new_dyn(3); @@ -104,7 +106,7 @@ pub enum FactorizeError { /// Factorization algorithm. /// -/// Determines which matrix decomposition is used by [`TensorLike::factorize`]. +/// Determines which matrix decomposition is used by [`TensorFactorizationLike::factorize`]. /// /// # Examples /// @@ -135,7 +137,9 @@ pub enum FactorizeAlg { /// # Examples /// /// ``` -/// use tensor4all_core::{factorize, Canonical, DynIndex, FactorizeOptions, TensorDynLen}; +/// use tensor4all_core::{ +/// factorize, Canonical, DynIndex, FactorizeOptions, TensorContractionLike, TensorDynLen, +/// }; /// /// let i = DynIndex::new_dyn(3); /// let j = DynIndex::new_dyn(3); @@ -157,8 +161,8 @@ pub enum FactorizeAlg { /// ).unwrap(); /// /// // Both recover the same tensor -/// let recovered_left = left_result.left.contract(&left_result.right).unwrap(); -/// let recovered_right = right_result.left.contract(&right_result.right).unwrap(); +/// let recovered_left = left_result.left.contract_pair(&left_result.right).unwrap(); +/// let recovered_right = right_result.left.contract_pair(&right_result.right).unwrap(); /// assert!(recovered_left.distance(&recovered_right).unwrap() < 1e-12); /// ``` #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] @@ -179,7 +183,7 @@ pub enum Canonical { /// Options for tensor factorization. /// /// Controls the algorithm, canonical direction, and truncation parameters -/// for [`TensorLike::factorize`]. +/// for [`TensorFactorizationLike::factorize`]. /// /// # Defaults /// @@ -426,7 +430,9 @@ impl FactorizeOptions { /// # Examples /// /// ``` -/// use tensor4all_core::{factorize, DynIndex, FactorizeOptions, TensorDynLen}; +/// use tensor4all_core::{ +/// factorize, DynIndex, FactorizeOptions, TensorContractionLike, TensorDynLen, +/// }; /// /// let i = DynIndex::new_dyn(3); /// let j = DynIndex::new_dyn(4); @@ -436,7 +442,7 @@ impl FactorizeOptions { /// let result = factorize(&tensor, &[i.clone()], &FactorizeOptions::svd()).unwrap(); /// /// // Contracting left * right recovers the original tensor -/// let recovered = result.left.contract(&result.right).unwrap(); +/// let recovered = result.left.contract_pair(&result.right).unwrap(); /// assert!(tensor.distance(&recovered).unwrap() < 1e-12); /// /// // SVD provides singular values @@ -444,7 +450,7 @@ impl FactorizeOptions { /// assert_eq!(result.singular_values.as_ref().unwrap().len(), result.rank); /// ``` #[derive(Debug, Clone)] -pub struct FactorizeResult { +pub struct FactorizeResult { /// Left factor tensor. pub left: T, /// Right factor tensor. @@ -461,61 +467,6 @@ pub struct FactorizeResult { // Contraction types // ============================================================================ -/// Specifies which tensor pairs are allowed to contract. -/// -/// This enum controls which tensor pairs can have their indices contracted -/// in multi-tensor contraction operations. This is useful for tensor networks -/// where the graph structure determines which tensors are connected. -/// -/// # Example -/// -/// ``` -/// use tensor4all_core::{AllowedPairs, DynIndex, TensorDynLen, TensorLike}; -/// -/// # fn main() -> anyhow::Result<()> { -/// let i = DynIndex::new_dyn(2); -/// let j = DynIndex::new_dyn(2); -/// let k = DynIndex::new_dyn(2); -/// -/// let a = TensorDynLen::from_dense( -/// vec![i.clone(), j.clone()], -/// vec![1.0, 0.0, 0.0, 1.0], -/// )?; -/// let b = TensorDynLen::from_dense( -/// vec![j.clone(), k.clone()], -/// vec![1.0, 2.0, 3.0, 4.0], -/// )?; -/// let c = TensorDynLen::from_dense(vec![k.clone()], vec![1.0, 10.0])?; -/// -/// let tensor_refs: Vec<&TensorDynLen> = vec![&a, &b, &c]; -/// let all = ::contract(&tensor_refs, AllowedPairs::All)?; -/// -/// let edges = vec![(0, 1), (1, 2)]; -/// let specified = -/// ::contract(&tensor_refs, AllowedPairs::Specified(&edges))?; -/// -/// assert_eq!(all.dims(), vec![2]); -/// assert!(all.sub(&specified)?.maxabs() < 1e-12); -/// # Ok(()) -/// # } -/// ``` -#[derive(Debug, Clone, Copy)] -pub enum AllowedPairs<'a> { - /// All tensor pairs are allowed to contract. - /// - /// Indices with matching IDs across any two tensors will be contracted. - /// This is the default behavior, equivalent to ITensor's `*` operator. - All, - /// Only specified tensor pairs are allowed to contract. - /// - /// Each pair is `(tensor_idx_a, tensor_idx_b)` into the input tensor slice. - /// Indices are only contracted if they belong to an allowed pair. - /// - /// This is useful for tensor networks where the graph structure - /// determines which tensors are connected (e.g., TreeTN edges). - Specified(&'a [(usize, usize)]), -} - /// Linearization order used when fusing or unfusing multiple logical indices /// into one physical index. /// @@ -557,502 +508,20 @@ impl LinearizationOrder { } // ============================================================================ -// TensorLike trait (fully generic) +// Capability traits (fully generic) // ============================================================================ -/// Trait for tensor-like objects that expose external indices and support contraction. -/// -/// This trait is **fully generic** (monomorphic), meaning it does not support -/// trait objects (`dyn TensorLike`). For heterogeneous tensor collections, -/// use an enum wrapper instead. -/// -/// # Design Principles -/// -/// - **Minimal interface**: Only external indices and automatic contraction -/// - **Fully generic**: Uses associated type for `Index`, returns `Self` -/// - **Stable ordering**: `external_indices()` returns indices in deterministic order -/// - **No trait objects**: Requires `Sized`, cannot use `dyn TensorLike` -/// -/// # Example -/// -/// ``` -/// use tensor4all_core::{AllowedPairs, DynIndex, TensorDynLen, TensorLike}; -/// -/// fn contract_pair(a: &TensorDynLen, b: &TensorDynLen) -> anyhow::Result { -/// Ok(::contract(&[a, b], AllowedPairs::All)?) -/// } -/// -/// # fn main() -> anyhow::Result<()> { -/// let i = DynIndex::new_dyn(2); -/// let j = DynIndex::new_dyn(2); -/// let a = TensorDynLen::from_dense( -/// vec![i.clone(), j.clone()], -/// vec![1.0, 0.0, 0.0, 1.0], -/// )?; -/// let b = TensorDynLen::from_dense(vec![j.clone()], vec![2.0, 3.0])?; +/// Vector-space operations for iterative linear algebra over tensor-like values. /// -/// let result = contract_pair(&a, &b)?; -/// assert_eq!(result.to_vec::()?, vec![2.0, 3.0]); -/// # Ok(()) -/// # } -/// ``` -/// -/// # Heterogeneous Collections -/// -/// For mixing different tensor types, define an enum: -/// -/// ``` -/// use tensor4all_core::{block_tensor::BlockTensor, DynIndex, TensorDynLen}; -/// -/// let i = DynIndex::new_dyn(2); -/// let dense = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap(); -/// let block = BlockTensor::new(vec![dense.clone()], (1, 1)).unwrap(); -/// -/// enum TensorNetwork { -/// Dense(TensorDynLen), -/// Block(BlockTensor), -/// } -/// -/// let network = TensorNetwork::Block(block); -/// assert!(matches!(network, TensorNetwork::Block(_))); -/// ``` -/// -/// # Supertrait -/// -/// `TensorLike` extends `TensorIndex`, which provides: -/// - `external_indices()` - Get all external indices -/// - `num_external_indices()` - Count external indices -/// - `replaceind()` / `replaceinds()` - Replace indices -/// -/// This separation allows tensor networks (like `TreeTN`) to implement -/// index operations without implementing contraction/factorization. -pub trait TensorLike: TensorIndex { - /// Factorize this tensor into left and right factors. - /// - /// This function dispatches to the appropriate algorithm based on `options.alg`: - /// - `SVD`: Singular Value Decomposition - /// - `QR`: QR decomposition - /// - `LU`: Rank-revealing LU decomposition - /// - `CI`: Cross Interpolation - /// - /// The `canonical` option controls which factor is "canonical": - /// - `Canonical::Left`: Left factor is orthogonal (SVD/QR) or unit-diagonal (LU/CI) - /// - `Canonical::Right`: Right factor is orthogonal (SVD) or unit-diagonal (LU/CI) - /// - /// # Arguments - /// * `left_inds` - Indices to place on the left side - /// * `options` - Factorization options - /// - /// # Returns - /// A `FactorizeResult` containing the left and right factors, bond index, - /// singular values (for SVD), and rank. - /// - /// # Errors - /// Returns `FactorizeError` if: - /// - The storage type is not supported (only DenseF64 and DenseC64) - /// - QR is used with `Canonical::Right` - /// - The underlying algorithm fails - fn factorize( - &self, - left_inds: &[::Index], - options: &FactorizeOptions, - ) -> std::result::Result, FactorizeError>; - - /// Factorize this tensor without applying truncation controls. - /// - /// Use this for exact tensor rewrites such as canonicalization, where the - /// contracted factors must preserve the represented tensor up to numerical - /// roundoff. Unlike [`Self::factorize`], this method must not consult global - /// SVD/QR/LU truncation defaults or apply maximum-rank limits. - /// - /// # Arguments - /// * `left_inds` - Indices to place on the left side. - /// * `alg` - Decomposition algorithm to use. - /// * `canonical` - Which factor should carry the canonical form. - /// - /// # Returns - /// A `FactorizeResult` containing the left and right factors, bond index, - /// singular values for SVD, and retained exact numerical rank. - /// - /// # Errors - /// Returns [`FactorizeError`] if: - /// - the storage type is not supported, - /// - the requested canonical direction is unsupported for the algorithm, or - /// - the underlying decomposition fails. - /// - /// # Examples - /// - /// ``` - /// use tensor4all_core::{ - /// Canonical, DynIndex, FactorizeAlg, TensorDynLen, TensorLike, - /// }; - /// - /// let i = DynIndex::new_dyn(2); - /// let j = DynIndex::new_dyn(2); - /// let tensor = TensorDynLen::from_dense( - /// vec![i.clone(), j.clone()], - /// vec![1.0_f64, 0.0, 0.0, 1.0e-16], - /// )?; - /// - /// let result = tensor.factorize_full_rank( - /// std::slice::from_ref(&i), - /// FactorizeAlg::QR, - /// Canonical::Left, - /// )?; - /// let reconstructed = result.left.contract(&result.right)?; - /// assert!(tensor.distance(&reconstructed)? < 1.0e-18); - /// # Ok::<(), Box>(()) - /// ``` - fn factorize_full_rank( - &self, - left_inds: &[::Index], - alg: FactorizeAlg, - canonical: Canonical, - ) -> std::result::Result, FactorizeError>; - - /// Tensor conjugate operation. - /// - /// This is a generalized conjugate operation that depends on the tensor type: - /// - For dense tensors (TensorDynLen): element-wise complex conjugate - /// - For symmetric tensors: tensor conjugate considering symmetry sectors - /// - /// This operation is essential for computing inner products and overlaps - /// in tensor network algorithms like fitting. - /// - /// # Returns - /// A new tensor representing the tensor conjugate. - fn conj(&self) -> Self; - - /// Direct sum of two tensors along specified index pairs. - /// - /// For tensors A and B with indices to be summed specified as pairs, - /// creates a new tensor C where each paired index has dimension = dim_A + dim_B. - /// Non-paired indices must match exactly between A and B (same ID). - /// - /// # Arguments - /// - /// * `other` - Second tensor - /// * `pairs` - Pairs of (self_index, other_index) to be summed. Each pair creates - /// a new index in the result with dimension = dim(self_index) + dim(other_index). - /// - /// # Returns - /// - /// A `DirectSumResult` containing the result tensor and new indices created - /// for the summed dimensions (one per pair). - /// - /// # Example - /// - /// ``` - /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; - /// - /// # fn main() -> anyhow::Result<()> { - /// let j = DynIndex::new_dyn(2); - /// let k = DynIndex::new_dyn(3); - /// - /// let a = TensorDynLen::from_dense(vec![j.clone()], vec![1.0, 2.0])?; - /// let b = TensorDynLen::from_dense(vec![k.clone()], vec![3.0, 4.0, 5.0])?; - /// let result = a.direct_sum(&b, &[(j.clone(), k.clone())])?; - /// - /// assert_eq!(result.new_indices.len(), 1); - /// assert_eq!(result.tensor.dims(), vec![5]); - /// assert_eq!(result.tensor.to_vec::()?, vec![1.0, 2.0, 3.0, 4.0, 5.0]); - /// # Ok(()) - /// # } - /// ``` - fn direct_sum( - &self, - other: &Self, - pairs: &[(::Index, ::Index)], - ) -> Result>; - - /// Outer product (tensor product) of two tensors. - /// - /// Computes the tensor product of `self` and `other`, resulting in a tensor - /// with all indices from both tensors. No indices are contracted. - /// - /// # Arguments - /// - /// * `other` - The other tensor to compute outer product with - /// - /// # Returns - /// - /// A new tensor representing the outer product. - /// - /// # Errors - /// - /// Returns an error if the tensors have common indices (by ID). - /// Use `tensordot` for contraction when indices overlap. - fn outer_product(&self, other: &Self) -> Result; - +/// This trait intentionally does not require tensor contraction/einsum, +/// factorization, or tensor-network construction. Krylov solvers should depend +/// on this trait instead of [`TensorLike`] so block vectors and other abstract +/// state types do not have to provide unrelated tensor-network operations. +pub trait TensorVectorSpace: TensorIndex { /// Compute the squared Frobenius norm of the tensor. - /// - /// The squared Frobenius norm is defined as the sum of squared absolute values - /// of all tensor elements: `||T||_F^2 = sum_i |T_i|^2`. - /// - /// This is used for computing norms in tensor network algorithms, - /// convergence checks, and normalization. - /// - /// # Returns - /// The squared Frobenius norm as a non-negative f64. fn norm_squared(&self) -> f64; - /// Permute tensor indices to match the specified order. - /// - /// This reorders the tensor's axes to match the order specified by `new_order`. - /// The indices in `new_order` are matched by ID with the tensor's current indices. - /// - /// # Arguments - /// - /// * `new_order` - The desired order of indices (matched by ID) - /// - /// # Returns - /// - /// A new tensor with permuted indices. - /// - /// # Errors - /// - /// Returns an error if: - /// - The number of indices doesn't match - /// - An index ID in `new_order` is not found in the tensor - fn permuteinds(&self, new_order: &[::Index]) -> Result; - - /// Fuse local tensor indices into one replacement index. - /// - /// This is a local axis fusion operation: it reshapes the tensor so - /// `old_indices` are replaced by `new_index`. Indices are matched by ID, - /// and `old_indices` must be non-empty. The order of `old_indices` defines - /// the fused coordinate linearization, while `order` defines how those - /// coordinates map to the replacement axis. `new_index.dim()` must equal - /// the product of the matched axis dimensions. The replacement index is - /// inserted at the earliest fused axis position, and the remaining indices - /// retain their relative order. - /// - /// Implementations should return `Err` if this operation is unsupported or - /// if exact local fusion cannot be represented by the tensor type. - /// - /// # Arguments - /// - /// * `old_indices` - Existing local indices to fuse, matched by ID. Must be non-empty. - /// * `new_index` - Replacement index whose dimension is the fused product. - /// * `order` - Linearization order for mapping old coordinates to the fused axis. - /// - /// # Returns - /// - /// A new tensor with `old_indices` replaced by `new_index`. - /// - /// # Errors - /// - /// Returns an error if: - /// - `old_indices` is empty - /// - Any requested index ID is missing from the tensor - /// - The replacement dimension does not match the product of fused axis dimensions - /// - The tensor type does not support local axis fusion - /// - Exact local fusion cannot be represented by the tensor type - /// - /// # Examples - /// - /// ``` - /// use tensor4all_core::{ - /// DynIndex, IndexLike, LinearizationOrder, TensorDynLen, TensorLike, - /// }; - /// - /// # fn main() -> anyhow::Result<()> { - /// let i = DynIndex::new_dyn(2); - /// let j = DynIndex::new_dyn(3); - /// let k = DynIndex::new_dyn(2); - /// let fused = DynIndex::new_link(6)?; - /// let data: Vec = (0..12).map(|value| value as f64).collect(); - /// let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone(), k.clone()], data)?; - /// - /// let fused_tensor = ::fuse_indices( - /// &tensor, - /// &[j.clone(), i.clone()], - /// fused.clone(), - /// LinearizationOrder::ColumnMajor, - /// )?; - /// - /// assert_eq!(fused_tensor.indices(), &[fused.clone(), k.clone()]); - /// assert_eq!(fused_tensor.dims(), vec![6, 2]); - /// - /// let roundtrip = fused_tensor - /// .unfuse_index(&fused, &[j, i], LinearizationOrder::ColumnMajor)? - /// .permuteinds(tensor.indices())?; - /// assert!(roundtrip.isapprox(&tensor, 1e-12, 0.0)); - /// # Ok(()) - /// # } - /// ``` - fn fuse_indices( - &self, - old_indices: &[::Index], - new_index: ::Index, - order: LinearizationOrder, - ) -> Result; - - /// Contract multiple tensors over their contractable indices. - /// - /// This method contracts 2 or more tensors. Pairs of indices that satisfy - /// `is_contractable()` (same ID, same dimension, compatible ConjState) - /// are contracted based on the `allowed` parameter. - /// - /// Handles disconnected tensor graphs automatically by: - /// 1. Finding connected components based on contractable indices - /// 2. Contracting each connected component separately - /// 3. Combining results using outer product - /// - /// # Arguments - /// - /// * `tensors` - Slice of tensor references to contract (must have length >= 1) - /// * `allowed` - Specifies which tensor pairs can have their indices contracted: - /// - `AllowedPairs::All`: Contract all contractable index pairs (default behavior) - /// - `AllowedPairs::Specified(&[(i, j)])`: Only contract indices between specified tensor pairs - /// - /// # Returns - /// - /// A new tensor representing the contracted result. - /// If tensors form disconnected components, they are combined via outer product. - /// - /// # Behavior by N - /// - N=0: Error - /// - N=1: Clone of input - /// - N>=2: Contract connected components, combine with outer product - /// - /// # Errors - /// - /// Returns an error if: - /// - No tensors are provided - /// - `AllowedPairs::Specified` contains a pair with no contractable indices - /// - /// # Example - /// - /// ``` - /// use tensor4all_core::{AllowedPairs, DynIndex, TensorDynLen, TensorLike}; - /// - /// # fn main() -> anyhow::Result<()> { - /// let i = DynIndex::new_dyn(2); - /// let j = DynIndex::new_dyn(2); - /// let k = DynIndex::new_dyn(2); - /// - /// let a = TensorDynLen::from_dense( - /// vec![i.clone(), j.clone()], - /// vec![1.0, 0.0, 0.0, 1.0], - /// )?; - /// let b = TensorDynLen::from_dense( - /// vec![j.clone(), k.clone()], - /// vec![1.0, 2.0, 3.0, 4.0], - /// )?; - /// let c = TensorDynLen::from_dense(vec![k.clone()], vec![1.0, 10.0])?; - /// - /// let all = ::contract(&[&a, &b, &c], AllowedPairs::All)?; - /// let specified = ::contract( - /// &[&a, &b, &c], - /// AllowedPairs::Specified(&[(0, 1), (1, 2)]), - /// )?; - /// - /// assert_eq!(all.dims(), vec![2]); - /// assert!(all.sub(&specified)?.maxabs() < 1e-12); - /// # Ok(()) - /// # } - /// ``` - fn contract(tensors: &[&Self], allowed: AllowedPairs<'_>) -> Result; - - /// Contract this tensor with one other tensor using default pairwise semantics. - /// - /// This contracts all compatible common indices between `self` and `other`. - /// Implementations may override it with a specialized two-tensor path. The - /// default implementation calls [`Self::contract`] with two inputs and - /// [`AllowedPairs::All`]. - /// - /// # Arguments - /// * `other` - The tensor to contract with. It should share at least one - /// compatible common index unless the implementation supports outer - /// products through its default pairwise semantics. - /// - /// # Returns - /// The tensor produced by contracting `self` and `other`. - /// - /// # Errors - /// Returns an error if pairwise contraction cannot be performed, for example - /// because common indices have incompatible dimensions or the contraction - /// executor fails. - /// - /// # Examples - /// - /// ``` - /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; - /// - /// let i = DynIndex::new_dyn(2); - /// let j = DynIndex::new_dyn(2); - /// let k = DynIndex::new_dyn(2); - /// let a = TensorDynLen::from_dense( - /// vec![i.clone(), j.clone()], - /// vec![1.0_f64, 0.0, 0.0, 1.0], - /// )?; - /// let b = TensorDynLen::from_dense(vec![j.clone(), k.clone()], vec![2.0, 3.0, 4.0, 5.0])?; - /// - /// let result = a.contract_pair(&b)?; - /// assert_eq!(result.dims(), vec![2, 2]); - /// assert_eq!(result.to_vec::()?, vec![2.0, 3.0, 4.0, 5.0]); - /// # Ok::<(), anyhow::Error>(()) - /// ``` - fn contract_pair(&self, other: &Self) -> Result { - Self::contract(&[self, other], AllowedPairs::All) - } - - /// Contract multiple tensors that must form a connected graph. - /// - /// This is the core contraction method that requires all tensors to be - /// connected through contractable indices. Use [`Self::contract`] if you want - /// automatic handling of disconnected components via outer product. - /// - /// # Arguments - /// - /// * `tensors` - Slice of tensor references to contract (must form a connected graph) - /// * `allowed` - Specifies which tensor pairs can have their indices contracted - /// - /// # Returns - /// - /// A new tensor representing the contracted result. - /// - /// # Errors - /// - /// Returns an error if: - /// - No tensors are provided - /// - The tensors form a disconnected graph - /// - /// # Example - /// - /// ``` - /// use tensor4all_core::{AllowedPairs, DynIndex, TensorDynLen, TensorLike}; - /// - /// # fn main() -> anyhow::Result<()> { - /// let i = DynIndex::new_dyn(2); - /// let j = DynIndex::new_dyn(2); - /// let k = DynIndex::new_dyn(2); - /// - /// let a = TensorDynLen::from_dense( - /// vec![i.clone(), j.clone()], - /// vec![1.0, 0.0, 0.0, 1.0], - /// )?; - /// let b = TensorDynLen::from_dense( - /// vec![j.clone(), k.clone()], - /// vec![1.0, 2.0, 3.0, 4.0], - /// )?; - /// let c = TensorDynLen::from_dense(vec![k.clone()], vec![1.0, 10.0])?; - /// - /// let result = TensorDynLen::contract_connected(&[&a, &b, &c], AllowedPairs::All)?; - /// assert_eq!(result.dims(), vec![2]); - /// # Ok(()) - /// # } - /// ``` - fn contract_connected(tensors: &[&Self], allowed: AllowedPairs<'_>) -> Result; - - // ======================================================================== - // Vector space operations (for Krylov solvers) - // ======================================================================== - /// Compute a linear combination: `a * self + b * other`. - /// - /// This is the fundamental vector space operation. fn axpby(&self, a: AnyScalar, other: &Self, b: AnyScalar) -> Result; /// Scalar multiplication. @@ -1069,58 +538,14 @@ pub trait TensorLike: TensorIndex { } /// Try to compute the maximum absolute value of all tensor elements. - /// - /// This is the error-aware form of [`Self::maxabs`]. Tensor - /// implementations should return an error when an exact elementwise maximum - /// would require hidden full-network materialization or another unsupported - /// operation. - /// - /// # Returns - /// The L-infinity norm when it is available without violating the tensor - /// representation's scalability contract. - /// - /// # Errors - /// Returns an error when the tensor type cannot compute an exact maximum - /// absolute value through a safe/scalable implementation. - /// - /// # Examples - /// - /// ``` - /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; - /// - /// let i = DynIndex::new_dyn(3); - /// let tensor = TensorDynLen::from_dense(vec![i], vec![1.0, -3.0, 2.0])?; - /// assert_eq!(tensor.try_maxabs()?, 3.0); - /// # Ok::<(), Box>(()) - /// ``` fn try_maxabs(&self) -> Result { Ok(self.maxabs()) } /// Maximum absolute value of all elements (L-infinity norm). - /// - /// This infallible method is kept for dense/reference tensor code. Generic - /// code should prefer [`Self::try_maxabs`] so tensor-network - /// implementations can report that exact elementwise maxima are - /// unsupported instead of hiding dense materialization. Implementations - /// that cannot compute this value safely may return `NaN`. fn maxabs(&self) -> f64; /// Element-wise subtraction: `self - other`. - /// - /// Indices are automatically permuted to match `self`'s order via `axpby`. - /// - /// # Examples - /// - /// ``` - /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; - /// - /// let i = DynIndex::new_dyn(2); - /// let a = TensorDynLen::from_dense(vec![i.clone()], vec![5.0, 3.0]).unwrap(); - /// let b = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 1.0]).unwrap(); - /// let diff = a.sub(&b).unwrap(); - /// assert_eq!(diff.to_vec::().unwrap(), vec![4.0, 2.0]); - /// ``` fn sub(&self, other: &Self) -> Result { self.axpby(AnyScalar::new_real(1.0), other, AnyScalar::new_real(-1.0)) } @@ -1131,22 +556,6 @@ pub trait TensorLike: TensorIndex { } /// Approximate equality check (Julia `isapprox` semantics). - /// - /// Returns `true` if `||self - other|| <= max(atol, rtol * max(||self||, ||other||))`. - /// - /// # Examples - /// - /// ``` - /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; - /// - /// let i = DynIndex::new_dyn(3); - /// let a = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0, 3.0]).unwrap(); - /// let b = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0, 3.0]).unwrap(); - /// assert!(a.isapprox(&b, 1e-12, 0.0)); - /// - /// let c = TensorDynLen::from_dense(vec![i.clone()], vec![1.1, 2.0, 3.0]).unwrap(); - /// assert!(!a.isapprox(&c, 1e-3, 0.0)); - /// ``` fn isapprox(&self, other: &Self, atol: f64, rtol: f64) -> bool { let diff = match self.sub(other) { Ok(d) => d, @@ -1156,94 +565,87 @@ pub trait TensorLike: TensorIndex { diff_norm <= atol.max(rtol * self.norm().max(other.norm())) } + /// Validate structural consistency of this tensor-like vector. + fn validate(&self) -> Result<()> { + Ok(()) + } +} + +/// Contraction/einsum-style operations for tensor-like values. +/// +/// Types that only need vector-space algebra should not implement or require +/// this trait. Tree tensor-network algorithms should use this trait when they +/// truly need index-based contraction. +pub trait TensorContractionLike: TensorIndex { + /// Tensor conjugate operation. + fn conj(&self) -> Self; + + /// Direct sum of two tensors along specified index pairs. + fn direct_sum( + &self, + other: &Self, + pairs: &[(::Index, ::Index)], + ) -> Result>; + + /// Outer product (tensor product) of two tensors. + fn outer_product(&self, other: &Self) -> Result; + + /// Permute tensor indices to match the specified order. + fn permuteinds(&self, new_order: &[::Index]) -> Result; + + /// Fuse local tensor indices into one replacement index. + fn fuse_indices( + &self, + old_indices: &[::Index], + new_index: ::Index, + order: LinearizationOrder, + ) -> Result; + + /// Contract a connected tensor network over its contractable indices. + fn contract(tensors: &[&Self]) -> Result; + + /// Contract this tensor with one other tensor using default pairwise semantics. + fn contract_pair(&self, other: &Self) -> Result { + Self::contract(&[self, other]) + } + /// Validate structural consistency of this tensor. - /// - /// The default implementation does nothing (always succeeds). - /// Types with internal structure (for example, block-sparse tensors) can override - /// this to check invariants such as index sharing between blocks. fn validate(&self) -> Result<()> { Ok(()) } +} + +/// Factorization operations for tensor-like values. +pub trait TensorFactorizationLike: TensorIndex { + /// Factorize this tensor into left and right factors. + fn factorize( + &self, + left_inds: &[::Index], + options: &FactorizeOptions, + ) -> std::result::Result, FactorizeError>; + /// Factorize this tensor without applying truncation controls. + fn factorize_full_rank( + &self, + left_inds: &[::Index], + alg: FactorizeAlg, + canonical: Canonical, + ) -> std::result::Result, FactorizeError>; +} + +/// Constructors and selection helpers for index-labelled tensors. +pub trait TensorConstructionLike: TensorContractionLike { /// Create a diagonal (Kronecker delta) tensor for a single index pair. - /// - /// Creates a 2D tensor `T[i, o]` where `T[i, o] = δ_{i,o}` (1 if i==o, 0 otherwise). - /// - /// # Arguments - /// - /// * `input_index` - Input index - /// * `output_index` - Output index (must have same dimension as input) - /// - /// # Returns - /// - /// A 2D tensor with shape `[dim, dim]` representing the identity matrix. - /// - /// # Errors - /// - /// Returns an error if dimensions don't match. - /// - /// # Examples - /// - /// ``` - /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; - /// - /// let i = DynIndex::new_dyn(3); - /// let o = DynIndex::new_dyn(3); - /// let delta = TensorDynLen::diagonal(&i, &o).unwrap(); - /// - /// assert_eq!(delta.dims(), vec![3, 3]); - /// let data = delta.to_vec::().unwrap(); - /// // Identity matrix in column-major: [1,0,0, 0,1,0, 0,0,1] - /// assert!((data[0] - 1.0).abs() < 1e-12); - /// assert!((data[4] - 1.0).abs() < 1e-12); - /// assert!((data[8] - 1.0).abs() < 1e-12); - /// assert!((data[1]).abs() < 1e-12); - /// ``` fn diagonal( input_index: &::Index, output_index: &::Index, ) -> Result; /// Create a delta (identity) tensor as outer product of diagonals. - /// - /// For paired indices `(i1, o1), (i2, o2), ...`, creates a tensor where: - /// `T[i1, o1, i2, o2, ...] = δ_{i1,o1} × δ_{i2,o2} × ...` - /// - /// This is computed as the outer product of individual diagonal tensors. - /// - /// # Arguments - /// - /// * `input_indices` - Input indices - /// * `output_indices` - Output indices (must have same length and matching dimensions) - /// - /// # Returns - /// - /// A tensor representing the identity operator on the given index space. - /// - /// # Errors - /// - /// Returns an error if: - /// - Number of input and output indices don't match - /// - Dimensions of paired indices don't match - /// - /// # Examples - /// - /// ``` - /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; - /// - /// let i1 = DynIndex::new_dyn(2); - /// let o1 = DynIndex::new_dyn(2); - /// let i2 = DynIndex::new_dyn(3); - /// let o2 = DynIndex::new_dyn(3); - /// - /// let d = TensorDynLen::delta(&[i1, i2], &[o1, o2]).unwrap(); - /// assert_eq!(d.dims(), vec![2, 2, 3, 3]); - /// ``` fn delta( input_indices: &[::Index], output_indices: &[::Index], ) -> Result { - // Validate same number of input and output indices if input_indices.len() != output_indices.len() { return Err(anyhow::anyhow!( "Number of input indices ({}) must match output indices ({})", @@ -1253,11 +655,9 @@ pub trait TensorLike: TensorIndex { } if input_indices.is_empty() { - // Return a scalar tensor with value 1.0 return Self::scalar_one(); } - // Build as outer product of diagonal tensors let mut result = Self::diagonal(&input_indices[0], &output_indices[0])?; for (inp, out) in input_indices[1..].iter().zip(output_indices[1..].iter()) { let diag = Self::diagonal(inp, out)?; @@ -1267,85 +667,12 @@ pub trait TensorLike: TensorIndex { } /// Create a scalar tensor with value 1.0. - /// - /// This is used as the identity element for outer products. - /// - /// # Examples - /// - /// ``` - /// use tensor4all_core::{TensorDynLen, TensorLike}; - /// - /// let one = TensorDynLen::scalar_one().unwrap(); - /// assert_eq!(one.dims(), Vec::::new()); - /// assert!((one.only().unwrap().real() - 1.0).abs() < 1e-12); - /// ``` fn scalar_one() -> Result; /// Create a tensor filled with 1.0 for the given indices. - /// - /// This is useful for adding indices to tensors via outer product - /// without changing tensor values (since multiplying by 1 is identity). - /// - /// # Example - /// To add a dummy index `l` to tensor `T`: - /// ``` - /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; - /// - /// # fn main() -> anyhow::Result<()> { - /// let i = DynIndex::new_dyn(2); - /// let l = DynIndex::new_dyn(3); - /// let t = TensorDynLen::from_dense(vec![i.clone()], vec![2.0, 4.0])?; - /// - /// let ones = TensorDynLen::ones(&[l.clone()])?; - /// let t_with_l = t.outer_product(&ones)?; - /// - /// assert_eq!(t_with_l.dims(), vec![2, 3]); - /// assert_eq!(t_with_l.to_vec::()?, vec![2.0, 4.0, 2.0, 4.0, 2.0, 4.0]); - /// # Ok(()) - /// # } - /// ``` fn ones(indices: &[::Index]) -> Result; /// Select fixed coordinates for a subset of this tensor's external indices. - /// - /// This returns a new tensor with `selected_indices` removed and the - /// corresponding coordinates fixed to `positions`. Implementations may - /// override this with direct slicing. The default implementation contracts - /// with a one-hot tensor, so it works for any tensor type that supports - /// [`Self::onehot`] and [`Self::contract`]. - /// - /// # Arguments - /// * `selected_indices` - External indices to fix. Each index must appear - /// at most once and should be present in the tensor. - /// * `positions` - Zero-based coordinates, one for each selected index. - /// - /// # Returns - /// A tensor with the selected indices removed. If no indices are selected, - /// this returns a clone of `self`. - /// - /// # Errors - /// Returns an error if the argument lengths differ, if an index is repeated, - /// if a coordinate is outside the index dimension, or if the one-hot - /// contraction fails. - /// - /// # Examples - /// - /// ``` - /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; - /// - /// let i = DynIndex::new_dyn(2); - /// let j = DynIndex::new_dyn(3); - /// let tensor = TensorDynLen::from_dense( - /// vec![i.clone(), j.clone()], - /// vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0], - /// )?; - /// - /// let selected = - /// ::select_indices(&tensor, &[j], &[1])?; - /// assert_eq!(selected.dims(), vec![2]); - /// assert_eq!(selected.to_vec::()?, vec![3.0, 4.0]); - /// # Ok::<(), anyhow::Error>(()) - /// ``` fn select_indices( &self, selected_indices: &[::Index], @@ -1382,39 +709,92 @@ pub trait TensorLike: TensorIndex { .zip(positions.iter().copied()) .collect::>(); let onehot = Self::onehot(&index_vals)?; - Self::contract(&[self, &onehot], AllowedPairs::All) + Self::contract(&[self, &onehot]) } /// Create a one-hot tensor with value 1.0 at the specified index positions. - /// - /// Similar to ITensors.jl's `onehot(i => 1, j => 2)`. - /// - /// # Arguments - /// * `index_vals` - Pairs of (Index, 0-indexed position) - /// - /// # Errors - /// Returns error if any value >= corresponding index dimension. - /// - /// # Examples - /// - /// ``` - /// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; - /// - /// let i = DynIndex::new_dyn(3); - /// let j = DynIndex::new_dyn(2); - /// - /// // One-hot at i=1, j=0 - /// let t = TensorDynLen::onehot(&[(i, 1), (j, 0)]).unwrap(); - /// assert_eq!(t.dims(), vec![3, 2]); - /// - /// let data = t.to_vec::().unwrap(); - /// // column-major 3x2: element at (1,0) = index 1 - /// assert!((data[1] - 1.0).abs() < 1e-12); - /// assert!((t.sum().unwrap().real() - 1.0).abs() < 1e-12); // exactly one non-zero - /// ``` fn onehot(index_vals: &[(::Index, usize)]) -> Result; } +// ============================================================================ +// TensorLike trait (fully generic composite) +// ============================================================================ + +/// Trait for tensor-like objects that expose external indices and support contraction. +/// +/// This trait is **fully generic** (monomorphic), meaning it does not support +/// trait objects (`dyn TensorLike`). For heterogeneous tensor collections, +/// use an enum wrapper instead. +/// +/// # Design Principles +/// +/// - **Capability composition**: combines vector-space, factorization, construction, and contraction traits +/// - **Fully generic**: Uses associated type for `Index`, returns `Self` +/// - **Stable ordering**: `external_indices()` returns indices in deterministic order +/// - **No trait objects**: Requires `Sized`, cannot use `dyn TensorLike` +/// +/// # Example +/// +/// ``` +/// use tensor4all_core::{DynIndex, TensorContractionLike, TensorDynLen}; +/// +/// fn contract_pair(a: &TensorDynLen, b: &TensorDynLen) -> anyhow::Result { +/// Ok(::contract(&[a, b])?) +/// } +/// +/// # fn main() -> anyhow::Result<()> { +/// let i = DynIndex::new_dyn(2); +/// let j = DynIndex::new_dyn(2); +/// let a = TensorDynLen::from_dense( +/// vec![i.clone(), j.clone()], +/// vec![1.0, 0.0, 0.0, 1.0], +/// )?; +/// let b = TensorDynLen::from_dense(vec![j.clone()], vec![2.0, 3.0])?; +/// +/// let result = contract_pair(&a, &b)?; +/// assert_eq!(result.to_vec::()?, vec![2.0, 3.0]); +/// # Ok(()) +/// # } +/// ``` +/// +/// # Heterogeneous Collections +/// +/// For mixing different tensor types, define an enum: +/// +/// ``` +/// use tensor4all_core::{block_tensor::BlockTensor, DynIndex, TensorDynLen}; +/// +/// let i = DynIndex::new_dyn(2); +/// let dense = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap(); +/// let block = BlockTensor::new(vec![dense.clone()], (1, 1)).unwrap(); +/// +/// enum TensorNetwork { +/// Dense(TensorDynLen), +/// Block(BlockTensor), +/// } +/// +/// let network = TensorNetwork::Block(block); +/// assert!(matches!(network, TensorNetwork::Block(_))); +/// ``` +/// +/// # Supertrait +/// +/// `TensorLike` extends several capability traits. Through those traits it provides: +/// - `external_indices()` - Get all external indices +/// - `num_external_indices()` - Count external indices +/// - `replaceind()` / `replaceinds()` - Replace indices +/// - vector-space operations such as `axpby`, `inner_product`, and `norm` +/// - tensor-network operations such as contraction, construction, and factorization +/// +/// Use narrower traits such as [`TensorVectorSpace`] or +/// [`TensorContractionLike`] when an algorithm does not need the full surface. +pub trait TensorLike: TensorVectorSpace + TensorFactorizationLike + TensorConstructionLike {} + +impl TensorLike for T where + T: TensorVectorSpace + TensorFactorizationLike + TensorConstructionLike +{ +} + /// Result of direct sum operation. /// /// Contains the resulting tensor and the new indices created for the summed @@ -1423,7 +803,7 @@ pub trait TensorLike: TensorIndex { /// # Examples /// /// ``` -/// use tensor4all_core::{DynIndex, TensorDynLen, TensorLike, IndexLike}; +/// use tensor4all_core::{DynIndex, IndexLike, TensorContractionLike, TensorDynLen}; /// /// let i = DynIndex::new_dyn(2); /// let j = DynIndex::new_dyn(3); @@ -1439,7 +819,7 @@ pub trait TensorLike: TensorIndex { /// assert_eq!(result.tensor.to_vec::().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0]); /// ``` #[derive(Debug, Clone)] -pub struct DirectSumResult { +pub struct DirectSumResult { /// The resulting tensor from direct sum. pub tensor: T, /// New indices created for the summed dimensions (one per pair). diff --git a/crates/tensor4all-core/tests/ad_integration.rs b/crates/tensor4all-core/tests/ad_integration.rs index a5f7157a..0b8dac5b 100644 --- a/crates/tensor4all-core/tests/ad_integration.rs +++ b/crates/tensor4all-core/tests/ad_integration.rs @@ -1,4 +1,6 @@ -use tensor4all_core::{factorize_full_rank, svd, Canonical, FactorizeAlg, Index, TensorDynLen}; +use tensor4all_core::{ + factorize_full_rank, svd, Canonical, FactorizeAlg, Index, TensorContractionLike, TensorDynLen, +}; fn assert_f64_slice_close(actual: &[f64], expected: &[f64], tol: f64) { assert_eq!(actual.len(), expected.len()); @@ -73,7 +75,7 @@ fn factorize_qr_reconstruction_preserves_gradient_to_input() { Canonical::Left, ) .unwrap(); - let reconstructed = result.left.contract(&result.right).unwrap(); + let reconstructed = result.left.contract_pair(&result.right).unwrap(); let loss = reconstructed.sum().unwrap(); assert!(loss.tracks_grad()); loss.backward().unwrap(); @@ -92,7 +94,7 @@ fn assert_ci_reconstruction_gradient(canonical: Canonical) { let result = factorize_full_rank(&x, std::slice::from_ref(&i), FactorizeAlg::CI, canonical).unwrap(); - let reconstructed = result.left.contract(&result.right).unwrap(); + let reconstructed = result.left.contract_pair(&result.right).unwrap(); let loss = reconstructed.sum().unwrap(); assert!(loss.tracks_grad()); loss.backward().unwrap(); diff --git a/crates/tensor4all-core/tests/bench_contract_fit_patterns.rs b/crates/tensor4all-core/tests/bench_contract_fit_patterns.rs index 6d1cdaea..181422d5 100644 --- a/crates/tensor4all-core/tests/bench_contract_fit_patterns.rs +++ b/crates/tensor4all-core/tests/bench_contract_fit_patterns.rs @@ -1,7 +1,7 @@ use std::hint::black_box; use std::time::{Duration, Instant}; -use tensor4all_core::{AllowedPairs, DynIndex, TensorDynLen, TensorLike}; +use tensor4all_core::{DynIndex, TensorContractionLike, TensorDynLen}; use tensor4all_tensorbackend::{dense_native_tensor_from_col_major, einsum_native_tensors}; fn make_data(dims: &[usize], offset: usize) -> Vec { @@ -217,8 +217,7 @@ fn bench_contract_fit_patterns_vs_native() { eprintln!("\n=== TensorDynLen contract vs native einsum ==="); let env3_contract = time_best_of("env3 TensorDynLen::contract", 2_000, || { - ::contract(&[&env3_a, &env3_b, &env3_c], AllowedPairs::All) - .unwrap() + ::contract(&[&env3_a, &env3_b, &env3_c]).unwrap() }); let env3_native = time_best_of("env3 native einsum", 2_000, || { einsum_native_tensors( @@ -233,11 +232,8 @@ fn bench_contract_fit_patterns_vs_native() { }); let env4_contract = time_best_of("env4 TensorDynLen::contract", 600, || { - ::contract( - &[&env4_a, &env4_b, &env4_c, &env4_d], - AllowedPairs::All, - ) - .unwrap() + ::contract(&[&env4_a, &env4_b, &env4_c, &env4_d]) + .unwrap() }); let env4_native = time_best_of("env4 native einsum", 600, || { einsum_native_tensors( @@ -252,10 +248,9 @@ fn bench_contract_fit_patterns_vs_native() { .unwrap() }); let env6_contract = time_best_of("env6 TensorDynLen::contract", 400, || { - ::contract( - &[&env6_a, &env6_b, &env6_c, &env6_d, &env6_e, &env6_f], - AllowedPairs::All, - ) + ::contract(&[ + &env6_a, &env6_b, &env6_c, &env6_d, &env6_e, &env6_f, + ]) .unwrap() }); let env6_native = time_best_of("env6 native einsum", 400, || { diff --git a/crates/tensor4all-core/tests/bug_qr_after_permute.rs b/crates/tensor4all-core/tests/bug_qr_after_permute.rs index f836245a..2e618203 100644 --- a/crates/tensor4all-core/tests/bug_qr_after_permute.rs +++ b/crates/tensor4all-core/tests/bug_qr_after_permute.rs @@ -7,7 +7,9 @@ //! SVD gives ~1e-14 on the same data. use num_complex::Complex64; -use tensor4all_core::{factorize, svd, DynIndex, FactorizeOptions, TensorDynLen}; +use tensor4all_core::{ + factorize, svd, DynIndex, FactorizeOptions, TensorContractionLike, TensorDynLen, +}; /// Create a [5,2,2,5] tensor with data that triggers the QR bug. fn make_buggy_tensor() -> TensorDynLen { @@ -126,7 +128,7 @@ fn make_buggy_tensor() -> TensorDynLen { fn reconstruction_error(t: &TensorDynLen, left_inds: &[DynIndex], opts: &FactorizeOptions) -> f64 { let result = factorize(t, left_inds, opts).unwrap(); - let recon = result.left.contract(&result.right).unwrap(); + let recon = result.left.contract_pair(&result.right).unwrap(); let neg = recon .scale(tensor4all_core::AnyScalar::new_real(-1.0)) .unwrap(); @@ -140,11 +142,11 @@ fn svd_reconstruction_error(t: &TensorDynLen, left_inds: &[DynIndex]) -> f64 { perm.extend(0..v.indices.len() - 1); let vh = v.conj().permute(&perm).unwrap(); let svh = s - .contract(&vh) + .contract_pair(&vh) .unwrap() .replaceind(&s.indices[1], &u.indices[u.indices.len() - 1]) .unwrap(); - let recon = u.contract(&svh).unwrap(); + let recon = u.contract_pair(&svh).unwrap(); let neg = recon .scale(tensor4all_core::AnyScalar::new_real(-1.0)) .unwrap(); diff --git a/crates/tensor4all-core/tests/common_index_ops.rs b/crates/tensor4all-core/tests/common_index_ops.rs index 0a65fec7..30241f58 100644 --- a/crates/tensor4all-core/tests/common_index_ops.rs +++ b/crates/tensor4all-core/tests/common_index_ops.rs @@ -1,7 +1,7 @@ use tensor4all_core::index::DefaultIndex as Index; use tensor4all_core::index_ops::{ - common_inds, hascommoninds, hasind, hasinds, noncommon_inds, prepare_contraction_pairs, - replaceinds, replaceinds_in_place, union_inds, unique_inds, ReplaceIndsError, + common_inds, hascommoninds, hasind, hasinds, noncommon_inds, replaceinds, replaceinds_in_place, + union_inds, unique_inds, ReplaceIndsError, }; use tensor4all_core::IndexLike; @@ -221,24 +221,6 @@ fn test_index_ops_distinguish_same_id_prime_pair() { assert_eq!(replaced, vec![i, replacement]); } -#[test] -fn test_prepare_contraction_pairs_selects_exact_same_id_prime_index() { - let i = Index::new_dyn(2); - let i_prime = i.prime(); - let spec = prepare_contraction_pairs( - &[i.clone(), i_prime.clone()], - &[2, 2], - std::slice::from_ref(&i_prime), - &[2], - &[(i_prime.clone(), i_prime.clone())], - ) - .unwrap(); - - assert_eq!(spec.axes_a, vec![1]); - assert_eq!(spec.axes_b, vec![0]); - assert_eq!(spec.result_indices, vec![i]); -} - #[test] fn test_replaceinds_multiple_replacements() { let i = Index::new_dyn(2); diff --git a/crates/tensor4all-core/tests/error_paths.rs b/crates/tensor4all-core/tests/error_paths.rs index ee09312c..b1ed4aba 100644 --- a/crates/tensor4all-core/tests/error_paths.rs +++ b/crates/tensor4all-core/tests/error_paths.rs @@ -4,7 +4,8 @@ use tensor4all_core::block_tensor::BlockTensor; use tensor4all_core::col_major_array::{ColMajorArrayError, ColMajorArrayRef}; use tensor4all_core::index_like::IndexLike; use tensor4all_core::{ - compute_permutation_from_indices, diag_tensor_dyn_len, DynIndex, TensorDynLen, + compute_permutation_from_indices, diag_tensor_dyn_len, DynIndex, TensorContractionLike, + TensorDynLen, }; use tensor4all_tensorbackend::Storage; @@ -51,7 +52,7 @@ fn tensor_contract_rejects_mismatched_common_dimension() { let a = TensorDynLen::from_dense(vec![shared_left], vec![1.0_f64, 2.0]).unwrap(); let b = TensorDynLen::from_dense(vec![shared_right], vec![1.0_f64, 2.0, 3.0]).unwrap(); - let err = a.contract(&b).unwrap_err(); + let err = a.contract_pair(&b).unwrap_err(); let message = err.to_string(); assert!( diff --git a/crates/tensor4all-core/tests/linalg_factorize.rs b/crates/tensor4all-core/tests/linalg_factorize.rs index a25ef659..980ca01f 100644 --- a/crates/tensor4all-core/tests/linalg_factorize.rs +++ b/crates/tensor4all-core/tests/linalg_factorize.rs @@ -3,9 +3,9 @@ use tensor4all_core::index::Index; use tensor4all_core::{ factorize, factorize_full_rank, Canonical, DynIndex, FactorizeAlg, FactorizeError, - FactorizeOptions, + FactorizeOptions, TensorContractionLike, }; -use tensor4all_core::{SvdTruncationPolicy, TensorDynLen, TensorLike}; +use tensor4all_core::{SvdTruncationPolicy, TensorDynLen}; // ============================================================================ // Test Data Helpers @@ -62,7 +62,7 @@ fn test_factorize_reconstruction(options: &FactorizeOptions) { let result = factorize(&tensor, &left_inds, options).unwrap(); // Verify reconstruction: left * right ≈ original - let reconstructed = result.left.contract(&result.right).unwrap(); + let reconstructed = result.left.contract_pair(&result.right).unwrap(); assert_tensors_approx_equal(&tensor, &reconstructed, 1e-10); } @@ -163,7 +163,7 @@ fn test_factorize_svd_rank3() { let options = FactorizeOptions::svd(); let result = factorize(&tensor, &left_inds, &options).unwrap(); - let reconstructed = result.left.contract(&result.right).unwrap(); + let reconstructed = result.left.contract_pair(&result.right).unwrap(); assert_tensors_approx_equal(&tensor, &reconstructed, 1e-10); } @@ -233,7 +233,7 @@ fn test_factorize_lu_ci_reconstruction_with_unit_dim_axis() { qr_rtol: None, }; let result = factorize(&tensor, &left_inds, &options).unwrap(); - let reconstructed = result.left.contract(&result.right).unwrap(); + let reconstructed = result.left.contract_pair(&result.right).unwrap(); assert_tensors_approx_equal(&tensor, &reconstructed, 1e-10); } } @@ -252,7 +252,7 @@ fn test_factorize_lu_ci_reconstruction_with_col_major_matrix_input() { qr_rtol: None, }; let result = factorize(&tensor, &left_inds, &options).unwrap(); - let reconstructed = result.left.contract(&result.right).unwrap(); + let reconstructed = result.left.contract_pair(&result.right).unwrap(); assert_tensors_approx_equal(&tensor, &reconstructed, 1e-10); } } @@ -292,7 +292,7 @@ fn test_factorize_full_rank_preserves_near_dependent_components() { "{alg:?} full-rank factorization dropped a near-dependent component" ); - let reconstructed = result.left.contract(&result.right).unwrap(); + let reconstructed = result.left.contract_pair(&result.right).unwrap(); assert_tensors_approx_equal(&tensor, &reconstructed, 1.0e-18); } } @@ -357,10 +357,10 @@ fn test_diag_dense_contraction_svd_internals() { assert!(common_found, "S and V should share a common index"); // Contractions should work - let sv = s.contract(&v).unwrap(); + let sv = s.contract_pair(&v).unwrap(); assert_eq!(sv.dims().len(), 2, "S*V should be a 2D tensor"); - let us = u.contract(&s).unwrap(); + let us = u.contract_pair(&s).unwrap(); assert_eq!(us.dims().len(), 2, "U*S should be a 2D tensor"); } diff --git a/crates/tensor4all-core/tests/linalg_qr.rs b/crates/tensor4all-core/tests/linalg_qr.rs index cfb96c8d..b25c51b5 100644 --- a/crates/tensor4all-core/tests/linalg_qr.rs +++ b/crates/tensor4all-core/tests/linalg_qr.rs @@ -1,7 +1,7 @@ use num_complex::Complex64; use tensor4all_core::index::DefaultIndex as Index; use tensor4all_core::{qr, DynIndex}; -use tensor4all_core::{TensorDynLen, TensorLike}; +use tensor4all_core::{TensorContractionLike, TensorDynLen}; fn dense_f64(indices: Vec, data: Vec) -> TensorDynLen { TensorDynLen::from_dense(indices, data).unwrap() @@ -84,7 +84,7 @@ fn test_qr_reconstruction() { let (q, r) = qr::(&tensor, std::slice::from_ref(&i)).expect("QR should succeed"); // Reconstruct: A = Q * R - let reconstructed = q.contract(&r).unwrap(); + let reconstructed = q.contract_pair(&r).unwrap(); // Check reconstruction accuracy assert!( @@ -181,7 +181,7 @@ fn test_qr_nontrivial_split_reconstruction() { let tensor = dense_f64(vec![i.clone(), j.clone(), k.clone(), l.clone()], data); let (q, r) = qr::(&tensor, &[i.clone(), k.clone()]).expect("QR should succeed"); - let reconstructed = q.contract(&r).unwrap(); + let reconstructed = q.contract_pair(&r).unwrap(); assert!( tensor.isapprox(&reconstructed, 1e-8, 0.0), @@ -207,7 +207,7 @@ fn test_qr_complex_reconstruction() { let (q, r) = qr::(&tensor, std::slice::from_ref(&i_idx)).expect("Complex QR should succeed"); - let reconstructed = q.contract(&r).unwrap(); + let reconstructed = q.contract_pair(&r).unwrap(); assert!( tensor.isapprox(&reconstructed, 1e-8, 0.0), "Complex QR reconstruction failed: maxabs diff = {}", @@ -218,7 +218,7 @@ fn test_qr_complex_reconstruction() { /// Helper: compute ||Q*R - T|| via tensor contraction for any tensor shape. fn qr_reconstruction_error_f64(t: &TensorDynLen, left_inds: &[DynIndex]) -> f64 { let (q, r) = qr::(t, left_inds).expect("QR should succeed"); - let recon = q.contract(&r).unwrap(); + let recon = q.contract_pair(&r).unwrap(); let neg = recon .scale(tensor4all_core::AnyScalar::new_real(-1.0)) .unwrap(); @@ -228,7 +228,7 @@ fn qr_reconstruction_error_f64(t: &TensorDynLen, left_inds: &[DynIndex]) -> f64 fn qr_reconstruction_error_c64(t: &TensorDynLen, left_inds: &[DynIndex]) -> f64 { let (q, r) = qr::(t, left_inds).expect("QR should succeed"); - let recon = q.contract(&r).unwrap(); + let recon = q.contract_pair(&r).unwrap(); let neg = recon .scale(tensor4all_core::AnyScalar::new_real(-1.0)) .unwrap(); diff --git a/crates/tensor4all-core/tests/linalg_svd.rs b/crates/tensor4all-core/tests/linalg_svd.rs index 69c5e9e9..cbbdd914 100644 --- a/crates/tensor4all-core/tests/linalg_svd.rs +++ b/crates/tensor4all-core/tests/linalg_svd.rs @@ -4,7 +4,7 @@ use tensor4all_core::{ default_svd_truncation_policy, set_default_svd_truncation_policy, svd, svd_with, SvdOptions, SvdTruncationPolicy, }; -use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; +use tensor4all_core::{DynIndex, TensorContractionLike, TensorDynLen}; fn dense_f64(indices: Vec, data: Vec) -> TensorDynLen { TensorDynLen::from_dense(indices, data).unwrap() @@ -28,11 +28,11 @@ fn vh_from_v(v: &TensorDynLen) -> TensorDynLen { fn reconstruct_from_svd(u: &TensorDynLen, s: &TensorDynLen, v: &TensorDynLen) -> TensorDynLen { let vh = vh_from_v(v); - let svh = s.contract(&vh).unwrap(); + let svh = s.contract_pair(&vh).unwrap(); let sim_bond = s.indices[1].clone(); let bond = v.indices[v.indices.len() - 1].clone(); let svh = svh.replaceind(&sim_bond, &bond).unwrap(); - u.contract(&svh).unwrap() + u.contract_pair(&svh).unwrap() } #[test] diff --git a/crates/tensor4all-core/tests/tensor_any_scalar.rs b/crates/tensor4all-core/tests/tensor_any_scalar.rs index d5aef075..6cf02ee2 100644 --- a/crates/tensor4all-core/tests/tensor_any_scalar.rs +++ b/crates/tensor4all-core/tests/tensor_any_scalar.rs @@ -1,4 +1,4 @@ -use num_complex::Complex64; +use num_complex::{Complex32, Complex64}; use num_traits::{One, Zero}; use tensor4all_core::AnyScalar; @@ -8,6 +8,31 @@ fn test_is_complex() { assert!(AnyScalar::new_complex(1.0, 0.0).is_complex()); } +#[test] +fn test_from_f32_and_complex32_scalar_paths() { + let real = AnyScalar::from_value(1.25_f32); + assert_eq!(real.as_f64(), Some(1.25)); + assert_eq!(real.real(), 1.25); + assert_eq!(real.imag(), 0.0); + assert_eq!(real.abs(), 1.25); + assert!(!real.is_zero()); + assert_eq!(Complex64::from(real), Complex64::new(1.25, 0.0)); + + let zero = AnyScalar::from_value(0.0_f32); + assert!(zero.is_zero()); + + let complex = AnyScalar::from_value(Complex32::new(3.0, -4.0)); + assert!(complex.is_complex()); + assert_eq!(complex.real(), 3.0); + assert_eq!(complex.imag(), -4.0); + assert_eq!(complex.abs(), 5.0); + assert_eq!(complex.as_c64(), Some(Complex64::new(3.0, -4.0))); + assert_eq!(Complex64::from(complex), Complex64::new(3.0, -4.0)); + + let complex_zero = AnyScalar::from_value(Complex32::new(0.0, 0.0)); + assert!(complex_zero.is_zero()); +} + #[test] fn test_real() { assert_eq!(AnyScalar::new_real(3.5).real(), 3.5); diff --git a/crates/tensor4all-core/tests/tensor_comparison.rs b/crates/tensor4all-core/tests/tensor_comparison.rs index 0f610103..ca8b1e1d 100644 --- a/crates/tensor4all-core/tests/tensor_comparison.rs +++ b/crates/tensor4all-core/tests/tensor_comparison.rs @@ -1,6 +1,6 @@ use num_complex::Complex64; use tensor4all_core::index::DefaultIndex as Index; -use tensor4all_core::{diag_tensor_dyn_len, TensorDynLen, TensorLike}; +use tensor4all_core::{diag_tensor_dyn_len, TensorDynLen}; #[test] fn test_sub_identical_tensors_is_zero() { diff --git a/crates/tensor4all-core/tests/tensor_contract_multi_pair_equivalence.rs b/crates/tensor4all-core/tests/tensor_contract_nary_pair_equivalence.rs similarity index 76% rename from crates/tensor4all-core/tests/tensor_contract_multi_pair_equivalence.rs rename to crates/tensor4all-core/tests/tensor_contract_nary_pair_equivalence.rs index 4f540271..39785b68 100644 --- a/crates/tensor4all-core/tests/tensor_contract_multi_pair_equivalence.rs +++ b/crates/tensor4all-core/tests/tensor_contract_nary_pair_equivalence.rs @@ -1,6 +1,7 @@ use tensor4all_core::index::DefaultIndex as Index; use tensor4all_core::{ - factorize, svd, AllowedPairs, Canonical, DynIndex, FactorizeOptions, TensorDynLen, TensorLike, + factorize, outer_product, svd, Canonical, DynIndex, FactorizeOptions, TensorContractionLike, + TensorDynLen, }; fn make_tensor(indices: Vec, data: Vec, dims: &[usize]) -> TensorDynLen { @@ -44,7 +45,7 @@ fn permute_col_major(data: &[f64], dims: &[usize], perm: &[usize]) -> Vec { } #[test] -fn test_contract_multi_pair_matches_binary_contract() { +fn test_contract_nary_pair_matches_binary_contract() { let l01 = Index::new_dyn(3); let s1 = Index::new_dyn(2); let l12 = Index::new_dyn(3); @@ -63,19 +64,18 @@ fn test_contract_multi_pair_matches_binary_contract() { &[3, 2], ); - let binary = t1.contract(&t2).unwrap(); - let multi = - ::contract(&[&t1, &t2], AllowedPairs::All).expect("contract"); + let binary = t1.contract_pair(&t2).unwrap(); + let multi = ::contract(&[&t1, &t2]).expect("contract"); assert!( multi.isapprox(&binary, 1e-12, 0.0), - "multi-contract and binary contract differ: maxabs diff = {}", + "nary contract and binary contract differ: maxabs diff = {}", multi.sub(&binary).unwrap().maxabs() ); } #[test] -fn test_contract_multi_three_matches_sequential_binary_contract() { +fn test_contract_nary_three_matches_sequential_binary_contract() { let i = Index::new_dyn(2); let a = Index::new_dyn(3); let b = Index::new_dyn(2); @@ -98,19 +98,19 @@ fn test_contract_multi_three_matches_sequential_binary_contract() { &[3, 2], ); - let sequential = t0.contract(&t1).unwrap().contract(&t2).unwrap(); - let multi = ::contract(&[&t0, &t1, &t2], AllowedPairs::All) - .expect("contract"); + let sequential = t0.contract_pair(&t1).unwrap().contract_pair(&t2).unwrap(); + let multi = + ::contract(&[&t0, &t1, &t2]).expect("contract"); assert!( multi.isapprox(&sequential, 1e-12, 0.0), - "3-tensor multi-contract and sequential contract differ: maxabs diff = {}", + "3-tensor nary contract and sequential contract differ: maxabs diff = {}", multi.sub(&sequential).unwrap().maxabs() ); } #[test] -fn test_contract_multi_pair_matches_binary_contract_for_zero_masked_inputs() { +fn test_contract_nary_pair_matches_binary_contract_for_zero_masked_inputs() { let s0 = Index::new_dyn(2); let l01 = Index::new_dyn(3); let s1 = Index::new_dyn(2); @@ -122,19 +122,18 @@ fn test_contract_multi_pair_matches_binary_contract_for_zero_masked_inputs() { ); let t1 = make_tensor(vec![l01, s1], (1..=6).map(|x| x as f64).collect(), &[3, 2]); - let binary = t0.contract(&t1).unwrap(); - let multi = - ::contract(&[&t0, &t1], AllowedPairs::All).expect("contract"); + let binary = t0.contract_pair(&t1).unwrap(); + let multi = ::contract(&[&t0, &t1]).expect("contract"); assert!( multi.isapprox(&binary, 1e-12, 0.0), - "zero-masked multi-contract and binary contract differ: maxabs diff = {}", + "zero-masked nary contract and binary contract differ: maxabs diff = {}", multi.sub(&binary).unwrap().maxabs() ); } #[test] -fn test_zipup_zero_masked_root_multi_matches_sequential_binary_contract() { +fn test_zipup_zero_masked_root_nary_matches_sequential_binary_contract() { let s0 = Index::new_dyn(2); let s1 = Index::new_dyn(2); let s2 = Index::new_dyn(2); @@ -162,8 +161,7 @@ fn test_zipup_zero_masked_root_multi_matches_sequential_binary_contract() { &[3, 2], ); - let leaf = ::contract(&[&a0, &b0], AllowedPairs::All) - .expect("leaf contract"); + let leaf = outer_product(&a0, &b0).expect("leaf outer product"); let permuted_leaf = leaf .permute_indices(&[s0.clone(), s1.clone(), l01.clone(), l12.clone()]) .unwrap(); @@ -180,14 +178,14 @@ fn test_zipup_zero_masked_root_multi_matches_sequential_binary_contract() { let (u, s, v) = svd::(&leaf, &[s0.clone(), s1.clone()]).expect("svd"); let vh = v.conj().permute(&[2, 0, 1]).unwrap(); - let svh = s.contract(&vh).unwrap(); + let svh = s.contract_pair(&vh).unwrap(); let svh = svh .replaceind( &s.indices[1].clone(), &v.indices[v.indices.len() - 1].clone(), ) .unwrap(); - let svd_reconstructed = u.contract(&svh).unwrap(); + let svd_reconstructed = u.contract_pair(&svh).unwrap(); assert!( svd_reconstructed.isapprox(&leaf, 1e-10, 0.0), "svd leaf does not reconstruct: maxabs diff = {}", @@ -201,7 +199,7 @@ fn test_zipup_zero_masked_root_multi_matches_sequential_binary_contract() { ) .expect("factorize"); - let reconstructed_leaf = factorized.left.contract(&factorized.right).unwrap(); + let reconstructed_leaf = factorized.left.contract_pair(&factorized.right).unwrap(); assert!( reconstructed_leaf.isapprox(&leaf, 1e-10, 0.0), "factorized leaf does not reconstruct: maxabs diff = {}", @@ -210,17 +208,16 @@ fn test_zipup_zero_masked_root_multi_matches_sequential_binary_contract() { let sequential = factorized .right - .contract(&a1) + .contract_pair(&a1) .unwrap() - .contract(&b1) + .contract_pair(&b1) .unwrap(); - let multi = - ::contract(&[&factorized.right, &a1, &b1], AllowedPairs::All) - .expect("root contract"); + let multi = ::contract(&[&factorized.right, &a1, &b1]) + .expect("root contract"); assert!( multi.isapprox(&sequential, 1e-10, 0.0), - "zipup root multi-contract and sequential binary contract differ: maxabs diff = {}", + "zipup root nary contract and sequential binary contract differ: maxabs diff = {}", multi.sub(&sequential).unwrap().maxabs() ); } diff --git a/crates/tensor4all-core/tests/tensor_contraction.rs b/crates/tensor4all-core/tests/tensor_contraction.rs index 34284003..9a90c65c 100644 --- a/crates/tensor4all-core/tests/tensor_contraction.rs +++ b/crates/tensor4all-core/tests/tensor_contraction.rs @@ -1,7 +1,10 @@ use num_complex::Complex64; use tensor4all_core::index::DefaultIndex as Index; use tensor4all_core::index_ops::common_inds; -use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; +use tensor4all_core::{ + contract_pair, contract_pair_with_operand_options, tensordot, DynIndex, + PairwiseContractionOptions, TensorContractionLike, TensorDynLen, +}; use tensor4all_tensorbackend::{Storage, StorageKind}; fn dense_f64(indices: Vec, data: Vec) -> TensorDynLen { @@ -49,7 +52,7 @@ fn test_contract_dyn_len_matrix_multiplication() { let tensor_b = dense_f64(vec![j.clone(), k.clone()], vec![1.0; 12]); // Contract along j: result should be C[i, k] with all 3.0 (since each element is sum of 3 ones) - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); assert_eq!(result.dims(), vec![2, 4]); assert_eq!(result.indices.len(), 2); assert_eq!(result.indices[0].id, i.id); @@ -74,7 +77,7 @@ fn test_mul_operator_contraction() { let tensor_b = dense_f64(vec![j.clone(), k.clone()], vec![1.0; 12]); // Contract along j using * operator: result should be C[i, k] with all 3.0 - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); assert_eq!(result.dims(), vec![2, 4]); assert_eq!(result.indices.len(), 2); assert_eq!(result.indices[0].id, i.id); @@ -95,7 +98,7 @@ fn test_mul_operator_owned() { let tensor_b = dense_f64(vec![j.clone(), k.clone()], vec![1.0; 12]); // Use * operator with owned tensors - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); assert_eq!(result.dims(), vec![2, 4]); assert_eq!(result.indices.len(), 2); } @@ -111,7 +114,7 @@ fn test_contract_no_common_indices_gives_outer_product() { let tensor_b = TensorDynLen::zeros::(vec![k.clone()]).unwrap(); // No common indices → outer product - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); assert_eq!(result.dims(), vec![2, 3, 4]); assert_eq!(result.indices.len(), 3); } @@ -124,7 +127,7 @@ fn test_contract_no_common_indices_preserves_left_then_right_index_order_and_val let tensor_a = TensorDynLen::from_dense(vec![i.clone()], vec![2.0, -1.0]).unwrap(); let tensor_b = TensorDynLen::from_dense(vec![j.clone()], vec![3.0, 4.0, -2.0]).unwrap(); - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); assert_eq!(result.indices, vec![i, j]); let expected = TensorDynLen::from_dense( @@ -147,7 +150,7 @@ fn structured_tensor_contract_materializes_to_correct_dense_result() { assert!(diag.is_diag()); let dense = TensorDynLen::from_dense(vec![j, k.clone()], vec![5.0, 7.0, 11.0, 13.0]).unwrap(); - let result = diag.contract(&dense).unwrap(); + let result = diag.contract_pair(&dense).unwrap(); let expected = TensorDynLen::from_dense(vec![i, k], vec![10.0, 21.0, 22.0, 39.0]).unwrap(); assert!(result.sub(&expected).unwrap().maxabs() < 1e-12); @@ -177,7 +180,7 @@ fn general_structured_contract_preserves_output_axis_classes() { ) .unwrap(); - let result = structured.contract(&dense).unwrap(); + let result = structured.contract_pair(&dense).unwrap(); assert_eq!(result.indices, vec![i, k, l]); assert_eq!(result.storage().storage_kind(), StorageKind::Structured); @@ -208,7 +211,7 @@ fn test_contract_three_indices() { let tensor_b = dense_f64(vec![j.clone(), k.clone(), l.clone()], vec![1.0; 60]); // Contract along j and k: result should be C[i, l] with all 12.0 (3 * 4 = 12) - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); assert_eq!(result.dims(), vec![2, 5]); assert_eq!(result.indices.len(), 2); assert_eq!(result.indices[0].id, i.id); @@ -242,7 +245,7 @@ fn test_contract_mixed_f64_c64() { // Contract along j: result should be C[i, k] (Complex64) // Expected result: [[1+2i + 5+6i, 3+4i + 7+8i], [1+2i + 5+6i, 3+4i + 7+8i]] // = [[6+8i, 10+12i], [6+8i, 10+12i]] - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); assert_eq!(result.dims(), vec![2, 2]); assert_eq!(result.indices.len(), 2); assert_eq!(result.indices[0].id, i.id); @@ -289,7 +292,7 @@ fn test_contract_mixed_c64_f64() { // C[0,1] = (1+2i)*1 + (3+4i)*1 = 4+6i // C[1,0] = (5+6i)*1 + (7+8i)*1 = 12+14i // C[1,1] = (5+6i)*1 + (7+8i)*1 = 12+14i - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); assert_eq!(result.dims(), vec![2, 2]); assert_eq!( @@ -303,6 +306,82 @@ fn test_contract_mixed_c64_f64() { ); } +#[test] +fn test_contract_pair_with_lhs_conj_matches_materialized_conj() { + let i = Index::new_dyn(2); + let j = Index::new_dyn(2); + let k = Index::new_dyn(2); + + let lhs = dense_c64( + vec![i.clone(), j.clone()], + vec![ + Complex64::new(1.0, 2.0), + Complex64::new(3.0, -1.0), + Complex64::new(-2.0, 0.5), + Complex64::new(0.25, 4.0), + ], + ); + let rhs = dense_c64( + vec![j, k.clone()], + vec![ + Complex64::new(0.5, -1.0), + Complex64::new(2.0, 3.0), + Complex64::new(-1.5, 0.25), + Complex64::new(4.0, -0.5), + ], + ); + + let flagged = contract_pair_with_operand_options( + &lhs, + &rhs, + PairwiseContractionOptions::new().with_lhs_conj(true), + ) + .unwrap(); + let materialized = contract_pair(&lhs.conj(), &rhs).unwrap(); + + assert!(flagged.isapprox(&materialized, 1e-12, 0.0)); + assert_eq!(flagged.indices[0].id, i.id); + assert_eq!(flagged.indices[1].id, k.id); +} + +#[test] +fn test_contract_pair_with_rhs_conj_matches_materialized_conj() { + let i = Index::new_dyn(2); + let j = Index::new_dyn(2); + let k = Index::new_dyn(2); + + let lhs = dense_c64( + vec![i.clone(), j.clone()], + vec![ + Complex64::new(1.0, -1.0), + Complex64::new(0.0, 2.0), + Complex64::new(3.0, 0.5), + Complex64::new(-2.0, 1.5), + ], + ); + let rhs = dense_c64( + vec![j, k.clone()], + vec![ + Complex64::new(2.0, 1.0), + Complex64::new(-3.0, 0.25), + Complex64::new(1.5, -2.0), + Complex64::new(0.5, 4.0), + ], + ); + + let flagged = contract_pair_with_operand_options( + &lhs, + &rhs, + PairwiseContractionOptions::new().with_rhs_conj(true), + ) + .unwrap(); + let materialized = contract_pair(&lhs, &rhs.conj()).unwrap(); + + assert!(flagged.isapprox(&materialized, 1e-12, 0.0)); + assert_eq!(flagged.indices[0].id, i.id); + assert_eq!(flagged.indices[1].id, k.id); +} + #[test] fn test_tensordot_different_ids() { // Test tensordot with indices that have different IDs but same dimensions @@ -318,9 +397,7 @@ fn test_tensordot_different_ids() { let tensor_b = dense_f64(vec![k.clone(), l.clone()], vec![1.0; 12]); // Contract j (from A) with k (from B): result should be C[i, l] with all 3.0 - let result = tensor_a - .tensordot(&tensor_b, &[(j.clone(), k.clone())]) - .unwrap(); + let result = tensordot(&tensor_a, &tensor_b, &[(j.clone(), k.clone())]).unwrap(); assert_eq!(result.dims(), vec![2, 4]); assert_eq!(result.indices.len(), 2); assert_eq!(result.indices[0].id, i.id); @@ -340,7 +417,7 @@ fn test_tensordot_dimension_mismatch() { let tensor_b = TensorDynLen::zeros::(vec![k.clone()]).unwrap(); - let result = tensor_a.tensordot(&tensor_b, &[(j.clone(), k.clone())]); + let result = tensordot(&tensor_a, &tensor_b, &[(j.clone(), k.clone())]); assert!(result.is_err()); if let Err(e) = result { let err_msg = format!("{}", e); @@ -365,7 +442,7 @@ fn test_tensordot_index_not_found() { let tensor_b = TensorDynLen::zeros::(vec![k.clone()]).unwrap(); // Try to contract with a non-existent index from tensor_a - let result = tensor_a.tensordot(&tensor_b, &[(nonexistent.clone(), k.clone())]); + let result = tensordot(&tensor_a, &tensor_b, &[(nonexistent.clone(), k.clone())]); assert!(result.is_err()); if let Err(e) = result { let err_msg = format!("{}", e); @@ -390,7 +467,8 @@ fn test_tensordot_duplicate_axis() { let tensor_b = TensorDynLen::zeros::(vec![k.clone(), l.clone()]).unwrap(); // Try to contract j twice (duplicate axis in self) - let result = tensor_a.tensordot( + let result = tensordot( + &tensor_a, &tensor_b, &[ (j.clone(), k.clone()), @@ -411,7 +489,7 @@ fn test_tensordot_empty_pairs() { let tensor_b = TensorDynLen::zeros::(vec![j.clone()]).unwrap(); - let result = tensor_a.tensordot(&tensor_b, &[]); + let result = tensordot(&tensor_a, &tensor_b, &[]); assert!(result.is_err()); if let Err(e) = result { let err_msg = format!("{}", e); @@ -443,7 +521,7 @@ fn test_tensordot_common_index_not_in_pairs() { // Try to contract only k with l, leaving j as a "batch" dimension // This should fail because batch contraction is not yet implemented - let result = tensor_a.tensordot(&tensor_b, &[(k.clone(), l.clone())]); + let result = tensordot(&tensor_a, &tensor_b, &[(k.clone(), l.clone())]); assert!(result.is_err()); if let Err(e) = result { let err_msg = format!("{}", e); @@ -471,7 +549,7 @@ fn test_tensordot_common_index_in_pairs_ok() { let tensor_b = dense_f64(vec![j.clone(), k.clone()], vec![1.0; 12]); // Contract j with j - this should work because the common index is in pairs - let result = tensor_a.tensordot(&tensor_b, &[(j.clone(), j.clone())]); + let result = tensordot(&tensor_a, &tensor_b, &[(j.clone(), j.clone())]); assert!(result.is_ok()); let result = result.unwrap(); assert_eq!(result.dims(), vec![2, 4]); @@ -488,7 +566,7 @@ fn test_scalar_times_tensor() { let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone()], data.clone()).unwrap(); - let result = scalar.contract(&tensor).unwrap(); + let result = scalar.contract_pair(&tensor).unwrap(); assert_eq!(result.dims(), vec![2, 3]); assert_eq!(result.to_vec::().unwrap(), data); } @@ -501,7 +579,7 @@ fn test_tensor_times_scalar() { let data = vec![10.0, 20.0]; let tensor = TensorDynLen::from_dense(vec![i.clone()], data.clone()).unwrap(); - let result = tensor.contract(&scalar).unwrap(); + let result = tensor.contract_pair(&scalar).unwrap(); assert_eq!(result.dims(), vec![2]); assert_eq!(result.to_vec::().unwrap(), data); } @@ -511,7 +589,7 @@ fn test_scalar_times_scalar() { let s1 = TensorDynLen::scalar(3.0).unwrap(); let s2 = TensorDynLen::scalar(5.0).unwrap(); - let result = s1.contract(&s2).unwrap(); + let result = s1.contract_pair(&s2).unwrap(); assert_eq!(result.dims().len(), 0); let val = result.to_vec::().unwrap(); assert_eq!(val.len(), 1); @@ -526,7 +604,7 @@ fn test_mul_operator_scalar_times_tensor() { let data = vec![1.0, 2.0, 3.0]; let tensor = TensorDynLen::from_dense(vec![i.clone()], data.clone()).unwrap(); - let result = scalar.contract(&tensor).unwrap(); + let result = scalar.contract_pair(&tensor).unwrap(); assert_eq!(result.dims(), vec![3]); assert_eq!(result.to_vec::().unwrap(), data); } @@ -540,8 +618,8 @@ fn test_foldl_sequential_contraction() { let b = TensorDynLen::from_dense(vec![j.clone(), i.clone()], vec![2.0; 6]).unwrap(); let mut acc = TensorDynLen::scalar_one().unwrap(); - acc = acc.contract(&a).unwrap(); // acc = a (outer product with scalar) - acc = acc.contract(&b).unwrap(); // acc = contract(a, b) over i and j + acc = acc.contract_pair(&a).unwrap(); // acc = a (outer product with scalar) + acc = acc.contract_pair(&b).unwrap(); // acc = contract(a, b) over i and j // a[i,j] * b[j,i] = sum_j(a[i,j]*b[j,i]) summed over both → scalar assert_eq!(acc.dims().len(), 0); diff --git a/crates/tensor4all-core/tests/tensor_diag.rs b/crates/tensor4all-core/tests/tensor_diag.rs index 1f98b7b9..c6aa55fa 100644 --- a/crates/tensor4all-core/tests/tensor_diag.rs +++ b/crates/tensor4all-core/tests/tensor_diag.rs @@ -1,7 +1,9 @@ use num_complex::Complex64; use tensor4all_core::index::DefaultIndex as Index; -use tensor4all_core::TensorLike; -use tensor4all_core::{diag_tensor_dyn_len, AnyScalar, TensorDynLen}; +use tensor4all_core::{ + diag_tensor_dyn_len, outer_product, tensordot, AnyScalar, TensorConstructionLike, + TensorContractionLike, TensorDynLen, +}; use tensor4all_tensorbackend::StorageKind; #[test] @@ -121,7 +123,7 @@ fn test_diag_tensor_contract_diag_diag_all_contracted() { let tensor_b = diag_tensor_dyn_len(vec![i.clone(), j.clone()], diag_b).unwrap(); // Contract all indices: result should be scalar (inner product) - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); // Result should be scalar: 1*3 + 2*4 = 11 assert_eq!(result.dims().len(), 0); @@ -141,7 +143,7 @@ fn test_diag_tensor_contract_diag_diag_partial() { let tensor_b = diag_tensor_dyn_len(vec![j.clone(), k.clone()], diag_b).unwrap(); // Contract along j: result should be DiagTensor[i, k] - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); assert_eq!(result.dims(), vec![3, 3]); assert!(result.is_diag()); @@ -163,11 +165,11 @@ fn tracked_diag_partial_contraction_preserves_diag_result_and_grad() { .unwrap(); let b = diag_tensor_dyn_len(vec![j, k.clone()], vec![7.0, 11.0, 13.0]).unwrap(); - let c = a.contract(&b).unwrap(); + let c = a.contract_pair(&b).unwrap(); assert_eq!(c.storage().storage_kind(), StorageKind::Diagonal); let ones = diag_tensor_dyn_len(vec![i, k], vec![1.0, 1.0, 1.0]).unwrap(); - let loss = c.contract(&ones).unwrap(); + let loss = c.contract_pair(&ones).unwrap(); loss.backward().unwrap(); let grad = a.grad().unwrap().unwrap(); @@ -188,9 +190,8 @@ fn test_diag_tensor_tensordot_diag_diag_partial_preserves_diagonal_storage() { let tensor_a = diag_tensor_dyn_len(vec![i.clone(), j.clone()], vec![1.0, 2.0, 3.0]).unwrap(); let tensor_b = diag_tensor_dyn_len(vec![k.clone(), l.clone()], vec![4.0, 5.0, 6.0]).unwrap(); - let result = tensor_a - .tensordot(&tensor_b, &[(j, k)]) - .expect("diag-diag tensordot should succeed"); + let result = + tensordot(&tensor_a, &tensor_b, &[(j, k)]).expect("diag-diag tensordot should succeed"); assert_eq!(result.dims(), vec![3, 3]); assert!(result.is_diag()); @@ -214,7 +215,7 @@ fn test_diag_tensor_contract_diag_dense() { let tensor_b = TensorDynLen::from_dense(vec![j.clone(), k.clone()], vec![1.0; 4]).unwrap(); // Contract along j: result should be DenseTensor[i, k] - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); assert_eq!(result.dims(), vec![2, 2]); let expected = TensorDynLen::from_dense(vec![i, k], vec![1.0, 2.0, 1.0, 2.0]).unwrap(); @@ -268,7 +269,7 @@ fn tensorlike_diagonal_uses_compact_diagonal_storage() { let i = Index::new_dyn(4); let o = Index::new_dyn(4); - let delta = ::diagonal(&i, &o).unwrap(); + let delta = ::diagonal(&i, &o).unwrap(); assert!(delta.is_diag()); assert_eq!(delta.storage().storage_kind(), StorageKind::Diagonal); @@ -287,19 +288,22 @@ fn tensorlike_delta_two_pairs_preserves_independent_copy_structure() { let i2 = Index::new_dyn(3); let o2 = Index::new_dyn(3); - let delta = - ::delta(&[i1.clone(), i2.clone()], &[o1.clone(), o2.clone()]) - .unwrap(); + let delta = ::delta( + &[i1.clone(), i2.clone()], + &[o1.clone(), o2.clone()], + ) + .unwrap(); assert_eq!(delta.dims(), vec![2, 2, 3, 3]); assert_eq!(delta.storage().storage_kind(), StorageKind::Structured); assert_eq!(delta.storage().payload_dims(), &[2, 3]); assert_eq!(delta.storage().axis_classes(), &[0, 0, 1, 1]); - let expected = TensorDynLen::from_diag(vec![i1, o1], vec![1.0_f64, 1.0]) - .unwrap() - .outer_product(&TensorDynLen::from_diag(vec![i2, o2], vec![1.0_f64, 1.0, 1.0]).unwrap()) - .unwrap(); + let expected = outer_product( + &TensorDynLen::from_diag(vec![i1, o1], vec![1.0_f64, 1.0]).unwrap(), + &TensorDynLen::from_diag(vec![i2, o2], vec![1.0_f64, 1.0, 1.0]).unwrap(), + ) + .unwrap(); assert!(delta.isapprox(&expected, 1e-12, 0.0)); } @@ -493,7 +497,7 @@ fn test_diag_tensor_contract_rank3() { let tensor_b = diag_tensor_dyn_len(vec![k.clone(), l.clone()], diag_b).unwrap(); // Contract along k: result should be DiagTensor[i, j, l] - let result = tensor_a.contract(&tensor_b).unwrap(); + let result = tensor_a.contract_pair(&tensor_b).unwrap(); assert_eq!(result.dims(), vec![2, 2, 2]); assert!(result.is_diag()); diff --git a/crates/tensor4all-core/tests/tensor_native_ad.rs b/crates/tensor4all-core/tests/tensor_native_ad.rs index 4fbf3657..7d87438a 100644 --- a/crates/tensor4all-core/tests/tensor_native_ad.rs +++ b/crates/tensor4all-core/tests/tensor_native_ad.rs @@ -1,6 +1,5 @@ use tensor4all_core::{ - contract_multi, contract_multi_with_options, AllowedPairs, ContractionOptions, Index, - TensorDynLen, + contract, contract_with_options, ContractionOptions, Index, TensorContractionLike, TensorDynLen, }; use tensor4all_tensorbackend::{Storage, StorageKind}; @@ -47,7 +46,7 @@ fn contraction_without_grad_returns_rank_zero_scalar() { let a = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0, 3.0]).unwrap(); let ones = TensorDynLen::from_dense(vec![i], vec![1.0, 1.0, 1.0]).unwrap(); - let result = contract_multi(&[&a, &ones], AllowedPairs::All).unwrap(); + let result = contract(&[&a, &ones]).unwrap(); assert!(result.indices().is_empty()); assert_eq!(result.to_vec::().unwrap(), vec![6.0]); @@ -62,13 +61,13 @@ fn backward_accumulates_until_clear_grad() { .unwrap(); let ones = TensorDynLen::from_dense(vec![i], vec![1.0, 1.0, 1.0]).unwrap(); - let loss = contract_multi(&[&x, &ones], AllowedPairs::All).unwrap(); + let loss = contract(&[&x, &ones]).unwrap(); loss.backward().unwrap(); let grad = x.grad().unwrap().unwrap(); assert_eq!(grad.to_vec::().unwrap(), vec![1.0, 1.0, 1.0]); - let loss = contract_multi(&[&x, &ones], AllowedPairs::All).unwrap(); + let loss = contract(&[&x, &ones]).unwrap(); loss.backward().unwrap(); let grad = x.grad().unwrap().unwrap(); @@ -97,7 +96,7 @@ fn general_structured_grad_preserves_input_axis_classes() { .unwrap(); let ones = TensorDynLen::from_dense(vec![i, j, k], vec![1.0; 12]).unwrap(); - let loss = contract_multi(&[&x, &ones], AllowedPairs::All).unwrap(); + let loss = contract(&[&x, &ones]).unwrap(); loss.backward().unwrap(); let grad = x.grad().unwrap().unwrap(); @@ -128,7 +127,7 @@ fn clone_shares_tracked_leaf_gradient_slot() { let x = TensorDynLen::scalar(2.0).unwrap().enable_grad().unwrap(); let alias = x.clone(); - let loss = x.contract(&alias).unwrap(); + let loss = x.contract_pair(&alias).unwrap(); loss.backward().unwrap(); let grad_x = x.grad().unwrap().unwrap(); @@ -154,9 +153,9 @@ fn retained_multi_contraction_preserves_grad_path() { let y = TensorDynLen::from_dense(vec![batch.clone(), k.clone(), j.clone()], vec![1.0; 12]).unwrap(); let retain_indices = [batch.clone()]; - let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); + let options = ContractionOptions::new().with_retain_indices(&retain_indices); - let result = contract_multi_with_options(&[&x, &y], options).unwrap(); + let result = contract_with_options(&[&x, &y], options).unwrap(); assert_eq!(result.dims(), vec![2, 2, 2]); assert_eq!( result.to_vec::().unwrap(), @@ -164,7 +163,7 @@ fn retained_multi_contraction_preserves_grad_path() { ); let ones = TensorDynLen::from_dense(result.indices().to_vec(), vec![1.0; 8]).unwrap(); - let loss = contract_multi(&[&result, &ones], AllowedPairs::All).unwrap(); + let loss = contract(&[&result, &ones]).unwrap(); loss.backward().unwrap(); let grad = x.grad().unwrap().unwrap(); @@ -193,9 +192,9 @@ fn structured_retained_multi_contraction_errors_before_detaching_grad() { let y = TensorDynLen::from_dense(vec![batch.clone(), k.clone(), j.clone()], vec![1.0; 8]).unwrap(); let retain_indices = [batch.clone()]; - let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain_indices); + let options = ContractionOptions::new().with_retain_indices(&retain_indices); - let err = contract_multi_with_options(&[&x, &y], options).unwrap_err(); + let err = contract_with_options(&[&x, &y], options).unwrap_err(); let message = err.to_string(); assert!( message.contains("structured storage") || message.contains("not yet supported"), diff --git a/crates/tensor4all-core/tests/tensor_tensor_like.rs b/crates/tensor4all-core/tests/tensor_tensor_like.rs index 485d9540..419b257b 100644 --- a/crates/tensor4all-core/tests/tensor_tensor_like.rs +++ b/crates/tensor4all-core/tests/tensor_tensor_like.rs @@ -2,7 +2,10 @@ use tensor4all_core::index::{DynId, Index}; use tensor4all_core::DynIndex; -use tensor4all_core::{AllowedPairs, TensorDynLen, TensorIndex, TensorLike}; +use tensor4all_core::{ + outer_product, TensorConstructionLike, TensorContractionLike, TensorDynLen, TensorIndex, + TensorVectorSpace, +}; /// Helper to create a simple tensor with given dimensions fn make_tensor(dims: &[usize]) -> TensorDynLen { @@ -51,8 +54,8 @@ fn test_tensor_like_contract_basic() { let b_data: Vec = (0..12).map(|x| x as f64).collect(); let b = TensorDynLen::from_dense(vec![j_copy.clone(), k.clone()], b_data).unwrap(); - // Use TensorLike::contract - auto-detects contractable pairs via is_contractable - let c = ::contract(&[&a, &b], AllowedPairs::All) + // Use TensorContractionLike::contract - auto-detects contractable pairs via is_contractable + let c = ::contract(&[&a, &b]) .expect("contract should succeed"); // Result should be 2x4 @@ -60,11 +63,49 @@ fn test_tensor_like_contract_basic() { } #[test] -fn test_contract_allowed_pairs_specified() { +fn tensor_vector_space_default_methods_cover_common_paths() { + let i = Index::::new_dyn(3); + let a = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, -2.0, 3.0]).unwrap(); + let b = TensorDynLen::from_dense(vec![i.clone()], vec![1.0, -2.0, 3.0 + 1.0e-13]).unwrap(); + + let neg = TensorVectorSpace::neg(&a).unwrap(); + assert_eq!(neg.to_vec::().unwrap(), vec![-1.0, 2.0, -3.0]); + assert!(a.isapprox(&b, 1.0e-12, 0.0)); + + let j = Index::::new_dyn(2); + let incompatible = TensorDynLen::from_dense(vec![j], vec![1.0, 2.0]).unwrap(); + assert!(!a.isapprox(&incompatible, 1.0e-12, 0.0)); +} + +#[test] +fn tensor_contraction_and_construction_default_methods_cover_paths() { + let i = Index::::new_dyn(2); + let j = Index::::new_dyn(3); + let a = TensorDynLen::from_dense(vec![i.clone()], vec![2.0, 5.0]).unwrap(); + let b = TensorDynLen::from_dense(vec![i.clone(), j.clone()], vec![1.0; 6]).unwrap(); + + let contracted = a.contract_pair(&b).unwrap(); + assert_eq!(contracted.indices, vec![j.clone()]); + assert_eq!(contracted.to_vec::().unwrap(), vec![7.0, 7.0, 7.0]); + TensorContractionLike::validate(&contracted).unwrap(); + + let unchanged = TensorConstructionLike::select_indices(&b, &[], &[]).unwrap(); + assert!(unchanged.isapprox(&b, 0.0, 0.0)); + + let selected = + TensorConstructionLike::select_indices(&b, std::slice::from_ref(&i), &[1]).unwrap(); + assert_eq!(selected.indices, vec![j.clone()]); + assert_eq!(selected.to_vec::().unwrap(), vec![1.0, 1.0, 1.0]); + + let err = + TensorConstructionLike::select_indices(&b, std::slice::from_ref(&i), &[2]).unwrap_err(); + assert!(err.to_string().contains("out of range")); +} + +#[test] +fn test_contract_three_tensor_chain() { // Create three tensors: A(i,j), B(j,k), C(k,l) - // With AllowedPairs::Specified(&[(0, 1), (1, 2)]) - A-B and B-C pairs allowed - // j is shared between A and B, k is shared between B and C - // All tensors form a connected chain: A-j-B-k-C + // j is shared between A and B, k is shared between B and C. let i = Index::::new_dyn(2); let j = Index::::new_dyn(3); let k = Index::::new_dyn(4); @@ -84,14 +125,8 @@ fn test_contract_allowed_pairs_specified() { let c_data: Vec = (0..20).map(|x| x as f64).collect(); let c = TensorDynLen::from_dense(vec![k_copy.clone(), l.clone()], c_data).unwrap(); - // Contract with specified pairs - // j is contracted between A and B (in pair (0,1)) - // k is contracted between B and C (in pair (1,2)) - let result = ::contract( - &[&a, &b, &c], - AllowedPairs::Specified(&[(0, 1), (1, 2)]), - ) - .expect("contract should succeed"); + let result = ::contract(&[&a, &b, &c]) + .expect("contract should succeed"); // Result should have: i (from A, dim=2), l (from C, dim=5) // j and k are contracted @@ -102,9 +137,7 @@ fn test_contract_allowed_pairs_specified() { } #[test] -fn test_contract_specified_empty_with_common_indices_errors() { - // AllowedPairs::Specified(&[]) with tensors that share index IDs should error - // because outer_product requires tensors to have no common indices +fn test_outer_product_with_common_indices_errors() { let i = Index::::new_dyn(2); let j = Index::::new_dyn(3); @@ -118,19 +151,16 @@ fn test_contract_specified_empty_with_common_indices_errors() { let b_data: Vec = (0..6).map(|x| x as f64).collect(); let b = TensorDynLen::from_dense(vec![i_copy.clone(), j_copy.clone()], b_data).unwrap(); - // With empty allowed pairs and tensors that share index IDs, - // outer_product will fail because tensors have common indices - let result = ::contract(&[&a, &b], AllowedPairs::Specified(&[])); + let result = outer_product(&a, &b); assert!(result.is_err()); - let err_msg = result.unwrap_err().to_string(); + let err_msg = result.unwrap_err().to_string().to_lowercase(); assert!(err_msg.contains("common indices")); } #[test] -fn test_contract_specified_empty_outer_product() { - // AllowedPairs::Specified(&[]) with tensors that have different index IDs - // should succeed via outer product +fn test_outer_product_disconnected_tensors() { + // Disconnected inputs require an explicit outer product. let i = Index::::new_dyn(2); let j = Index::::new_dyn(3); let k = Index::::new_dyn(4); @@ -144,9 +174,7 @@ fn test_contract_specified_empty_outer_product() { let b_data: Vec = (0..20).map(|x| x as f64).collect(); let b = TensorDynLen::from_dense(vec![k.clone(), l.clone()], b_data).unwrap(); - // With empty allowed pairs and different index IDs, outer product succeeds - let result = - ::contract(&[&a, &b], AllowedPairs::Specified(&[])).unwrap(); + let result = outer_product(&a, &b).unwrap(); // Result should have 4 indices (i, j, k, l) let mut sorted_dims = result.dims(); @@ -156,15 +184,14 @@ fn test_contract_specified_empty_outer_product() { } #[test] -fn test_contract_specified_empty_outer_product_preserves_input_component_order() { +fn test_outer_product_preserves_input_component_order() { let i = Index::::new_dyn(2); let j = Index::::new_dyn(3); let a = TensorDynLen::from_dense(vec![i.clone()], vec![2.0, -1.0]).unwrap(); let b = TensorDynLen::from_dense(vec![j.clone()], vec![3.0, 4.0, -2.0]).unwrap(); - let result = - ::contract(&[&a, &b], AllowedPairs::Specified(&[])).unwrap(); + let result = outer_product(&a, &b).unwrap(); assert_eq!(result.indices, vec![i, j]); let expected = TensorDynLen::from_dense( @@ -179,10 +206,7 @@ fn test_contract_specified_empty_outer_product_preserves_input_component_order() } #[test] -fn test_contract_specified_disconnected_outer_product() { - // AllowedPairs::Specified(&[(0,1), (2,3)]) with 4 tensors - // This creates a disconnected graph: {A,B} and {C,D} - // Each component contracts within itself, then outer product combines them +fn test_contract_components_then_outer_product() { let i = Index::::new_dyn(2); let j = Index::::new_dyn(3); @@ -193,13 +217,9 @@ fn test_contract_specified_disconnected_outer_product() { let j_copy = Index::new(j.id, j.dim); let d = TensorDynLen::from_dense(vec![j_copy.clone()], vec![8.0, 9.0, 10.0]).unwrap(); - // Disconnected pairs: (0,1) and (2,3) - // A and B contract i, C and D contract j, then outer product combines results - let result = ::contract( - &[&a, &b, &c, &d], - AllowedPairs::Specified(&[(0, 1), (2, 3)]), - ) - .unwrap(); + let left = ::contract(&[&a, &b]).unwrap(); + let right = ::contract(&[&c, &d]).unwrap(); + let result = outer_product(&left, &right).unwrap(); // A(i) * B(i) contracts to scalar (dim 0) // C(j) * D(j) contracts to scalar (dim 0) @@ -264,8 +284,6 @@ fn test_onehot_empty() { #[test] fn test_onehot_contraction() { - use tensor4all_core::AllowedPairs; - // Create a tensor A(i,j) and a onehot V(i) let i = Index::new_dyn(3); let j = Index::new_dyn(4); @@ -279,7 +297,7 @@ fn test_onehot_contraction() { let v = TensorDynLen::onehot(&[(i.clone(), 1)]).unwrap(); // Contract: V(i) * A(i,j) = A[1,:] - let result = ::contract(&[&v, &a], AllowedPairs::All).unwrap(); + let result = ::contract(&[&v, &a]).unwrap(); assert_eq!(result.dims(), vec![4]); let data = result.to_vec::().unwrap(); // Row i=1 of the column-major 3×4 matrix: [1, 4, 7, 10] diff --git a/crates/tensor4all-core/tests/tensor_unfuse.rs b/crates/tensor4all-core/tests/tensor_unfuse.rs index eb0a7f28..2de71cdb 100644 --- a/crates/tensor4all-core/tests/tensor_unfuse.rs +++ b/crates/tensor4all-core/tests/tensor_unfuse.rs @@ -1,6 +1,6 @@ use num_complex::Complex64; use tensor4all_core::{ - DynIndex, IndexLike, LinearizationOrder, TensorDynLen, TensorIndex, TensorLike, + DynIndex, IndexLike, LinearizationOrder, TensorContractionLike, TensorDynLen, TensorIndex, }; #[test] @@ -92,7 +92,7 @@ fn fuse_indices_trait_dispatch_on_tensordynlen_uses_old_index_order() { let data: Vec = (0..12).map(|x| x as f64).collect(); let tensor = TensorDynLen::from_dense(vec![i.clone(), j.clone(), k.clone()], data).unwrap(); - let fused_tensor = ::fuse_indices( + let fused_tensor = ::fuse_indices( &tensor, &[j.clone(), i.clone()], fused.clone(), diff --git a/crates/tensor4all-hdf5/examples/inspect_mps_inputs.rs b/crates/tensor4all-hdf5/examples/inspect_mps_inputs.rs new file mode 100644 index 00000000..9ebc521b --- /dev/null +++ b/crates/tensor4all-hdf5/examples/inspect_mps_inputs.rs @@ -0,0 +1 @@ +include!("../../../benchmarks/rust/inspect_hdf5_mps_inputs.rs"); diff --git a/crates/tensor4all-hdf5/src/lib.rs b/crates/tensor4all-hdf5/src/lib.rs index dd2966d7..a2cbe06b 100644 --- a/crates/tensor4all-hdf5/src/lib.rs +++ b/crates/tensor4all-hdf5/src/lib.rs @@ -153,6 +153,43 @@ pub fn save_itensor(filepath: &str, name: &str, tensor: &TensorDynLen) -> Result itensor::write_itensor(&group, tensor) } +/// Append a [`TensorDynLen`] as an ITensors.jl-compatible `ITensor` to an HDF5 file. +/// +/// Opens `filepath` read/write if it exists, or creates it otherwise, then +/// writes the tensor under `name`. This is useful for files containing multiple +/// tensor objects. The target group must not already exist. +/// +/// # Errors +/// +/// Returns an error if the file cannot be opened for appending, if `name` +/// already exists, or if the tensor storage type is unsupported. +/// +/// # Examples +/// +/// ``` +/// use tensor4all_core::{DynIndex, TensorDynLen}; +/// use tensor4all_hdf5::{append_itensor, load_itensor}; +/// +/// # fn main() -> anyhow::Result<()> { +/// let dir = tempfile::tempdir()?; +/// let path = dir.path().join("append_itensor.h5"); +/// let path = path.to_str().unwrap(); +/// let a = TensorDynLen::from_dense(vec![DynIndex::new_dyn(2)], vec![1.0, 2.0])?; +/// let b = TensorDynLen::from_dense(vec![DynIndex::new_dyn(2)], vec![3.0, 4.0])?; +/// +/// append_itensor(path, "a", &a)?; +/// append_itensor(path, "b", &b)?; +/// assert_eq!(load_itensor(path, "a")?.to_vec::()?, vec![1.0, 2.0]); +/// assert_eq!(load_itensor(path, "b")?.to_vec::()?, vec![3.0, 4.0]); +/// # Ok(()) +/// # } +/// ``` +pub fn append_itensor(filepath: &str, name: &str, tensor: &TensorDynLen) -> Result<()> { + let file = File::append(filepath)?; + let group = file.create_group(name)?; + itensor::write_itensor(&group, tensor) +} + /// Load a [`TensorDynLen`] from an ITensors.jl-compatible `ITensor` in an HDF5 file. /// /// Opens the file in read-only mode and reads the tensor from the group named @@ -264,6 +301,47 @@ pub fn save_mps(filepath: &str, name: &str, tt: &TensorTrain) -> Result<()> { mps::write_mps(&group, tt) } +/// Append a [`TensorTrain`] as an ITensorMPS.jl-compatible `MPS` to an HDF5 file. +/// +/// Opens `filepath` read/write if it exists, or creates it otherwise, then +/// writes the MPS under `name`. This keeps the same `MPS` v1 schema as +/// [`save_mps`] while allowing multiple named MPS objects in a single file. +/// The target group must not already exist. +/// +/// # Errors +/// +/// Returns an error if the file cannot be opened for appending, if `name` +/// already exists, or if any site tensor uses unsupported storage. +/// +/// # Examples +/// +/// ``` +/// use tensor4all_core::{DynIndex, TensorDynLen}; +/// use tensor4all_hdf5::{append_mps, load_mps}; +/// use tensor4all_itensorlike::TensorTrain; +/// +/// # fn main() -> anyhow::Result<()> { +/// let dir = tempfile::tempdir()?; +/// let path = dir.path().join("append_mps.h5"); +/// let path = path.to_str().unwrap(); +/// let s0 = DynIndex::new_dyn(2); +/// let s1 = DynIndex::new_dyn(2); +/// let a = TensorTrain::new(vec![TensorDynLen::from_dense(vec![s0], vec![1.0, 2.0])?])?; +/// let b = TensorTrain::new(vec![TensorDynLen::from_dense(vec![s1], vec![3.0, 4.0])?])?; +/// +/// append_mps(path, "a", &a)?; +/// append_mps(path, "b", &b)?; +/// assert_eq!(load_mps(path, "a")?.len(), 1); +/// assert_eq!(load_mps(path, "b")?.siteinds()[0][0].size(), 2); +/// # Ok(()) +/// # } +/// ``` +pub fn append_mps(filepath: &str, name: &str, tt: &TensorTrain) -> Result<()> { + let file = File::append(filepath)?; + let group = file.create_group(name)?; + mps::write_mps(&group, tt) +} + /// Load a [`TensorTrain`] from an ITensorMPS.jl-compatible `MPS` in an HDF5 file. /// /// Opens the file in read-only mode and reads the MPS from the group named diff --git a/crates/tensor4all-hdf5/tests/test_hdf5.rs b/crates/tensor4all-hdf5/tests/test_hdf5.rs index 40400e0c..842bbdd8 100644 --- a/crates/tensor4all-hdf5/tests/test_hdf5.rs +++ b/crates/tensor4all-hdf5/tests/test_hdf5.rs @@ -3,7 +3,7 @@ use hdf5_metno::File; use num_complex::Complex64; use tensor4all_core::index::{DynId, DynIndex, Index, TagSet}; use tensor4all_core::TensorDynLen; -use tensor4all_hdf5::{load_itensor, load_mps, save_itensor, save_mps}; +use tensor4all_hdf5::{append_itensor, append_mps, load_itensor, load_mps, save_itensor, save_mps}; use tensor4all_itensorlike::{CanonicalForm, TensorTrain}; fn temp_path(name: &str) -> String { @@ -112,6 +112,34 @@ fn test_itensor_c64_roundtrip() { std::fs::remove_file(&path).ok(); } +#[test] +fn test_append_itensor_keeps_multiple_named_objects() { + let path = temp_path("append_itensor_multiple"); + std::fs::remove_file(&path).ok(); + let first = TensorDynLen::from_dense(vec![DynIndex::new_dyn(2)], vec![1.0, 2.0]).unwrap(); + let second = TensorDynLen::from_dense(vec![DynIndex::new_dyn(2)], vec![3.0, 4.0]).unwrap(); + + append_itensor(&path, "first", &first).unwrap(); + append_itensor(&path, "second", &second).unwrap(); + + assert_eq!( + load_itensor(&path, "first") + .unwrap() + .to_vec::() + .unwrap(), + vec![1.0, 2.0] + ); + assert_eq!( + load_itensor(&path, "second") + .unwrap() + .to_vec::() + .unwrap(), + vec![3.0, 4.0] + ); + + std::fs::remove_file(&path).ok(); +} + #[test] fn test_itensor_c64_storage_dataset_uses_column_major_linearization() { let path = temp_path("itensor_c64_storage_column_major"); @@ -240,6 +268,65 @@ fn test_mps_roundtrip() { std::fs::remove_file(&path).ok(); } +#[test] +fn test_mps_load_preserves_site_tensor_index_order() { + let path = temp_path("mps_preserve_index_order"); + + let site0 = Index::new_with_size(DynId(10), 2); + let link = Index::new_with_size(DynId(11), 3); + let site1 = Index::new_with_size(DynId(12), 2); + let left = TensorDynLen::from_dense( + vec![link.clone(), site0.clone()], + vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], + ) + .unwrap(); + let right = TensorDynLen::from_dense( + vec![site1.clone(), link.clone()], + vec![6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + ) + .unwrap(); + let mps = TensorTrain::new(vec![left, right]).unwrap(); + + save_mps(&path, "mps", &mps).unwrap(); + let loaded = load_mps(&path, "mps").unwrap(); + + assert_eq!(loaded.tensor(0).unwrap().indices(), &[link.clone(), site0]); + assert_eq!(loaded.tensor(1).unwrap().indices(), &[site1, link]); + + std::fs::remove_file(&path).ok(); +} + +#[test] +fn test_append_mps_keeps_multiple_named_objects() { + let path = temp_path("append_mps_multiple"); + std::fs::remove_file(&path).ok(); + let first = TensorTrain::new(vec![TensorDynLen::from_dense( + vec![DynIndex::new_dyn(2)], + vec![1.0, 2.0], + ) + .unwrap()]) + .unwrap(); + let second = TensorTrain::new(vec![TensorDynLen::from_dense( + vec![DynIndex::new_dyn(2)], + vec![3.0, 4.0], + ) + .unwrap()]) + .unwrap(); + + append_mps(&path, "first", &first).unwrap(); + append_mps(&path, "second", &second).unwrap(); + + assert_eq!(load_mps(&path, "first").unwrap().siteinds()[0][0].size(), 2); + assert_eq!( + load_mps(&path, "second").unwrap().tensors()[0] + .to_vec::() + .unwrap(), + vec![3.0, 4.0] + ); + + std::fs::remove_file(&path).ok(); +} + #[test] fn test_mps_ortho_roundtrip() { let path = temp_path("mps_ortho"); diff --git a/crates/tensor4all-itensorlike/Cargo.toml b/crates/tensor4all-itensorlike/Cargo.toml index e440d670..db8ff57c 100644 --- a/crates/tensor4all-itensorlike/Cargo.toml +++ b/crates/tensor4all-itensorlike/Cargo.toml @@ -32,4 +32,5 @@ anyhow.workspace = true [dev-dependencies] approx.workspace = true +tenferro.workspace = true tensor4all-tensorbackend = { path = "../tensor4all-tensorbackend", default-features = false, features = ["tenferro-cpu-faer"] } diff --git a/crates/tensor4all-itensorlike/examples/benchmark_tt_ops.rs b/crates/tensor4all-itensorlike/examples/benchmark_tt_ops.rs new file mode 100644 index 00000000..d8b558f8 --- /dev/null +++ b/crates/tensor4all-itensorlike/examples/benchmark_tt_ops.rs @@ -0,0 +1 @@ +include!("../../../benchmarks/rust/benchmark_tt_ops.rs"); diff --git a/crates/tensor4all-itensorlike/examples/test_gmres_block_mpo.rs b/crates/tensor4all-itensorlike/examples/test_gmres_block_mpo.rs index 4f555df6..69107d63 100644 --- a/crates/tensor4all-itensorlike/examples/test_gmres_block_mpo.rs +++ b/crates/tensor4all-itensorlike/examples/test_gmres_block_mpo.rs @@ -17,7 +17,7 @@ use tensor4all_core::block_tensor::BlockTensor; use tensor4all_core::krylov::{ gmres_with_truncation, restart_gmres_with_truncation, GmresOptions, RestartGmresOptions, }; -use tensor4all_core::TensorLike; +use tensor4all_core::TensorVectorSpace; use tensor4all_core::{AnyScalar, DynIndex, IndexLike, TensorDynLen, TensorIndex}; use tensor4all_itensorlike::{ContractOptions, TensorTrain, TruncateOptions}; diff --git a/crates/tensor4all-itensorlike/examples/test_gmres_block_mps.rs b/crates/tensor4all-itensorlike/examples/test_gmres_block_mps.rs index 8f32448c..17dba49c 100644 --- a/crates/tensor4all-itensorlike/examples/test_gmres_block_mps.rs +++ b/crates/tensor4all-itensorlike/examples/test_gmres_block_mps.rs @@ -14,7 +14,7 @@ use tensor4all_core::block_tensor::BlockTensor; use tensor4all_core::krylov::{ gmres_with_truncation, restart_gmres_with_truncation, GmresOptions, RestartGmresOptions, }; -use tensor4all_core::TensorLike; +use tensor4all_core::TensorVectorSpace; use tensor4all_core::{AnyScalar, DynIndex, IndexLike, TensorDynLen, TensorIndex}; use tensor4all_itensorlike::{ContractOptions, TensorTrain, TruncateOptions}; diff --git a/crates/tensor4all-itensorlike/src/contract.rs b/crates/tensor4all-itensorlike/src/contract.rs index 6c7518aa..d7e8d478 100644 --- a/crates/tensor4all-itensorlike/src/contract.rs +++ b/crates/tensor4all-itensorlike/src/contract.rs @@ -95,7 +95,7 @@ pub fn contract( let result_inner = if matches!(options.method(), ContractMethod::Zipup) { a.as_treetn() - .contract_zipup_tree_accumulated( + .contract_zipup_with( b.as_treetn(), ¢er, CanonicalForm::Unitary, @@ -137,6 +137,19 @@ impl TensorTrain { pub fn contract(&self, other: &Self, options: &ContractOptions) -> Result { contract(self, other, options) } + + /// Contract two tensor trains with explicit contraction options. + /// + /// This is an alias for [`TensorTrain::contract`]. It exists for callers + /// that use `contract_pair` to mean pairwise tensor-train contraction with + /// compression, rather than the dense tensor-level pair contraction. + /// + /// # Errors + /// Returns an error if the tensor trains cannot be contracted with the + /// requested options. + pub fn contract_pair(&self, other: &Self, options: &ContractOptions) -> Result { + self.contract(other, options) + } } #[cfg(test)] diff --git a/crates/tensor4all-itensorlike/src/contract/tests/mod.rs b/crates/tensor4all-itensorlike/src/contract/tests/mod.rs index 10cf721c..32c315bf 100644 --- a/crates/tensor4all-itensorlike/src/contract/tests/mod.rs +++ b/crates/tensor4all-itensorlike/src/contract/tests/mod.rs @@ -1,6 +1,6 @@ use super::*; use crate::TensorTrainError; -use tensor4all_core::{DynId, DynIndex, Index, TensorDynLen, TensorLike}; +use tensor4all_core::{DynId, DynIndex, Index, TensorContractionLike, TensorDynLen}; /// Helper to create a simple tensor for testing fn make_tensor(indices: Vec) -> TensorDynLen { @@ -23,7 +23,7 @@ fn assert_matches_naive(tt1: &TensorTrain, tt2: &TensorTrain, result: &TensorTra let naive_result = tt1 .to_dense() .unwrap() - .contract(&tt2.to_dense().unwrap()) + .contract_pair(&tt2.to_dense().unwrap()) .unwrap(); let result_dense = result.to_dense().unwrap(); assert!( @@ -181,7 +181,7 @@ fn test_contract_zipup_matches_naive_for_zero_masked_inputs() { .unwrap(); let tt2 = TensorTrain::new(vec![ - make_tensor(vec![s1.clone(), l12.clone()]), + make_tensor(vec![s0.clone(), l12.clone()]), TensorDynLen::from_dense( vec![l12.clone(), s2.clone()], vec![1.0, 0.0, 3.0, 0.0, 5.0, 0.0], @@ -191,7 +191,7 @@ fn test_contract_zipup_matches_naive_for_zero_masked_inputs() { .unwrap(); let result = contract(&tt1, &tt2, &ContractOptions::zipup()).unwrap(); - assert_eq!(result.len(), 2); + assert_eq!(result.len(), 1); assert_matches_naive(&tt1, &tt2, &result); } @@ -214,7 +214,7 @@ fn test_treetn_zipup_matches_naive_for_zero_masked_inputs() { .unwrap(); let tt2 = TensorTrain::new(vec![ - make_tensor(vec![s1.clone(), l12.clone()]), + make_tensor(vec![s0.clone(), l12.clone()]), TensorDynLen::from_dense( vec![l12.clone(), s2.clone()], vec![1.0, 0.0, 3.0, 0.0, 5.0, 0.0], @@ -232,7 +232,7 @@ fn test_treetn_zipup_matches_naive_for_zero_masked_inputs() { ) .unwrap(); let result = TensorTrain::from_inner(result_inner, Some(CanonicalForm::Unitary)).unwrap(); - assert_eq!(result.len(), 2); + assert_eq!(result.len(), 1); assert_matches_naive(&tt1, &tt2, &result); } @@ -269,7 +269,26 @@ fn test_contract_fit_two_sites() { let options = ContractOptions::fit().with_max_rank(10).with_nhalfsweeps(4); let result = contract(&tt1, &tt2, &options).unwrap(); - assert_eq!(result.len(), 1); + assert_matches_naive(&tt1, &tt2, &result); +} + +#[test] +fn test_contract_fit_accepts_non_chain_ordered_site_tensors() { + let s0 = idx(1080, 2); + let s1 = idx(1081, 2); + let l01_a = idx(1082, 3); + let l01_b = idx(1083, 3); + + let t1_0 = make_tensor(vec![l01_a.clone(), s0.clone()]); + let t1_1 = make_tensor(vec![s1.clone(), l01_a.clone()]); + let tt1 = TensorTrain::new(vec![t1_0, t1_1]).unwrap(); + + let t2_0 = make_tensor(vec![l01_b.clone(), s0.clone()]); + let t2_1 = make_tensor(vec![s1.clone(), l01_b.clone()]); + let tt2 = TensorTrain::new(vec![t2_0, t2_1]).unwrap(); + + let options = ContractOptions::fit().with_max_rank(10).with_nhalfsweeps(4); + let result = contract(&tt1, &tt2, &options).unwrap(); assert_matches_naive(&tt1, &tt2, &result); } @@ -310,7 +329,7 @@ fn test_contract_method_uses_tt_contract() { let options = ContractOptions::zipup(); let result_free = contract(&tt1, &tt2, &options).unwrap(); - let result_method = tt1.contract(&tt2, &options).unwrap(); + let result_method = tt1.contract_pair(&tt2, &options).unwrap(); // Both should produce TTs with the same length assert_eq!(result_free.len(), result_method.len()); @@ -376,7 +395,7 @@ fn test_contract_zipup_with_truncation() { let naive_result = tt1 .to_dense() .unwrap() - .contract(&tt2.to_dense().unwrap()) + .contract_pair(&tt2.to_dense().unwrap()) .unwrap(); let naive_data = naive_result.to_vec::().unwrap(); let result_data = result.to_dense().unwrap().to_vec::().unwrap(); diff --git a/crates/tensor4all-itensorlike/src/linsolve.rs b/crates/tensor4all-itensorlike/src/linsolve.rs index 4aea233b..73b8038a 100644 --- a/crates/tensor4all-itensorlike/src/linsolve.rs +++ b/crates/tensor4all-itensorlike/src/linsolve.rs @@ -54,24 +54,24 @@ pub fn linsolve( validate_svd_truncation_options(options.max_rank(), options.svd_policy())?; - if !options.krylov_tol().is_finite() || options.krylov_tol() <= 0.0 { + if !options.gmres_tol().is_finite() || options.gmres_tol() <= 0.0 { return Err(TensorTrainError::OperationError { message: format!( - "krylov_tol must be finite and > 0, got {}", - options.krylov_tol() + "gmres_tol must be finite and > 0, got {}", + options.gmres_tol() ), }); } - if options.krylov_maxiter() == 0 { + if options.gmres_max_restarts() == 0 { return Err(TensorTrainError::OperationError { - message: "krylov_maxiter must be >= 1".to_string(), + message: "gmres_max_restarts must be >= 1".to_string(), }); } - if options.krylov_dim() == 0 { + if options.gmres_restart_dim() == 0 { return Err(TensorTrainError::OperationError { - message: "krylov_dim must be >= 1".to_string(), + message: "gmres_restart_dim must be >= 1".to_string(), }); } @@ -86,12 +86,13 @@ pub fn linsolve( // Convert LinsolveOptions → treetn::LinsolveOptions let nfullsweeps = options.nhalfsweeps() / 2; + let (a0, a1) = options.coefficients(); let treetn_options = tensor4all_treetn::LinsolveOptions::new(nfullsweeps) .with_truncation(TruncationOptions::new()) - .with_krylov_tol(options.krylov_tol()) - .with_krylov_maxiter(options.krylov_maxiter()) - .with_krylov_dim(options.krylov_dim()) - .with_coefficients(options.coefficients().0, options.coefficients().1); + .with_gmres_tol(options.gmres_tol()) + .with_gmres_max_restarts(options.gmres_max_restarts()) + .with_gmres_restart_dim(options.gmres_restart_dim()) + .with_coefficients(a0, a1); let treetn_options = if let Some(policy) = options.svd_policy() { treetn_options.with_svd_policy(policy) @@ -110,6 +111,7 @@ pub fn linsolve( } else { treetn_options }; + let treetn_options = treetn_options.with_residual_check(options.check_residual()); // Use the last site as the sweep center let center = init.len() - 1; diff --git a/crates/tensor4all-itensorlike/src/options.rs b/crates/tensor4all-itensorlike/src/options.rs index 96207c60..7d0ad238 100644 --- a/crates/tensor4all-itensorlike/src/options.rs +++ b/crates/tensor4all-itensorlike/src/options.rs @@ -2,7 +2,7 @@ use std::ops::Range; -use tensor4all_core::SvdTruncationPolicy; +use tensor4all_core::{AnyScalar, SvdTruncationPolicy}; use crate::error::{Result, TensorTrainError}; @@ -310,7 +310,7 @@ impl ContractOptions { /// let opts = LinsolveOptions::new(5) /// .with_svd_policy(SvdTruncationPolicy::new(1e-10)) /// .with_max_rank(64) -/// .with_krylov_tol(1e-8) +/// .with_gmres_tol(1e-8) /// .with_coefficients(1.0, -1.0); /// /// assert_eq!(opts.max_rank(), Some(64)); @@ -322,12 +322,13 @@ pub struct LinsolveOptions { nhalfsweeps: usize, max_rank: Option, svd_policy: Option, - krylov_tol: f64, - krylov_maxiter: usize, - krylov_dim: usize, - a0: f64, - a1: f64, + gmres_tol: f64, + gmres_max_restarts: usize, + gmres_restart_dim: usize, + a0: AnyScalar, + a1: AnyScalar, convergence_tol: Option, + check_residual: bool, } impl Default for LinsolveOptions { @@ -336,12 +337,13 @@ impl Default for LinsolveOptions { nhalfsweeps: 10, max_rank: None, svd_policy: None, - krylov_tol: 1e-10, - krylov_maxiter: 100, - krylov_dim: 30, - a0: 0.0, - a1: 1.0, + gmres_tol: 1e-10, + gmres_max_restarts: 100, + gmres_restart_dim: 30, + a0: AnyScalar::new_real(0.0), + a1: AnyScalar::new_real(1.0), convergence_tol: None, + check_residual: true, } } } @@ -380,27 +382,34 @@ impl LinsolveOptions { } /// Set GMRES tolerance. - pub fn with_krylov_tol(mut self, tol: f64) -> Self { - self.krylov_tol = tol; + pub fn with_gmres_tol(mut self, tol: f64) -> Self { + self.gmres_tol = tol; self } - /// Set maximum GMRES iterations per local solve. - pub fn with_krylov_maxiter(mut self, maxiter: usize) -> Self { - self.krylov_maxiter = maxiter; + /// Set maximum number of GMRES restart cycles per local solve. + /// + /// This matches KrylovKit's `maxiter` convention. The maximum number of + /// operator expansion steps is roughly `gmres_max_restarts * gmres_restart_dim`. + pub fn with_gmres_max_restarts(mut self, max_restarts: usize) -> Self { + self.gmres_max_restarts = max_restarts; self } - /// Set Krylov subspace dimension (restart parameter). - pub fn with_krylov_dim(mut self, dim: usize) -> Self { - self.krylov_dim = dim; + /// Set GMRES restart cycle length. + pub fn with_gmres_restart_dim(mut self, dim: usize) -> Self { + self.gmres_restart_dim = dim; self } /// Set coefficients `a₀` and `a₁` in `(a₀ + a₁ * A) * x = b`. - pub fn with_coefficients(mut self, a0: f64, a1: f64) -> Self { - self.a0 = a0; - self.a1 = a1; + pub fn with_coefficients(mut self, a0: A0, a1: A1) -> Self + where + A0: Into, + A1: Into, + { + self.a0 = a0.into(); + self.a1 = a1.into(); self } @@ -410,6 +419,12 @@ impl LinsolveOptions { self } + /// Set whether to compute the final true residual after the sweep. + pub fn with_residual_check(mut self, check_residual: bool) -> Self { + self.check_residual = check_residual; + self + } + /// Get the maximum retained bond dimension. #[inline] pub fn max_rank(&self) -> Option { @@ -430,26 +445,26 @@ impl LinsolveOptions { /// Get GMRES tolerance. #[inline] - pub fn krylov_tol(&self) -> f64 { - self.krylov_tol + pub fn gmres_tol(&self) -> f64 { + self.gmres_tol } - /// Get maximum GMRES iterations per local solve. + /// Get maximum number of GMRES restart cycles per local solve. #[inline] - pub fn krylov_maxiter(&self) -> usize { - self.krylov_maxiter + pub fn gmres_max_restarts(&self) -> usize { + self.gmres_max_restarts } - /// Get Krylov subspace dimension. + /// Get GMRES restart cycle length. #[inline] - pub fn krylov_dim(&self) -> usize { - self.krylov_dim + pub fn gmres_restart_dim(&self) -> usize { + self.gmres_restart_dim } /// Get coefficients `(a0, a1)`. #[inline] - pub fn coefficients(&self) -> (f64, f64) { - (self.a0, self.a1) + pub fn coefficients(&self) -> (AnyScalar, AnyScalar) { + (self.a0.clone(), self.a1.clone()) } /// Get convergence tolerance. @@ -457,6 +472,12 @@ impl LinsolveOptions { pub fn convergence_tol(&self) -> Option { self.convergence_tol } + + /// Get whether the final true residual is computed after the sweep. + #[inline] + pub fn check_residual(&self) -> bool { + self.check_residual + } } #[cfg(test)] diff --git a/crates/tensor4all-itensorlike/src/options/tests/mod.rs b/crates/tensor4all-itensorlike/src/options/tests/mod.rs index 4f16ee7e..536849f6 100644 --- a/crates/tensor4all-itensorlike/src/options/tests/mod.rs +++ b/crates/tensor4all-itensorlike/src/options/tests/mod.rs @@ -1,4 +1,6 @@ use super::*; +use num_complex::Complex64; +use tensor4all_core::AnyScalar; use tensor4all_core::SvdTruncationPolicy; #[test] @@ -105,11 +107,15 @@ fn test_contract_method_default() { fn test_linsolve_options_default() { let opts = LinsolveOptions::default(); assert_eq!(opts.nhalfsweeps(), 10); - assert_eq!(opts.coefficients(), (0.0, 1.0)); - assert_eq!(opts.krylov_tol(), 1e-10); - assert_eq!(opts.krylov_maxiter(), 100); - assert_eq!(opts.krylov_dim(), 30); + assert_eq!( + opts.coefficients(), + (AnyScalar::new_real(0.0), AnyScalar::new_real(1.0)) + ); + assert_eq!(opts.gmres_tol(), 1e-10); + assert_eq!(opts.gmres_max_restarts(), 100); + assert_eq!(opts.gmres_restart_dim(), 30); assert_eq!(opts.convergence_tol(), None); + assert!(opts.check_residual()); assert_eq!(opts.svd_policy(), None); assert_eq!(opts.max_rank(), None); } @@ -129,20 +135,39 @@ fn test_linsolve_options_builder() { .with_nsweeps(10) .with_svd_policy(policy) .with_max_rank(100) - .with_krylov_tol(1e-8) - .with_krylov_maxiter(200) - .with_krylov_dim(50) + .with_gmres_tol(1e-8) + .with_gmres_max_restarts(200) + .with_gmres_restart_dim(50) .with_coefficients(1.0, -1.0) - .with_convergence_tol(1e-6); + .with_convergence_tol(1e-6) + .with_residual_check(false); assert_eq!(opts.nhalfsweeps(), 20); assert_eq!(opts.svd_policy(), Some(policy)); assert_eq!(opts.max_rank(), Some(100)); - assert_eq!(opts.krylov_tol(), 1e-8); - assert_eq!(opts.krylov_maxiter(), 200); - assert_eq!(opts.krylov_dim(), 50); - assert_eq!(opts.coefficients(), (1.0, -1.0)); + assert_eq!(opts.gmres_tol(), 1e-8); + assert_eq!(opts.gmres_max_restarts(), 200); + assert_eq!(opts.gmres_restart_dim(), 50); + assert_eq!( + opts.coefficients(), + (AnyScalar::new_real(1.0), AnyScalar::new_real(-1.0)) + ); assert_eq!(opts.convergence_tol(), Some(1e-6)); + assert!(!opts.check_residual()); +} + +#[test] +fn test_linsolve_options_complex_coefficients() { + let opts = LinsolveOptions::default() + .with_coefficients(Complex64::new(0.25, -0.5), AnyScalar::new_complex(1.5, 2.0)); + + assert_eq!( + opts.coefficients(), + ( + AnyScalar::new_complex(0.25, -0.5), + AnyScalar::new_complex(1.5, 2.0) + ) + ); } #[test] diff --git a/crates/tensor4all-itensorlike/src/tensortrain.rs b/crates/tensor4all-itensorlike/src/tensortrain.rs index 2d9aa490..f131f609 100644 --- a/crates/tensor4all-itensorlike/src/tensortrain.rs +++ b/crates/tensor4all-itensorlike/src/tensortrain.rs @@ -7,18 +7,72 @@ //! `TreeTN` where node names are site indices (0, 1, 2, ...). use num_complex::Complex64; +use std::env; use std::ops::Range; -use tensor4all_core::{common_inds, hascommoninds, DynIndex, IndexLike}; +use std::sync::OnceLock; +use std::time::{Duration, Instant}; use tensor4all_core::{ - AllowedPairs, AnyScalar, Canonical, CommonScalar, DirectSumResult, FactorizeAlg, - FactorizeError, FactorizeOptions, FactorizeResult, LinearizationOrder, TensorDynLen, - TensorElement, TensorIndex, TensorLike, + common_inds, contract_pair, contract_pair_with_operand_options, hascommoninds, DynIndex, + IndexLike, PairwiseContractionOptions, +}; +use tensor4all_core::{ + AnyScalar, Canonical, CommonScalar, DirectSumResult, FactorizeAlg, FactorizeError, + FactorizeOptions, FactorizeResult, LinearizationOrder, TensorConstructionLike, + TensorContractionLike, TensorDynLen, TensorElement, TensorFactorizationLike, TensorIndex, + TensorVectorSpace, }; use tensor4all_treetn::{CanonicalizationOptions, TreeTN, TruncationOptions}; use crate::error::{Result, TensorTrainError}; use crate::options::{validate_svd_truncation_options, CanonicalForm, TruncateOptions}; +#[derive(Debug, Default)] +struct TensorTrainInnerProfile { + sim_internal_inds: Duration, + node_lookup: Duration, + right_tensor_clone: Duration, + conj: Duration, + contract: Duration, + final_dims: Duration, + sum: Duration, +} + +fn tensortrain_inner_profile_enabled() -> bool { + static ENABLED: OnceLock = OnceLock::new(); + *ENABLED.get_or_init(|| env::var("T4A_PROFILE_TT_INNER").is_ok()) +} + +fn profile_tt_inner_section(enabled: bool, slot: &mut Duration, f: impl FnOnce() -> T) -> T { + if !enabled { + return f(); + } + let started = Instant::now(); + let result = f(); + *slot += started.elapsed(); + result +} + +fn print_tt_inner_profile(profile: &TensorTrainInnerProfile, length: usize) { + let total = profile.sim_internal_inds + + profile.node_lookup + + profile.right_tensor_clone + + profile.conj + + profile.contract + + profile.final_dims + + profile.sum; + eprintln!( + "tt_inner_profile,L={length},total_ms={:.6},sim_internal_inds_ms={:.6},node_lookup_ms={:.6},right_tensor_clone_ms={:.6},conj_ms={:.6},contract_ms={:.6},final_dims_ms={:.6},sum_ms={:.6}", + total.as_secs_f64() * 1.0e3, + profile.sim_internal_inds.as_secs_f64() * 1.0e3, + profile.node_lookup.as_secs_f64() * 1.0e3, + profile.right_tensor_clone.as_secs_f64() * 1.0e3, + profile.conj.as_secs_f64() * 1.0e3, + profile.contract.as_secs_f64() * 1.0e3, + profile.final_dims.as_secs_f64() * 1.0e3, + profile.sum.as_secs_f64() * 1.0e3, + ); +} + /// Tensor Train with orthogonality tracking. /// /// This type represents a tensor train as a sequence of tensors with tracked @@ -176,11 +230,10 @@ impl TensorTrain { } })?; - let mut tt = Self { + let tt = Self { treetn, canonical_form: None, }; - tt.normalize_site_tensor_orders()?; Ok(tt) } @@ -239,12 +292,78 @@ impl TensorTrain { })?; } } - if tt.has_simple_linear_links() { - tt.normalize_site_tensor_orders()?; - } Ok(tt) } + /// Create a tensor train from a linear-chain [`TreeTN`]. + /// + /// The input tree must use `usize` node names and represent a tensor-train + /// chain. Node names are renumbered to `0..len` when necessary. Site tensor + /// index order is preserved. + /// + /// # Errors + /// + /// Returns an error if the tree cannot be interpreted as a valid tensor + /// train, for example because adjacent site tensors have incompatible + /// shared indices. + /// + /// # Examples + /// + /// ``` + /// use tensor4all_core::{DynIndex, TensorDynLen}; + /// use tensor4all_itensorlike::TensorTrain; + /// use tensor4all_treetn::TreeTN; + /// + /// # fn main() -> anyhow::Result<()> { + /// let site0 = DynIndex::new_dyn(2); + /// let link = DynIndex::new_bond(1)?; + /// let site1 = DynIndex::new_dyn(2); + /// let t0 = TensorDynLen::from_dense(vec![site0, link.clone()], vec![1.0, 0.0])?; + /// let t1 = TensorDynLen::from_dense(vec![link, site1], vec![2.0, 0.0])?; + /// let tree = TreeTN::from_tensors(vec![t0, t1], vec![0usize, 1usize])?; + /// + /// let tt = TensorTrain::from_treetn(tree)?; + /// assert_eq!(tt.len(), 2); + /// assert_eq!( + /// tt.siteinds() + /// .into_iter() + /// .map(|indices| indices.into_iter().map(|idx| idx.size()).collect::>()) + /// .collect::>(), + /// vec![vec![2], vec![2]] + /// ); + /// # Ok(()) + /// # } + /// ``` + pub fn from_treetn(treetn: TreeTN) -> Result { + Self::from_inner(treetn, None) + } + + /// Consume this tensor train and return its underlying [`TreeTN`]. + /// + /// Use this when a chain MPS must be passed to APIs that operate on general + /// tree tensor networks. The returned tree preserves the tensor and index + /// metadata stored in the tensor train. + /// + /// # Examples + /// + /// ``` + /// use tensor4all_core::{DynIndex, TensorDynLen}; + /// use tensor4all_itensorlike::TensorTrain; + /// + /// # fn main() -> anyhow::Result<()> { + /// let site = DynIndex::new_dyn(2); + /// let tensor = TensorDynLen::from_dense(vec![site], vec![1.0, 2.0])?; + /// let tt = TensorTrain::new(vec![tensor])?; + /// + /// let tree = tt.into_treetn(); + /// assert_eq!(tree.node_count(), 1); + /// # Ok(()) + /// # } + /// ``` + pub fn into_treetn(self) -> TreeTN { + self.treetn + } + /// Get a reference to the underlying TreeTN. /// /// This is a crate-internal accessor used by `contract` and `linsolve`. @@ -638,6 +757,69 @@ impl TensorTrain { left_ok && right_ok } + fn with_explicit_unit_links(&self) -> Result { + if self.len() <= 1 { + return Ok(self.clone()); + } + + let mut tensors = (0..self.len()) + .map(|site| self.tensor_checked(site).cloned()) + .collect::>>()?; + for site in 0..tensors.len() - 1 { + let common = common_inds(tensors[site].indices(), tensors[site + 1].indices()); + if common.len() > 1 { + let fused_dim = common.iter().try_fold(1usize, |acc, index| { + acc.checked_mul(index.dim()) + .ok_or_else(|| TensorTrainError::InvalidStructure { + message: "parallel link fusion would overflow index dimension" + .to_string(), + }) + })?; + let fused_link = DynIndex::new_dyn(fused_dim); + tensors[site] = tensors[site] + .fuse_indices(&common, fused_link.clone(), LinearizationOrder::ColumnMajor) + .map_err(|e| TensorTrainError::OperationError { + message: format!("failed to fuse parallel TT links: {e}"), + })?; + tensors[site + 1] = tensors[site + 1] + .fuse_indices(&common, fused_link, LinearizationOrder::ColumnMajor) + .map_err(|e| TensorTrainError::OperationError { + message: format!("failed to fuse parallel TT links: {e}"), + })?; + continue; + } + if common.len() == 1 { + continue; + } + + let link = DynIndex::new_dyn(1); + let left_link = + ::ones(std::slice::from_ref(&link)) + .map_err(|e| TensorTrainError::OperationError { + message: format!("failed to build implicit unit link tensor: {e}"), + })?; + tensors[site] = tensors[site].outer_product(&left_link).map_err(|e| { + TensorTrainError::OperationError { + message: format!("failed to attach implicit unit link: {e}"), + } + })?; + + let right_link = + ::ones(&[link]).map_err(|e| { + TensorTrainError::OperationError { + message: format!("failed to build implicit unit link tensor: {e}"), + } + })?; + tensors[site + 1] = tensors[site + 1].outer_product(&right_link).map_err(|e| { + TensorTrainError::OperationError { + message: format!("failed to attach implicit unit link: {e}"), + } + })?; + } + + Self::new(tensors) + } + fn normalize_site_tensor_order(&mut self, site: usize) -> Result<()> { if !self.can_normalize_site_tensor_order(site) { return Ok(()); @@ -838,8 +1020,7 @@ impl TensorTrain { /// assert!(tt.set_tensor_checked(2, tt.tensor(0).unwrap().clone()).is_err()); /// ``` pub fn set_tensor_checked(&mut self, site: usize, tensor: TensorDynLen) -> Result<()> { - self.set_tensor_raw(site, tensor) - .and_then(|()| self.normalize_site_tensor_order(site))?; + self.set_tensor_raw(site, tensor)?; // Invalidate orthogonality self.treetn .set_canonical_region(Vec::::new()) @@ -1048,7 +1229,12 @@ impl TensorTrain { // Sequential bra-ket contraction along the chain: O(N·D²·d). // TreeTN::inner() uses contract_naive which is O(d^N) and OOMs for large N. - let other_sim = other.treetn.sim_internal_inds(); + let profile_enabled = tensortrain_inner_profile_enabled(); + let mut profile = TensorTrainInnerProfile::default(); + let other_sim = + profile_tt_inner_section(profile_enabled, &mut profile.sim_internal_inds, || { + other.treetn.sim_internal_inds() + }); let node_idx = |ttn: &TreeTN, site: usize| { ttn.node_index(&site) @@ -1059,48 +1245,76 @@ impl TensorTrain { // Start with leftmost tensors - contract over site indices only let mut env = { - let a0_conj = self.tensor_checked(0)?.conj(); - let b0_node = node_idx(&other_sim, 0)?; - let b0 = other_sim - .tensor(b0_node) - .ok_or_else(|| TensorTrainError::InvalidStructure { - message: "missing tensor for site 0 in simulated right operand".to_string(), - })? - .clone(); - a0_conj - .contract(&b0) + let a0 = profile_tt_inner_section(profile_enabled, &mut profile.node_lookup, || { + self.tensor_checked(0) + })?; + let b0_node = + profile_tt_inner_section(profile_enabled, &mut profile.node_lookup, || { + node_idx(&other_sim, 0) + })?; + let b0 = + profile_tt_inner_section(profile_enabled, &mut profile.right_tensor_clone, || { + Ok::( + other_sim + .tensor(b0_node) + .ok_or_else(|| TensorTrainError::InvalidStructure { + message: "missing tensor for site 0 in simulated right operand" + .to_string(), + })? + .clone(), + ) + })?; + profile_tt_inner_section(profile_enabled, &mut profile.contract, || { + contract_pair_with_operand_options( + a0, + &b0, + PairwiseContractionOptions::new().with_lhs_conj(true), + ) .map_err(|err| TensorTrainError::OperationError { message: format!("failed to contract leftmost tensors: {err}"), - })? + }) + })? }; // Sweep through remaining sites for i in 1..self.len() { - let ai_conj = self.tensor_checked(i)?.conj(); - let bi_node = node_idx(&other_sim, i)?; - let bi = + let ai = profile_tt_inner_section(profile_enabled, &mut profile.node_lookup, || { + self.tensor_checked(i) + })?; + let bi_node = + profile_tt_inner_section(profile_enabled, &mut profile.node_lookup, || { + node_idx(&other_sim, i) + })?; + let bi = profile_tt_inner_section(profile_enabled, &mut profile.node_lookup, || { other_sim .tensor(bi_node) .ok_or_else(|| TensorTrainError::InvalidStructure { message: format!("missing tensor for site {i} in simulated right operand"), - })?; + }) + })?; // Contract: env * conj(A_i) (over self's link index) - env = env - .contract(&ai_conj) + env = profile_tt_inner_section(profile_enabled, &mut profile.contract, || { + contract_pair_with_operand_options( + &env, + ai, + PairwiseContractionOptions::new().with_rhs_conj(true), + ) .map_err(|err| TensorTrainError::OperationError { message: format!("failed to contract environment with site {i}: {err}"), - })?; + }) + })?; // Contract: result * B_i (over other's link index and site indices) - env = env - .contract(bi) - .map_err(|err| TensorTrainError::OperationError { + env = profile_tt_inner_section(profile_enabled, &mut profile.contract, || { + contract_pair(&env, bi).map_err(|err| TensorTrainError::OperationError { message: format!("failed to contract right operand at site {i}: {err}"), - })?; + }) + })?; } // Result should be a scalar (0-dimensional tensor) - let dims = env.dims(); + let dims = + profile_tt_inner_section(profile_enabled, &mut profile.final_dims, || env.dims()); let total_size: usize = if dims.is_empty() { 1 } else { @@ -1114,9 +1328,15 @@ impl TensorTrain { ), }); } - env.sum().map_err(|err| TensorTrainError::OperationError { - message: format!("failed to sum scalar inner-product tensor: {err}"), - }) + let result = profile_tt_inner_section(profile_enabled, &mut profile.sum, || { + env.sum().map_err(|err| TensorTrainError::OperationError { + message: format!("failed to sum scalar inner-product tensor: {err}"), + }) + }); + if profile_enabled { + print_tt_inner_profile(&profile, self.len()); + } + result } /// Compute the squared norm of the tensor train. @@ -1375,6 +1595,86 @@ impl TensorTrain { message: format!("TT addition failed: {}", e), })?; + Self::from_inner(result_inner, None)?.with_explicit_unit_links() + } + + /// Add two tensor trains after reindexing `other` to this tensor train's site space. + /// + /// This method is useful when two tensor trains represent the same logical + /// vector space but carry distinct site-index IDs, for example after + /// independent contractions. It pairs site indices site-by-site by + /// dimension, rewrites `other` to use `self`'s site-index IDs, then performs + /// strict tensor-train addition. + /// + /// # Arguments + /// + /// * `other` - The tensor train to reindex and add. It must have the same + /// chain length and compatible site dimensions as `self`. + /// + /// # Returns + /// + /// A tensor train representing `self + other`, with site indices matching + /// `self`. + /// + /// # Errors + /// + /// Returns an error if the two tensor trains have incompatible chain + /// topology, site counts, or site dimensions, or if the strict addition + /// fails after reindexing. + /// + /// # Examples + /// + /// ``` + /// use tensor4all_core::{DynId, Index, TensorDynLen}; + /// use tensor4all_itensorlike::TensorTrain; + /// + /// fn one_site(id: u64, values: Vec) -> TensorTrain { + /// let site = Index::new_with_size(DynId(id), 2); + /// let tensor = TensorDynLen::from_dense(vec![site], values).unwrap(); + /// TensorTrain::new(vec![tensor]).unwrap() + /// } + /// + /// let lhs = one_site(0, vec![1.0, 2.0]); + /// let rhs = one_site(1, vec![3.0, 4.0]); + /// let sum = lhs.add_reindexed_like_self(&rhs).unwrap(); + /// + /// let dense = sum.to_dense().unwrap(); + /// assert_eq!(dense.to_vec::().unwrap(), vec![4.0, 6.0]); + /// assert_eq!(dense.indices()[0], lhs.siteinds()[0][0]); + /// ``` + pub fn add_reindexed_like_self(&self, other: &Self) -> Result { + if self.is_empty() && other.is_empty() { + return Ok(Self::default()); + } + + if self.is_empty() { + return Ok(other.clone()); + } + + if other.is_empty() { + return Ok(self.clone()); + } + + if self.len() != other.len() { + return Err(TensorTrainError::InvalidStructure { + message: format!( + "Tensor trains must have the same length for reindexed addition: {} vs {}", + self.len(), + other.len() + ), + }); + } + + let lhs = self.with_explicit_unit_links()?; + let rhs = other.with_explicit_unit_links()?; + + let result_inner = lhs + .treetn + .add_reindexed_like_self(&rhs.treetn) + .map_err(|e| TensorTrainError::InvalidStructure { + message: format!("TT reindexed addition failed: {}", e), + })?; + Self::from_inner(result_inner, None) } @@ -1497,7 +1797,7 @@ impl TensorIndex for TensorTrain { // TensorLike implementation for TensorTrain // ============================================================================ -impl TensorLike for TensorTrain { +impl TensorVectorSpace for TensorTrain { // ======================================================================== // GMRES-required methods (fully supported) // ======================================================================== @@ -1520,13 +1820,19 @@ impl TensorLike for TensorTrain { fn try_maxabs(&self) -> anyhow::Result { anyhow::bail!( - "TensorTrain does not support TensorLike::maxabs without explicit dense materialization; use TensorTrain::dense_maxabs() for small reference checks or norm-based residuals for long tensor trains" + "TensorTrain does not support TensorVectorSpace::maxabs without explicit dense materialization; use TensorTrain::dense_maxabs() for small reference checks or norm-based residuals for long tensor trains" ) } fn maxabs(&self) -> f64 { f64::NAN } +} + +impl TensorContractionLike for TensorTrain { + // ======================================================================== + // Tensor network operations + // ======================================================================== fn conj(&self) -> Self { let mut result = self.clone(); @@ -1539,37 +1845,8 @@ impl TensorLike for TensorTrain { result } - // ======================================================================== - // Methods not supported by TensorTrain - // ======================================================================== - - fn factorize( - &self, - _left_inds: &[Self::Index], - _options: &FactorizeOptions, - ) -> std::result::Result, FactorizeError> { - Err(FactorizeError::UnsupportedStorage( - "TensorTrain does not support factorize; use orthogonalize() instead", - )) - } - - fn factorize_full_rank( - &self, - _left_inds: &[Self::Index], - _alg: FactorizeAlg, - _canonical: Canonical, - ) -> std::result::Result, FactorizeError> { - Err(FactorizeError::UnsupportedStorage( - "TensorTrain does not support factorize_full_rank; use orthogonalize() instead", - )) - } - - fn contract(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> anyhow::Result { - anyhow::bail!("TensorTrain does not support TensorLike::contract; use TensorTrain::contract() method instead") - } - - fn contract_connected(_tensors: &[&Self], _allowed: AllowedPairs<'_>) -> anyhow::Result { - anyhow::bail!("TensorTrain does not support TensorLike::contract_connected; use TensorTrain::contract() method instead") + fn contract(_tensors: &[&Self]) -> anyhow::Result { + anyhow::bail!("TensorTrain does not support TensorContractionLike::contract; use TensorTrain::contract() method instead") } fn direct_sum( @@ -1594,9 +1871,34 @@ impl TensorLike for TensorTrain { _new_index: Self::Index, _order: LinearizationOrder, ) -> anyhow::Result { - anyhow::bail!("TensorTrain does not support TensorLike::fuse_indices") + anyhow::bail!("TensorTrain does not support TensorContractionLike::fuse_indices") + } +} + +impl TensorFactorizationLike for TensorTrain { + fn factorize( + &self, + _left_inds: &[Self::Index], + _options: &FactorizeOptions, + ) -> std::result::Result, FactorizeError> { + Err(FactorizeError::UnsupportedStorage( + "TensorTrain does not support factorize; use orthogonalize() instead", + )) } + fn factorize_full_rank( + &self, + _left_inds: &[Self::Index], + _alg: FactorizeAlg, + _canonical: Canonical, + ) -> std::result::Result, FactorizeError> { + Err(FactorizeError::UnsupportedStorage( + "TensorTrain does not support factorize_full_rank; use orthogonalize() instead", + )) + } +} + +impl TensorConstructionLike for TensorTrain { fn diagonal(input: &Self::Index, output: &Self::Index) -> anyhow::Result { // Create a single-site TensorTrain with an identity tensor let delta = TensorDynLen::diagonal(input, output)?; diff --git a/crates/tensor4all-itensorlike/src/tensortrain/tests/mod.rs b/crates/tensor4all-itensorlike/src/tensortrain/tests/mod.rs index 256a4f17..7df71b46 100644 --- a/crates/tensor4all-itensorlike/src/tensortrain/tests/mod.rs +++ b/crates/tensor4all-itensorlike/src/tensortrain/tests/mod.rs @@ -1,5 +1,6 @@ use super::*; -use tensor4all_core::{DynId, Index, LinearizationOrder, TensorLike}; +use std::time::Duration; +use tensor4all_core::{DynId, Index, LinearizationOrder, TensorContractionLike, TensorVectorSpace}; /// Helper to create a simple tensor for testing fn make_tensor(indices: Vec) -> TensorDynLen { @@ -34,6 +35,73 @@ fn test_single_site_tt() { assert_eq!(tt.bond_dims(), Vec::::new()); } +#[test] +fn profile_helpers_and_basic_accessors_cover_paths() { + let mut elapsed = Duration::ZERO; + let value = profile_tt_inner_section(true, &mut elapsed, || 42usize); + assert_eq!(value, 42); + assert!(elapsed >= Duration::ZERO); + print_tt_inner_profile(&TensorTrainInnerProfile::default(), 0); + + let tensor = make_tensor(vec![idx(0, 2)]); + let tt = TensorTrain::new(vec![tensor]).unwrap(); + assert_eq!(tt.tensors().len(), 1); + assert!(tt + .tensor_checked(5) + .unwrap_err() + .to_string() + .contains("out of bounds")); + assert_eq!(tt.clone().into_treetn().node_count(), 1); + assert!(tt.norm_squared() >= 0.0); +} + +#[test] +fn add_reindexed_like_self_aligns_site_indices_before_addition() { + let i0 = idx(0, 2); + let i1 = idx(1, 2); + let link = idx(2, 2); + let j0 = idx(10, 2); + let j1 = idx(11, 2); + let rhs_link = idx(12, 2); + + let lhs = TensorTrain::new(vec![ + TensorDynLen::from_dense(vec![i0.clone(), link.clone()], vec![1.0, 2.0, 3.0, 4.0]).unwrap(), + TensorDynLen::from_dense(vec![link.clone(), i1.clone()], vec![5.0, 6.0, 7.0, 8.0]).unwrap(), + ]) + .unwrap(); + let rhs = TensorTrain::new(vec![ + TensorDynLen::from_dense(vec![j0.clone(), rhs_link.clone()], vec![2.0, 3.0, 4.0, 5.0]) + .unwrap(), + TensorDynLen::from_dense( + vec![rhs_link.clone(), j1.clone()], + vec![7.0, 11.0, 13.0, 17.0], + ) + .unwrap(), + ]) + .unwrap(); + + let sum = lhs.add_reindexed_like_self(&rhs).unwrap(); + assert_eq!(sum.len(), 2); + assert_eq!(sum.siteinds(), vec![vec![i0.clone()], vec![i1.clone()]]); + + let dense = sum.to_dense().unwrap(); + let rhs_dense = rhs + .to_dense() + .unwrap() + .replaceinds(&[j0, j1], &[i0, i1]) + .unwrap(); + let expected = lhs + .to_dense() + .unwrap() + .axpby( + AnyScalar::new_real(1.0), + &rhs_dense, + AnyScalar::new_real(1.0), + ) + .unwrap(); + assert!(dense.sub(&expected).unwrap().maxabs() < 1e-12); +} + #[test] fn test_fuse_indices_trait_dispatch_returns_unsupported_error() { let i = idx(0, 2); @@ -41,7 +109,7 @@ fn test_fuse_indices_trait_dispatch_returns_unsupported_error() { let tensor = make_tensor(vec![i.clone()]); let tt = TensorTrain::new(vec![tensor]).unwrap(); - let err = ::fuse_indices( + let err = ::fuse_indices( &tt, &[i], fused, @@ -51,7 +119,7 @@ fn test_fuse_indices_trait_dispatch_returns_unsupported_error() { assert!(err .to_string() - .contains("TensorTrain does not support TensorLike::fuse_indices")); + .contains("TensorTrain does not support TensorContractionLike::fuse_indices")); } #[test] @@ -102,6 +170,52 @@ fn test_multi_site_indices() { assert_eq!(site_inds[1].len(), 1); // site 1 has 1 index } +#[test] +fn test_new_preserves_site_tensor_index_order() { + let s0 = idx(0, 2); + let l01 = idx(1, 3); + let s1 = idx(2, 2); + + let t0 = make_tensor(vec![l01.clone(), s0.clone()]); + let t1 = make_tensor(vec![s1.clone(), l01.clone()]); + + let tt = TensorTrain::new(vec![t0, t1]).unwrap(); + + assert_eq!(tt.tensor(0).unwrap().indices(), &[l01.clone(), s0]); + assert_eq!(tt.tensor(1).unwrap().indices(), &[s1, l01]); +} + +#[test] +fn test_with_ortho_preserves_site_tensor_index_order() { + let s0 = idx(0, 2); + let l01 = idx(1, 3); + let s1 = idx(2, 2); + + let t0 = make_tensor(vec![l01.clone(), s0.clone()]); + let t1 = make_tensor(vec![s1.clone(), l01.clone()]); + + let tt = TensorTrain::with_ortho(vec![t0, t1], -1, 1, Some(CanonicalForm::Unitary)).unwrap(); + + assert_eq!(tt.tensor(0).unwrap().indices(), &[l01.clone(), s0]); + assert_eq!(tt.tensor(1).unwrap().indices(), &[s1, l01]); +} + +#[test] +fn test_from_treetn_preserves_site_tensor_index_order() { + let s0 = idx(0, 2); + let l01 = idx(1, 3); + let s1 = idx(2, 2); + + let t0 = make_tensor(vec![l01.clone(), s0.clone()]); + let t1 = make_tensor(vec![s1.clone(), l01.clone()]); + let tree = tensor4all_treetn::TreeTN::from_tensors(vec![t0, t1], vec![0usize, 1usize]).unwrap(); + + let tt = TensorTrain::from_treetn(tree).unwrap(); + + assert_eq!(tt.tensor(0).unwrap().indices(), &[l01.clone(), s0]); + assert_eq!(tt.tensor(1).unwrap().indices(), &[s1, l01]); +} + #[test] fn test_ortho_tracking() { let s0 = idx(0, 2); @@ -296,10 +410,18 @@ fn test_contract_with_fit_method() { // Test contract with Fit method let options = ContractOptions::fit().with_max_rank(10).with_nhalfsweeps(4); // 4 half-sweeps = 2 full sweeps - let result = tt1.contract(&tt2, &options); + let result = tt1.contract_pair(&tt2, &options); assert!(result.is_ok()); let result_tt = result.unwrap(); - assert_eq!(result_tt.len(), 1); + let naive_result = tt1 + .to_dense() + .unwrap() + .contract_pair(&tt2.to_dense().unwrap()) + .unwrap(); + assert!(result_tt + .to_dense() + .unwrap() + .isapprox(&naive_result, 1e-10, 0.0)); } #[test] @@ -317,10 +439,18 @@ fn test_contract_with_naive_method() { // Test contract with Naive method let options = ContractOptions::naive().with_dense_reference_limit(2); - let result = tt1.contract(&tt2, &options); + let result = tt1.contract_pair(&tt2, &options); assert!(result.is_ok()); let result_tt = result.unwrap(); - assert_eq!(result_tt.len(), 1); + let naive_result = tt1 + .to_dense() + .unwrap() + .contract_pair(&tt2.to_dense().unwrap()) + .unwrap(); + assert!(result_tt + .to_dense() + .unwrap() + .isapprox(&naive_result, 1e-10, 0.0)); } #[test] @@ -344,10 +474,18 @@ fn test_contract_nhalfsweeps_conversion() { // Test that nhalfsweeps is correctly converted to nfullsweeps // nhalfsweeps=6 should become nfullsweeps=3 let options = ContractOptions::fit().with_nhalfsweeps(6).with_max_rank(10); - let result = tt1.contract(&tt2, &options); + let result = tt1.contract_pair(&tt2, &options); assert!(result.is_ok()); let result_tt = result.unwrap(); - assert_eq!(result_tt.len(), 1); + let naive_result = tt1 + .to_dense() + .unwrap() + .contract_pair(&tt2.to_dense().unwrap()) + .unwrap(); + assert!(result_tt + .to_dense() + .unwrap() + .isapprox(&naive_result, 1e-10, 0.0)); } #[test] @@ -368,7 +506,7 @@ fn test_contract_fit_odd_nhalfsweeps_errors() { let tt2 = TensorTrain::new(vec![t2_0, t2_1]).unwrap(); let options = ContractOptions::fit().with_nhalfsweeps(1).with_max_rank(10); - let err = tt1.contract(&tt2, &options).unwrap_err(); + let err = tt1.contract_pair(&tt2, &options).unwrap_err(); assert!(matches!(err, TensorTrainError::OperationError { .. })); } @@ -424,7 +562,7 @@ fn test_to_dense() { let dense = tt.to_dense().unwrap(); // Expected: contract t0 and t1 along l01 - let expected = t0.contract(&t1).unwrap(); + let expected = t0.contract_pair(&t1).unwrap(); // Compare results let dense_data = dense.to_vec::().unwrap(); @@ -483,7 +621,7 @@ fn test_to_dense_three_sites() { let dense = tt.to_dense().unwrap(); // Expected: contract t0, t1, t2 sequentially - let expected = t0.contract(&t1).unwrap().contract(&t2).unwrap(); + let expected = t0.contract_pair(&t1).unwrap().contract_pair(&t2).unwrap(); let dense_data = dense.to_vec::().unwrap(); let expected_data = expected.to_vec::().unwrap(); @@ -688,6 +826,22 @@ fn test_set_tensor_checked_invalid_site_errors() { )); } +#[test] +fn test_set_tensor_checked_preserves_replacement_index_order() { + let s0 = idx(0, 2); + let l01 = idx(1, 3); + let s1 = idx(2, 2); + + let t0 = make_tensor(vec![s0.clone(), l01.clone()]); + let t1 = make_tensor(vec![l01.clone(), s1]); + let replacement = make_tensor(vec![l01.clone(), s0.clone()]); + + let mut tt = TensorTrain::new(vec![t0, t1]).unwrap(); + tt.set_tensor_checked(0, replacement).unwrap(); + + assert_eq!(tt.tensor(0).unwrap().indices(), &[l01, s0]); +} + #[test] fn test_tensor_mut_checked_invalid_site_errors() { let s0 = idx(0, 2); @@ -830,8 +984,6 @@ fn test_axpby() { #[test] fn test_tensor_like_scale() { - use tensor4all_core::TensorLike; - let s0 = idx(0, 2); let l01 = idx(1, 3); let s1 = idx(2, 2); @@ -841,11 +993,11 @@ fn test_tensor_like_scale() { let tt = TensorTrain::new(vec![t0, t1]).unwrap(); - // Use TensorLike::scale - let scaled = TensorLike::scale(&tt, AnyScalar::new_real(2.0)).unwrap(); + // Use TensorVectorSpace::scale + let scaled = TensorVectorSpace::scale(&tt, AnyScalar::new_real(2.0)).unwrap(); let orig_norm = tt.norm(); - let scaled_norm = TensorLike::norm(&scaled); + let scaled_norm = TensorVectorSpace::norm(&scaled); assert!( (scaled_norm - 2.0 * orig_norm).abs() < 1e-10, "Expected scaled_norm = {}, got {}", @@ -856,7 +1008,7 @@ fn test_tensor_like_scale() { #[test] fn test_tensor_like_inner_product() { - use tensor4all_core::TensorLike; + use tensor4all_core::TensorVectorSpace; let s0 = idx(0, 2); let l01 = idx(1, 3); @@ -867,8 +1019,8 @@ fn test_tensor_like_inner_product() { let tt = TensorTrain::new(vec![t0, t1]).unwrap(); - // TensorLike::inner_product should equal TensorTrain::inner - let inner_via_trait = TensorLike::inner_product(&tt, &tt).unwrap(); + // TensorVectorSpace::inner_product should equal TensorTrain::inner + let inner_via_trait = TensorVectorSpace::inner_product(&tt, &tt).unwrap(); let inner_direct = tt.inner(&tt).unwrap(); assert!( @@ -1225,8 +1377,6 @@ fn test_dense_maxabs_is_explicit_dense_reference_api() { #[test] fn test_tensor_like_maxabs_is_not_hidden_dense_for_tensor_train() { - use tensor4all_core::TensorLike; - let s0 = idx(0, 2); let l01 = idx(1, 3); let s1 = idx(2, 2); @@ -1236,14 +1386,14 @@ fn test_tensor_like_maxabs_is_not_hidden_dense_for_tensor_train() { let tt = TensorTrain::new(vec![t0, t1]).unwrap(); - let err = TensorLike::try_maxabs(&tt).unwrap_err(); + let err = TensorVectorSpace::try_maxabs(&tt).unwrap_err(); assert!(err.to_string().contains("explicit dense materialization")); - assert!(TensorLike::maxabs(&tt).is_nan()); + assert!(TensorVectorSpace::maxabs(&tt).is_nan()); } #[test] fn test_tensor_like_conj() { - use tensor4all_core::TensorLike; + use tensor4all_core::TensorContractionLike; let s0 = idx(0, 2); let l01 = idx(1, 3); @@ -1255,7 +1405,7 @@ fn test_tensor_like_conj() { let tt = TensorTrain::new(vec![t0, t1]).unwrap(); // For real tensors, conj should be identical - let conj_tt = TensorLike::conj(&tt); + let conj_tt = TensorContractionLike::conj(&tt); assert_eq!(conj_tt.len(), tt.len()); let orig_dense = tt.to_dense().unwrap(); diff --git a/crates/tensor4all-itensorlike/tests/linsolve_mpo.rs b/crates/tensor4all-itensorlike/tests/linsolve_mpo.rs index f407dd7b..8e31caa8 100644 --- a/crates/tensor4all-itensorlike/tests/linsolve_mpo.rs +++ b/crates/tensor4all-itensorlike/tests/linsolve_mpo.rs @@ -4,7 +4,8 @@ //! Previously, `linsolve` failed with index mismatch errors because //! it did not pass IndexMapping to the treetn solver. -use tensor4all_core::{DynIndex, TensorDynLen}; +use num_complex::Complex64; +use tensor4all_core::{AnyScalar, DynIndex, TensorDynLen}; use tensor4all_itensorlike::{LinsolveOptions, TensorTrain}; /// Build a 2-site identity MPO where input indices are the SAME objects @@ -65,8 +66,8 @@ fn test_linsolve_identity_mpo_distinct_output_indices() { // This previously failed with "Index count mismatch" or "index structure mismatch" let options = LinsolveOptions::new(3) - .with_krylov_tol(1e-10) - .with_krylov_dim(10) + .with_gmres_tol(1e-10) + .with_gmres_restart_dim(10) .with_max_rank(4); let result = operator.linsolve(&rhs, init, &options).unwrap(); @@ -74,3 +75,71 @@ fn test_linsolve_identity_mpo_distinct_output_indices() { // For identity operator, solution should match RHS assert_eq!(result.len(), 2); } + +#[test] +fn test_linsolve_identity_mpo_accepts_complex_coefficients() { + let phys_dim = 2; + + let s0 = DynIndex::new_dyn(phys_dim); + let s1 = DynIndex::new_dyn(phys_dim); + let s0_out = DynIndex::new_dyn(phys_dim); + let s1_out = DynIndex::new_dyn(phys_dim); + let b_mps = DynIndex::new_dyn(phys_dim); + let b_mpo = DynIndex::new_dyn(1); + + let mut data0 = vec![Complex64::new(0.0, 0.0); phys_dim * phys_dim]; + for i in 0..phys_dim { + data0[i * phys_dim + i] = Complex64::new(1.0, 0.0); + } + let t0_mps = TensorDynLen::from_dense(vec![s0.clone(), b_mps.clone()], data0).unwrap(); + + let rhs_values = vec![ + Complex64::new(1.0, 0.5), + Complex64::new(2.0, -1.0), + Complex64::new(-0.25, 0.75), + Complex64::new(0.5, 1.5), + ]; + let t1_mps = + TensorDynLen::from_dense(vec![b_mps.clone(), s1.clone()], rhs_values.clone()).unwrap(); + let rhs = TensorTrain::new(vec![t0_mps.clone(), t1_mps]).unwrap(); + + let zero0 = t0_mps.scale(AnyScalar::new_real(0.0)).unwrap(); + let zero1 = TensorDynLen::from_dense( + vec![b_mps.clone(), s1.clone()], + vec![Complex64::new(0.0, 0.0); phys_dim * phys_dim], + ) + .unwrap(); + let init = TensorTrain::new(vec![zero0, zero1]).unwrap(); + + let mut id_data = vec![Complex64::new(0.0, 0.0); phys_dim * phys_dim]; + for i in 0..phys_dim { + id_data[i * phys_dim + i] = Complex64::new(1.0, 0.0); + } + let t0_mpo = TensorDynLen::from_dense( + vec![s0_out.clone(), s0.clone(), b_mpo.clone()], + id_data.clone(), + ) + .unwrap(); + let t1_mpo = TensorDynLen::from_dense(vec![b_mpo, s1_out, s1.clone()], id_data).unwrap(); + let operator = TensorTrain::new(vec![t0_mpo, t1_mpo]).unwrap(); + + let options = LinsolveOptions::new(3) + .with_gmres_tol(1e-10) + .with_gmres_restart_dim(10) + .with_max_rank(4) + .with_coefficients(0.0, Complex64::new(0.0, 1.0)); + + let result = operator.linsolve(&rhs, init, &options).unwrap(); + let dense = result.to_dense().unwrap().to_vec::().unwrap(); + let expected = rhs_values + .iter() + .map(|value| Complex64::new(0.0, -1.0) * value) + .collect::>(); + + for (actual, expected) in dense.iter().zip(expected.iter()) { + assert!( + (*actual - *expected).norm() < 1e-8, + "actual={actual:?}, expected={expected:?}" + ); + } +} diff --git a/crates/tensor4all-partitionedtt/src/contract/tests/mod.rs b/crates/tensor4all-partitionedtt/src/contract/tests/mod.rs index cb318246..6c536451 100644 --- a/crates/tensor4all-partitionedtt/src/contract/tests/mod.rs +++ b/crates/tensor4all-partitionedtt/src/contract/tests/mod.rs @@ -1,7 +1,7 @@ use super::*; use num_complex::Complex64; use tensor4all_core::index::Index; -use tensor4all_core::{DynIndex, TensorDynLen, TensorElement}; +use tensor4all_core::{DynIndex, TensorContractionLike, TensorDynLen, TensorElement}; use tensor4all_itensorlike::TensorTrain; /// Trait for scalar types used in tests. @@ -279,7 +279,7 @@ fn test_contract_numerical_correctness_generic() { // Manually compute expected contraction: R = T1 * T2 (contracting s1) // T1 has indices [s0, s1], T2 has indices [s1, s2] // Result R has indices [s0, s2] - let expected = t1_full.contract(&t2_full).unwrap(); + let expected = t1_full.contract_pair(&t2_full).unwrap(); // Compare: both should have the same data let contracted_data = T::extract_slice(&contracted_full); @@ -361,7 +361,7 @@ fn test_contract_with_projectors_numerical_correctness_generic() let t2_proj = project_dense_tensor_at_index::(&t2_full, &s2, 1); // Contract projected tensors - let expected = t1_proj.contract(&t2_proj).unwrap(); + let expected = t1_proj.contract_pair(&t2_proj).unwrap(); // Compare let contracted_data = T::extract_slice(&contracted_full); @@ -419,7 +419,7 @@ fn test_contract_with_projectors_numerical_correctness_default_zipup_generic(&t1_full, &s0, s0_val); let t2_proj = project_dense_tensor_at_index::(&t2_full, &s2, s2_val); - let expected = t1_proj.contract(&t2_proj).unwrap(); + let expected = t1_proj.contract_pair(&t2_proj).unwrap(); let contracted_data = T::extract_slice(&contracted_full); let expected_data = T::extract_slice(&expected); @@ -484,7 +484,7 @@ fn test_contract_with_projector_on_contracted_index_generic() { let t1_proj = project_dense_tensor_at_index::(&t1_full, &s1, 0); let t2_proj = project_dense_tensor_at_index::(&t2_full, &s1, 0); - let expected = t1_proj.contract(&t2_proj).unwrap(); + let expected = t1_proj.contract_pair(&t2_proj).unwrap(); let contracted_data = T::extract_slice(&contracted_full); let expected_data = T::extract_slice(&expected); @@ -540,7 +540,7 @@ fn test_contract_one_side_has_projector_generic() { // Compute expected: project t1 to s0=1, t2 unchanged let t1_proj = project_dense_tensor_at_index::(&t1_full, &s0, 1); - let expected = t1_proj.contract(&t2_full).unwrap(); + let expected = t1_proj.contract_pair(&t2_full).unwrap(); let contracted_data = T::extract_slice(&contracted_full); let expected_data = T::extract_slice(&expected); @@ -597,7 +597,7 @@ fn test_proj_contract_numerical_correctness_generic() { let t1_proj = project_dense_tensor_at_index::(&t1_full, &s0, 0); let t2_proj = project_dense_tensor_at_index::(&t2_full, &s2, 1); - let expected = t1_proj.contract(&t2_proj).unwrap(); + let expected = t1_proj.contract_pair(&t2_proj).unwrap(); let contracted_data = T::extract_slice(&contracted_full); let expected_data = T::extract_slice(&expected); diff --git a/crates/tensor4all-partitionedtt/src/partitioned_tt.rs b/crates/tensor4all-partitionedtt/src/partitioned_tt.rs index 159e4666..73a1c32c 100644 --- a/crates/tensor4all-partitionedtt/src/partitioned_tt.rs +++ b/crates/tensor4all-partitionedtt/src/partitioned_tt.rs @@ -257,12 +257,15 @@ impl PartitionedTT { // Check if we already have a subdomain with the same projector if let Some(existing) = result.get_mut(&proj) { // Sum the subdomains using TT addition - let mut summed_tt = existing.data().add(contracted.data()).map_err(|e| { - PartitionedTTError::TensorTrainError(format!( - "TT addition in contract failed: {}", - e - )) - })?; + let mut summed_tt = existing + .data() + .add_reindexed_like_self(contracted.data()) + .map_err(|e| { + PartitionedTTError::TensorTrainError(format!( + "TT addition in contract failed: {}", + e + )) + })?; // Truncate after addition using the same truncation params as contraction let mut truncate_opts = TruncateOptions::svd(); if let Some(policy) = options.svd_policy() { diff --git a/crates/tensor4all-partitionedtt/src/partitioned_tt/tests/mod.rs b/crates/tensor4all-partitionedtt/src/partitioned_tt/tests/mod.rs index 61658e8a..379d920f 100644 --- a/crates/tensor4all-partitionedtt/src/partitioned_tt/tests/mod.rs +++ b/crates/tensor4all-partitionedtt/src/partitioned_tt/tests/mod.rs @@ -1,6 +1,6 @@ use super::*; use tensor4all_core::index::Index; -use tensor4all_core::TensorDynLen; +use tensor4all_core::{TensorContractionLike, TensorDynLen}; fn make_index(size: usize) -> DynIndex { Index::new_dyn(size) @@ -245,7 +245,7 @@ fn test_partitioned_tt_contract_numerical() { // Project t2 to s2=s2_val let t2_proj = project_dense_tensor_at_index(&t2_full, &s2, s2_val); - let expected = t1_proj.contract(&t2_proj).unwrap(); + let expected = t1_proj.contract_pair(&t2_proj).unwrap(); let expected_data = expected.to_vec::().unwrap(); assert_eq!( diff --git a/crates/tensor4all-quanticstransform/src/affine.rs b/crates/tensor4all-quanticstransform/src/affine.rs index 01896d32..684851d7 100644 --- a/crates/tensor4all-quanticstransform/src/affine.rs +++ b/crates/tensor4all-quanticstransform/src/affine.rs @@ -14,6 +14,8 @@ use num_integer::Integer; use num_rational::Rational64; use num_traits::One; use sprs::CsMat; +use tensor4all_core::index::{DynId, Index, TagSet}; +use tensor4all_core::LinearizationOrder; use tensor4all_simplett::{types::tensor3_zeros, AbstractTensorTrain, Tensor3Ops, TensorTrain}; use crate::common::{ @@ -488,6 +490,93 @@ pub fn affine_operator( tensortrain_to_linear_operator_asymmetric(&remapped_mpo, &input_dims, &output_dims) } +/// Create an affine operator with interleaved binary variable indices. +/// +/// This is the same forward coordinate map as [`affine_operator`], but each bit +/// node carries one binary output index per output variable and one binary input +/// index per input variable instead of fusing variables into local dimensions +/// `2^M` and `2^N`. The mapping order at each node is +/// `(y0, y1, ..., yM-1)` for outputs and `(x0, x1, ..., xN-1)` for inputs. +/// +/// Use this form when the state stores variables as separate interleaved QTT +/// site indices and should bind them through [`LinearOperator::new_multi`]. +/// +/// # Arguments +/// +/// * `r` - Bits per variable. Node `0` is the most significant bit. +/// * `params` - Rational affine map `y = A*x + b`. +/// * `bc` - Boundary condition for each output variable. +/// +/// # Returns +/// +/// A [`LinearOperator`] whose node `i` has `params.n` input mappings and +/// `params.m` output mappings, all with binary dimension. +/// +/// # Errors +/// +/// Returns an error when `r == 0`, when `bc.len() != params.m`, or when the +/// affine tensor network cannot be constructed. +/// +/// # Examples +/// +/// ``` +/// use tensor4all_quanticstransform::{ +/// affine_operator_interleaved, AffineParams, BoundaryCondition, +/// }; +/// +/// let params = AffineParams::from_integers(vec![1, 0, 0, 1], vec![0, 0], 2, 2).unwrap(); +/// let bc = vec![BoundaryCondition::Periodic; 2]; +/// let op = affine_operator_interleaved(3, ¶ms, &bc).unwrap(); +/// +/// assert_eq!(op.mpo().node_count(), 3); +/// assert_eq!(op.get_output_mappings(&0).unwrap().len(), 2); +/// assert_eq!(op.get_input_mappings(&0).unwrap().len(), 2); +/// ``` +pub fn affine_operator_interleaved( + r: usize, + params: &AffineParams, + bc: &[BoundaryCondition], +) -> Result { + let mut op = affine_operator(r, params, bc)?; + + let fused_output_indices = (0..r) + .map(|site| { + op.get_output_mapping(&site) + .ok_or_else(|| anyhow::anyhow!("missing affine output mapping for site {site}")) + .map(|mapping| mapping.true_index.clone()) + }) + .collect::>>()?; + let fused_input_indices = (0..r) + .map(|site| { + op.get_input_mapping(&site) + .ok_or_else(|| anyhow::anyhow!("missing affine input mapping for site {site}")) + .map(|mapping| mapping.true_index.clone()) + }) + .collect::>>()?; + + for site in 0..r { + let output_indices = (0..params.m) + .map(|_| Index::::new_dyn(2)) + .collect::>(); + op = op.unfuse_output_index( + &fused_output_indices[site], + &output_indices, + LinearizationOrder::ColumnMajor, + )?; + + let input_indices = (0..params.n) + .map(|_| Index::::new_dyn(2)) + .collect::>(); + op = op.unfuse_input_index( + &fused_input_indices[site], + &input_indices, + LinearizationOrder::ColumnMajor, + )?; + } + + Ok(op) +} + /// Compute the full affine transformation matrix directly (for verification). /// /// This computes the transformation matrix by directly evaluating y = A*x + b diff --git a/crates/tensor4all-quanticstransform/src/affine/tests/mod.rs b/crates/tensor4all-quanticstransform/src/affine/tests/mod.rs index eab942ef..cf10faed 100644 --- a/crates/tensor4all-quanticstransform/src/affine/tests/mod.rs +++ b/crates/tensor4all-quanticstransform/src/affine/tests/mod.rs @@ -532,6 +532,74 @@ fn affine_matrix_to_dense_tensor( TensorDynLen::from_dense(indices, data).expect("failed to build affine reference tensor") } +fn affine_interleaved_matrix_to_dense_tensor( + matrix: &CsMat, + op: &QuanticsOperator, + r: usize, + m: usize, + n: usize, + template: &TensorDynLen, +) -> TensorDynLen { + let indices = template.indices().to_vec(); + let dims = template.dims(); + let mut id_to_pos = std::collections::HashMap::new(); + for (pos, index) in indices.iter().enumerate() { + id_to_pos.insert(*index.id(), pos); + } + + let output_positions = (0..r) + .map(|site| { + op.get_output_mappings(&site) + .expect("missing output mappings") + .iter() + .map(|mapping| { + *id_to_pos + .get(mapping.internal_index.id()) + .expect("output index not found in contracted tensor") + }) + .collect::>() + }) + .collect::>(); + let input_positions = (0..r) + .map(|site| { + op.get_input_mappings(&site) + .expect("missing input mappings") + .iter() + .map(|mapping| { + *id_to_pos + .get(mapping.internal_index.id()) + .expect("input index not found in contracted tensor") + }) + .collect::>() + }) + .collect::>(); + + let mut data = vec![Complex64::new(0.0, 0.0); dims.iter().product()]; + let mut coords = vec![0usize; dims.len()]; + + for (y_flat, row) in matrix.outer_iterator().enumerate() { + for (x_flat, value) in row.iter() { + coords.fill(0); + for site in 0..r { + let bit_pos = r - 1 - site; + for var in 0..m { + let y_var = (y_flat >> (var * r)) & ((1 << r) - 1); + coords[output_positions[site][var]] = (y_var >> bit_pos) & 1; + } + for var in 0..n { + let x_var = (x_flat >> (var * r)) & ((1 << r) - 1); + coords[input_positions[site][var]] = (x_var >> bit_pos) & 1; + } + } + let offset = column_major_offset(&dims, &coords); + data[offset] = Complex64::new(*value, 0.0); + } + } + + TensorDynLen::from_dense(indices, data) + .expect("failed to build interleaved affine reference tensor") +} + /// Assert that the MPO representation matches the direct sparse matrix computation /// for all elements. This is the primary correctness check: two independent algorithms /// (carry-based MPO vs direct enumeration) must agree. @@ -1003,6 +1071,35 @@ fn test_unfused_vs_fused_equivalence() { } } +#[test] +fn test_affine_operator_interleaved_matches_matrix() { + let r = 2; + let params = AffineParams::from_integers(vec![1, 1, 0, 1], vec![0, 0], 2, 2).unwrap(); + let bc = vec![BoundaryCondition::Periodic; 2]; + + let matrix = affine_transform_matrix(r, ¶ms, &bc).unwrap(); + let op = affine_operator_interleaved(r, ¶ms, &bc).unwrap(); + + for site in 0..r { + assert_eq!(op.get_output_mappings(&site).unwrap().len(), params.m); + assert_eq!(op.get_input_mappings(&site).unwrap().len(), params.n); + for mapping in op.get_output_mappings(&site).unwrap() { + assert_eq!(mapping.true_index.dim(), 2); + assert_eq!(mapping.internal_index.dim(), 2); + } + for mapping in op.get_input_mappings(&site).unwrap() { + assert_eq!(mapping.true_index.dim(), 2); + assert_eq!(mapping.internal_index.dim(), 2); + } + } + + let actual = op.mpo.contract_to_tensor().unwrap(); + let expected = + affine_interleaved_matrix_to_dense_tensor(&matrix, &op, r, params.m, params.n, &actual); + let maxabs = actual.distance(&expected).unwrap(); + assert!(maxabs < 1e-10, "interleaved affine maxabs={maxabs}"); +} + #[test] fn test_affine_parametric_full() { // From Quantics.jl "full R=$R, boundary=$boundary, M=$M, N=$N" test diff --git a/crates/tensor4all-quanticstransform/src/lib.rs b/crates/tensor4all-quanticstransform/src/lib.rs index 72f22700..7049e729 100644 --- a/crates/tensor4all-quanticstransform/src/lib.rs +++ b/crates/tensor4all-quanticstransform/src/lib.rs @@ -67,8 +67,8 @@ mod phase_rotation; mod shift; pub use affine::{ - affine_operator, affine_transform_matrix, affine_transform_tensors_unfused, AffineParams, - LinearConstraintRow, UnfusedTensorInfo, + affine_operator, affine_operator_interleaved, affine_transform_matrix, + affine_transform_tensors_unfused, AffineParams, LinearConstraintRow, UnfusedTensorInfo, }; pub use common::{BoundaryCondition, CarryDirection}; pub use cumsum::{cumsum_operator, triangle_operator, TriangleType}; diff --git a/crates/tensor4all-tensorbackend/Cargo.toml b/crates/tensor4all-tensorbackend/Cargo.toml index 07841dc7..30ea07b5 100644 --- a/crates/tensor4all-tensorbackend/Cargo.toml +++ b/crates/tensor4all-tensorbackend/Cargo.toml @@ -23,6 +23,7 @@ thiserror.workspace = true anyhow.workspace = true rand.workspace = true rand_distr.workspace = true +omeco.workspace = true tenferro.workspace = true tenferro-device.workspace = true tenferro-einsum.workspace = true diff --git a/crates/tensor4all-tensorbackend/src/context.rs b/crates/tensor4all-tensorbackend/src/context.rs index d28ec92f..14b00337 100644 --- a/crates/tensor4all-tensorbackend/src/context.rs +++ b/crates/tensor4all-tensorbackend/src/context.rs @@ -2,18 +2,21 @@ //! //! tensor4all-rs routes tenferro CPU execution through one process-global //! `CpuContext`, matching tenferro's `cpu:0` default-global thread-pool model. -//! Plain tensor operations and eager AD currently use separate `CpuBackend` -//! values because tenferro does not yet expose a public API for borrowing the -//! backend owned by an `EagerContext`. Both backends are created -//! from the same global CPU context, so thread-pool configuration is shared. +//! Plain tensor operations, cached traced execution, and eager AD currently use +//! separate `CpuBackend` values because tenferro does not yet expose a public +//! API for borrowing the backend owned by an `EagerContext`. All +//! backends are created from the same global CPU context, so thread-pool +//! configuration is shared. use std::sync::{Arc, Mutex, OnceLock}; -use tenferro::{CpuBackend, EagerContext}; +use tenferro::{CpuBackend, EagerContext, Engine}; +use tenferro_tensor::buffer_pool::BufferPoolStats; use tenferro_tensor::cpu::CpuContext; static DEFAULT_CPU_CONTEXT: OnceLock> = OnceLock::new(); static DEFAULT_BACKEND: OnceLock> = OnceLock::new(); +static DEFAULT_ENGINE: OnceLock>> = OnceLock::new(); static DEFAULT_EAGER_CTX: OnceLock>> = OnceLock::new(); fn default_cpu_context() -> Arc { @@ -26,6 +29,18 @@ fn default_backend() -> &'static Mutex { DEFAULT_BACKEND.get_or_init(|| Mutex::new(CpuBackend::from_context(default_cpu_context()))) } +fn default_engine() -> &'static Mutex> { + DEFAULT_ENGINE + .get_or_init(|| Mutex::new(Engine::new(CpuBackend::from_context(default_cpu_context())))) +} + +fn lock_default_engine() -> std::sync::MutexGuard<'static, Engine> { + match default_engine().lock() { + Ok(guard) => guard, + Err(poisoned) => poisoned.into_inner(), + } +} + /// Run a closure against the process-global CPU backend. /// /// This is the canonical entry point for typed and untyped tenferro tensor @@ -38,10 +53,40 @@ pub fn with_default_backend(f: impl FnOnce(&mut CpuBackend) -> R) -> R { f(&mut backend) } +/// Run a closure against the process-global tenferro execution engine. +/// +/// This is used for native tensor operations that benefit from tenferro's +/// persistent execution caches, such as N-ary einsum contraction paths. +pub(crate) fn with_default_engine(f: impl FnOnce(&mut Engine) -> R) -> R { + let mut engine = lock_default_engine(); + f(&mut engine) +} + +/// Return retained-buffer statistics for the process-global execution engine. +pub(crate) fn default_engine_buffer_pool_stats() -> BufferPoolStats { + lock_default_engine().buffer_pool_stats() +} + +/// Reset retained buffers in the process-global execution engine. +pub(crate) fn reset_default_engine_buffer_pool() { + lock_default_engine().reset_buffer_pool(); +} + +/// Drop and recreate the process-global execution engine. +/// +/// This releases tenferro's retained execution buffers and cached contraction +/// paths. It is intended for diagnostics and memory-pressure recovery, not for +/// normal hot loops where the caches are valuable. +pub(crate) fn reset_default_engine() { + let mut engine = lock_default_engine(); + *engine = Engine::new(CpuBackend::from_context(default_cpu_context())); +} + /// Return the process-global eager context used for reverse-mode AD. /// -/// This context owns a separate `CpuBackend` from [`with_default_backend`], but -/// both backends share the same process-global tenferro CPU context. +/// This context owns a separate `CpuBackend` from [`with_default_backend`] and +/// the cached execution engine, but all backends share the same process-global +/// tenferro CPU context. pub fn default_eager_ctx() -> Arc> { DEFAULT_EAGER_CTX .get_or_init(|| EagerContext::with_backend(CpuBackend::from_context(default_cpu_context()))) @@ -80,4 +125,15 @@ mod tests { assert_eq!(main_threads, worker_threads); } + + #[test] + fn default_engine_is_shared_across_threads() { + let main_threads = with_default_engine(|engine| engine.backend().num_threads()); + let worker_threads = + std::thread::spawn(|| with_default_engine(|engine| engine.backend().num_threads())) + .join() + .expect("worker thread should complete"); + + assert_eq!(main_threads, worker_threads); + } } diff --git a/crates/tensor4all-tensorbackend/src/lib.rs b/crates/tensor4all-tensorbackend/src/lib.rs index 57fb8dd3..3fb0d401 100644 --- a/crates/tensor4all-tensorbackend/src/lib.rs +++ b/crates/tensor4all-tensorbackend/src/lib.rs @@ -19,6 +19,8 @@ mod backend; mod context; /// Dense column-major matrix type and backend-backed matrix utilities. mod matrix; +/// Process-level memory pressure helpers. +mod memory; /// Tensor snapshot storage types and low-level dense/diagonal kernels. mod storage; pub(crate) mod tenferro_bridge; @@ -35,6 +37,7 @@ pub use matrix::{ from_vec2d, mat_mul, submatrix, submatrix_argmax, swap_cols, swap_rows, transpose, BlasMul, Matrix, MatrixScalar, }; +pub use memory::{release_process_allocator_cached_memory, AllocatorPressureRelief}; pub use storage::{ contract_storage, make_mut_storage, mindim, Storage, StorageError, StorageKind, StorageResult, StorageScalar, StructuredStorage, SumFromStorage, @@ -42,13 +45,14 @@ pub use storage::{ pub use tenferro_bridge::{ axpby_native_tensor, axpby_storage_native, conj_native_tensor, contract_native_tensor, contract_storage_native, dense_native_tensor_from_col_major, diag_native_tensor_from_col_major, - einsum_native_tensors, einsum_native_tensors_owned, + einsum_native_tensor_reads, einsum_native_tensors, einsum_native_tensors_owned, native_tensor_primal_to_dense_c64_col_major, native_tensor_primal_to_dense_col_major, native_tensor_primal_to_dense_f64_col_major, native_tensor_primal_to_diag_c64, native_tensor_primal_to_diag_f64, native_tensor_primal_to_storage, outer_product_native_tensor, outer_product_storage_native, permute_native_tensor, permute_storage_native, print_and_reset_native_einsum_profile, qr_native_tensor, reset_native_einsum_profile, reshape_col_major_native_tensor, scale_native_tensor, scale_storage_native, - storage_to_native_tensor, sum_native_tensor, svd_native_tensor, tangent_native_tensor, + storage_payload_native_read_input, storage_to_native_tensor, sum_native_tensor, + svd_native_tensor, tangent_native_tensor, NativeTensorReadInput, }; pub use tensor_element::TensorElement; diff --git a/crates/tensor4all-tensorbackend/src/memory.rs b/crates/tensor4all-tensorbackend/src/memory.rs new file mode 100644 index 00000000..4a9bd82f --- /dev/null +++ b/crates/tensor4all-tensorbackend/src/memory.rs @@ -0,0 +1,104 @@ +//! Process-level memory pressure helpers. + +/// Result reported by [`release_process_allocator_cached_memory`]. +/// +/// The fields are platform-specific because allocator pressure-relief APIs do +/// not expose a uniform contract. On macOS, `released_bytes` is the value +/// returned by `malloc_zone_pressure_relief`. On Linux, `success` is the boolean +/// result returned by `malloc_trim(0)`. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct AllocatorPressureRelief { + /// Whether tensor4all-rs has a platform hook for the current target. + pub supported: bool, + /// Platform-reported released bytes when the allocator exposes that value. + pub released_bytes: Option, + /// Platform-reported success when the allocator exposes only a status bit. + pub success: Option, +} + +/// Ask the process allocator to return cached/free memory to the operating system. +/// +/// This is a diagnostic and memory-pressure hook for the platform system +/// allocator only. It does not release memory that is still owned by live +/// tensors or explicit buffer pools, and it may have no effect if the program is +/// built with a custom global allocator. +/// +/// # Examples +/// +/// ``` +/// use tensor4all_tensorbackend::release_process_allocator_cached_memory; +/// +/// let report = release_process_allocator_cached_memory(); +/// assert_eq!( +/// report.supported, +/// cfg!(any(target_os = "macos", target_os = "linux")) +/// ); +/// ``` +pub fn release_process_allocator_cached_memory() -> AllocatorPressureRelief { + allocator_pressure_relief() +} + +#[cfg(target_os = "macos")] +fn allocator_pressure_relief() -> AllocatorPressureRelief { + use std::ffi::c_void; + + extern "C" { + fn malloc_default_zone() -> *mut c_void; + fn malloc_zone_pressure_relief(zone: *mut c_void, goal: usize) -> usize; + } + + unsafe { + let zone = malloc_default_zone(); + if zone.is_null() { + AllocatorPressureRelief { + supported: true, + released_bytes: Some(0), + success: Some(false), + } + } else { + let released_bytes = malloc_zone_pressure_relief(zone, 0); + AllocatorPressureRelief { + supported: true, + released_bytes: Some(released_bytes), + success: Some(released_bytes > 0), + } + } + } +} + +#[cfg(target_os = "linux")] +fn allocator_pressure_relief() -> AllocatorPressureRelief { + extern "C" { + fn malloc_trim(pad: usize) -> i32; + } + + let success = unsafe { malloc_trim(0) != 0 }; + AllocatorPressureRelief { + supported: true, + released_bytes: None, + success: Some(success), + } +} + +#[cfg(not(any(target_os = "macos", target_os = "linux")))] +fn allocator_pressure_relief() -> AllocatorPressureRelief { + AllocatorPressureRelief { + supported: false, + released_bytes: None, + success: None, + } +} + +#[cfg(test)] +mod tests { + use super::release_process_allocator_cached_memory; + + #[test] + fn reports_platform_support() { + let report = release_process_allocator_cached_memory(); + assert_eq!( + report.supported, + cfg!(any(target_os = "macos", target_os = "linux")) + ); + } +} diff --git a/crates/tensor4all-tensorbackend/src/storage.rs b/crates/tensor4all-tensorbackend/src/storage.rs index 755b2250..4448c6cf 100644 --- a/crates/tensor4all-tensorbackend/src/storage.rs +++ b/crates/tensor4all-tensorbackend/src/storage.rs @@ -532,6 +532,16 @@ impl StructuredStorage { None } } + + /// Returns a borrowed compact-payload view when the payload is already + /// stored contiguously in column-major order. + pub fn payload_col_major_view_if_contiguous(&self) -> Option<&[T]> { + if matches!(col_major_strides(&self.payload_dims), Ok(strides) if strides == self.strides) { + Some(&self.data) + } else { + None + } + } } impl StructuredStorage { @@ -1292,6 +1302,19 @@ impl Storage { } } + /// Borrows the compact `f64` payload when it is already contiguous in + /// column-major payload order. + pub fn payload_f64_col_major_view_if_contiguous(&self) -> StorageResult> { + match &self.0 { + StorageRepr::F64(value) => Ok(value.payload_col_major_view_if_contiguous()), + StorageRepr::C64(_) => Err(StorageError::ScalarKindMismatch { + expected: "f64", + actual: storage_scalar_kind(&self.0), + operation: "borrowing f64 payload", + }), + } + } + /// Copies the compact `Complex64` payload in column-major payload order. /// /// This does not materialize logical dense values. Complex payloads are @@ -1322,6 +1345,19 @@ impl Storage { } } + /// Borrows the compact `Complex64` payload when it is already contiguous in + /// column-major payload order. + pub fn payload_c64_col_major_view_if_contiguous(&self) -> StorageResult> { + match &self.0 { + StorageRepr::C64(value) => Ok(value.payload_col_major_view_if_contiguous()), + StorageRepr::F64(_) => Err(StorageError::ScalarKindMismatch { + expected: "Complex64", + actual: storage_scalar_kind(&self.0), + operation: "borrowing c64 payload", + }), + } + } + /// Check if this storage uses f64 scalar type. /// /// # Examples diff --git a/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs b/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs index 4e2a7ac3..bb951fc0 100644 --- a/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs +++ b/crates/tensor4all-tensorbackend/src/tenferro_bridge.rs @@ -2,28 +2,73 @@ use std::cell::RefCell; use std::cmp::Reverse; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::env; use std::time::{Duration, Instant}; use anyhow::{anyhow, ensure, Result}; use num_complex::{Complex32, Complex64}; -use tenferro::eager_einsum::{eager_einsum, eager_einsum_owned}; -use tenferro::{DType, Tensor as NativeTensor, TensorBackend}; +use omeco::ScoreFunction; +use tenferro::traced_tensor::{einsum_subscripts_with, EinsumOptimize}; +use tenferro::{ + DType, EinsumSubscripts, Tensor as NativeTensor, TensorBackend, TensorRead, TensorView, + TracedTensor, +}; +use tenferro_einsum::{ContractionOptimizerOptions, ContractionTree, Subscripts}; use crate::any_scalar::promote_scalar_native; -use crate::context::with_default_backend; +use crate::context::{ + default_engine_buffer_pool_stats, reset_default_engine, reset_default_engine_buffer_pool, + with_default_backend, with_default_engine, +}; +use crate::memory::release_process_allocator_cached_memory; use crate::storage::Storage; #[cfg(test)] use crate::storage::StorageRepr; use crate::tensor_element::TensorElement; use crate::AnyScalar; +/// Read-only native tensor input that can either borrow external payload data +/// or own a temporary materialized tensor. +pub enum NativeTensorReadInput<'a> { + /// Borrowed read-only tensor input. + Borrowed(TensorRead<'a>), + /// Owned temporary tensor input. + Owned(NativeTensor), +} + +impl<'a> NativeTensorReadInput<'a> { + /// Return this input as a read-only tenferro tensor input. + pub fn as_read(&'a self) -> TensorRead<'a> { + match self { + Self::Borrowed(read) => *read, + Self::Owned(tensor) => TensorRead::from_tensor(tensor), + } + } + + /// Return the scalar dtype of this input. + pub fn dtype(&self) -> DType { + match self { + Self::Borrowed(read) => read.dtype(), + Self::Owned(tensor) => tensor.dtype(), + } + } + + /// Return the tensor shape of this input. + pub fn shape(&self) -> &[usize] { + match self { + Self::Borrowed(read) => read.shape(), + Self::Owned(tensor) => tensor.shape(), + } + } +} + #[cfg(test)] use std::cell::Cell; #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] enum NativeEinsumPath { + Owned, Borrowed, BorrowedWithConversions, } @@ -51,6 +96,8 @@ struct NativeEinsumProfileEntry { thread_local! { static NATIVE_EINSUM_PROFILE_STATE: RefCell> = RefCell::new(HashMap::new()); + static NATIVE_EINSUM_TRACE_STATE: RefCell> = + RefCell::new(HashSet::new()); } #[cfg(test)] @@ -66,21 +113,65 @@ fn native_einsum_profile_enabled() -> bool { env::var("T4A_PROFILE_NATIVE_EINSUM").is_ok() } +fn native_einsum_path_trace_enabled() -> bool { + env::var("T4A_TRACE_NATIVE_EINSUM_PATHS").is_ok() +} + +fn native_einsum_path_trace_min_bytes() -> usize { + env::var("T4A_TRACE_NATIVE_EINSUM_MIN_BYTES") + .ok() + .and_then(|value| value.parse().ok()) + .unwrap_or(0) +} + +fn native_einsum_path_trace_max_signatures() -> usize { + env::var("T4A_TRACE_NATIVE_EINSUM_MAX_SIGNATURES") + .ok() + .and_then(|value| value.parse().ok()) + .unwrap_or(64) +} + +fn native_einsum_pool_trace_enabled() -> bool { + env::var("T4A_TRACE_NATIVE_EINSUM_POOL").is_ok() +} + +fn native_einsum_pool_trace_min_output_bytes() -> usize { + env::var("T4A_TRACE_NATIVE_EINSUM_POOL_MIN_OUTPUT_BYTES") + .ok() + .and_then(|value| value.parse().ok()) + .unwrap_or(0) +} + +fn native_einsum_pool_trace_min_retained_bytes() -> usize { + env::var("T4A_TRACE_NATIVE_EINSUM_POOL_MIN_RETAINED_BYTES") + .ok() + .and_then(|value| value.parse().ok()) + .unwrap_or(0) +} + +fn reset_native_einsum_engine_after_call() -> bool { + env::var("T4A_RESET_NATIVE_EINSUM_ENGINE_AFTER_CALL").is_ok() +} + +fn reset_native_einsum_buffer_pool_after_call() -> bool { + env::var("T4A_RESET_NATIVE_EINSUM_BUFFER_POOL_AFTER_CALL").is_ok() +} + +fn release_allocator_after_native_einsum_call() -> bool { + env::var("T4A_RELEASE_ALLOCATOR_AFTER_NATIVE_EINSUM_CALL").is_ok() +} + #[cfg(test)] pub(crate) fn set_native_einsum_profile_enabled_for_tests(enabled: bool) { FORCE_NATIVE_EINSUM_PROFILE.with(|slot| slot.set(enabled)); } -fn record_native_einsum_profile( +fn native_einsum_signature( path: NativeEinsumPath, operands: &[(&NativeTensor, &[usize])], output_ids: &[u32], - elapsed: Duration, -) { - if !native_einsum_profile_enabled() { - return; - } - let signature = NativeEinsumSignature { +) -> NativeEinsumSignature { + NativeEinsumSignature { path, operands: operands .iter() @@ -91,7 +182,19 @@ fn record_native_einsum_profile( }) .collect(), output_ids: output_ids.to_vec(), - }; + } +} + +fn record_native_einsum_profile( + path: NativeEinsumPath, + operands: &[(&NativeTensor, &[usize])], + output_ids: &[u32], + elapsed: Duration, +) { + if !native_einsum_profile_enabled() { + return; + } + let signature = native_einsum_signature(path, operands, output_ids); NATIVE_EINSUM_PROFILE_STATE.with(|state| { let mut state = state.borrow_mut(); let entry = state.entry(signature).or_default(); @@ -100,9 +203,254 @@ fn record_native_einsum_profile( }); } +fn dtype_size_bytes(dtype: DType) -> usize { + match dtype { + DType::F32 => 4, + DType::F64 => 8, + DType::C32 => 8, + DType::C64 => 16, + DType::I64 => 8, + } +} + +fn native_tensor_bytes(tensor: &NativeTensor) -> usize { + tensor + .shape() + .iter() + .copied() + .fold(1usize, usize::saturating_mul) + .saturating_mul(dtype_size_bytes(tensor.dtype())) +} + +fn format_label(label: u32) -> String { + char::from_u32(label).map_or_else(|| label.to_string(), |label| label.to_string()) +} + +fn format_labels(labels: &[u32]) -> String { + if labels.is_empty() { + "scalar".to_string() + } else { + labels + .iter() + .map(|&label| format_label(label)) + .collect::>() + .join("") + } +} + +fn label_dims(subscripts: &Subscripts, shapes: &[Vec]) -> Result> { + let mut dims = HashMap::new(); + for (labels, shape) in subscripts.inputs.iter().zip(shapes.iter()) { + ensure!( + labels.len() == shape.len(), + "einsum labels {:?} do not match shape {:?}", + labels, + shape + ); + for (&label, &dim) in labels.iter().zip(shape.iter()) { + if let Some(previous) = dims.insert(label, dim) { + ensure!( + previous == dim, + "inconsistent dimension for einsum label {}: {} vs {}", + format_label(label), + previous, + dim + ); + } + } + } + Ok(dims) +} + +fn labels_size(labels: &[u32], dims: &HashMap) -> usize { + labels.iter().fold(1usize, |size, label| { + size.saturating_mul(dims.get(label).copied().unwrap_or(1)) + }) +} + +fn union_labels(lhs: &[u32], rhs: &[u32]) -> Vec { + let mut seen = HashSet::new(); + let mut labels = Vec::new(); + for &label in lhs.iter().chain(rhs.iter()) { + if seen.insert(label) { + labels.push(label); + } + } + labels +} + +#[derive(Debug)] +struct NativeEinsumPlanReport { + lines: Vec, + peak_intermediate_bytes: usize, +} + +fn time_optimized_contraction_options() -> ContractionOptimizerOptions { + ContractionOptimizerOptions { + score: ScoreFunction::time_optimized(), + ..ContractionOptimizerOptions::default() + } +} + +fn native_einsum_plan_report_with_options( + signature: &NativeEinsumSignature, + optimizer_name: &'static str, + options: &ContractionOptimizerOptions, +) -> Result { + let input_ids = signature + .operands + .iter() + .map(|operand| operand.ids.as_slice()) + .collect::>(); + let subscripts_string = build_einsum_subscripts(&input_ids, &signature.output_ids)?; + let subscripts = Subscripts { + inputs: input_ids.iter().map(|ids| ids.to_vec()).collect(), + output: signature.output_ids.clone(), + }; + let shapes = signature + .operands + .iter() + .map(|operand| operand.shape.clone()) + .collect::>(); + let shape_refs = shapes.iter().map(Vec::as_slice).collect::>(); + let tree = ContractionTree::optimize_with_options(&subscripts, &shape_refs, options) + .map_err(|e| anyhow!("failed to optimize native einsum path: {e}"))?; + let dims = label_dims(&subscripts, &shapes)?; + let dtype = signature + .operands + .first() + .map(|operand| operand.dtype) + .unwrap_or(DType::F64); + let dtype_size = dtype_size_bytes(dtype); + + let mut lines = Vec::new(); + lines.push(format!( + "optimizer={optimizer_name} subscripts={subscripts_string} dtype={dtype:?} steps={}", + tree.step_count() + )); + let mut peak_intermediate_elems = 1usize; + for step in 0..tree.step_count() { + let Some((left, right)) = tree.step_pair(step) else { + continue; + }; + let Some((lhs, rhs, out)) = tree.step_subscripts(step) else { + continue; + }; + let lhs_elems = labels_size(lhs, &dims); + let rhs_elems = labels_size(rhs, &dims); + let out_elems = labels_size(out, &dims); + let flop_index_elems = labels_size(&union_labels(lhs, rhs), &dims); + peak_intermediate_elems = peak_intermediate_elems.max(out_elems); + lines.push(format!( + " step {step:02}: pair=({left},{right}) {}[{}] x {}[{}] -> {}[{}] flop_index={} intermediate={} elems ({:.3} MiB)", + format_labels(lhs), + lhs_elems, + format_labels(rhs), + rhs_elems, + format_labels(out), + out_elems, + flop_index_elems, + out_elems, + out_elems as f64 * dtype_size as f64 / (1024.0 * 1024.0), + )); + } + let peak_intermediate_bytes = peak_intermediate_elems.saturating_mul(dtype_size); + lines.push(format!( + " peak_intermediate={} elems ({:.3} MiB)", + peak_intermediate_elems, + peak_intermediate_bytes as f64 / (1024.0 * 1024.0) + )); + + Ok(NativeEinsumPlanReport { + lines, + peak_intermediate_bytes, + }) +} + +fn native_einsum_time_optimized_plan_report( + signature: &NativeEinsumSignature, +) -> Result { + native_einsum_plan_report_with_options( + signature, + "time_optimized", + &time_optimized_contraction_options(), + ) +} + +fn native_einsum_balanced_plan_report( + signature: &NativeEinsumSignature, +) -> Result { + native_einsum_plan_report_with_options( + signature, + "balanced_default", + &ContractionOptimizerOptions::default(), + ) +} + +fn maybe_trace_native_einsum_path( + path: NativeEinsumPath, + operands: &[(&NativeTensor, &[usize])], + output_ids: &[u32], +) { + if !native_einsum_path_trace_enabled() { + return; + } + let signature = native_einsum_signature(path, operands, output_ids); + let report = match native_einsum_time_optimized_plan_report(&signature) { + Ok(report) if report.peak_intermediate_bytes >= native_einsum_path_trace_min_bytes() => { + report + } + Ok(_) => return, + Err(err) => { + eprintln!("native_einsum path trace failed: {err:#}"); + return; + } + }; + + let max_signatures = native_einsum_path_trace_max_signatures(); + let should_trace = NATIVE_EINSUM_TRACE_STATE.with(|state| { + let mut state = state.borrow_mut(); + if state.len() >= max_signatures || state.contains(&signature) { + false + } else { + state.insert(signature.clone()); + true + } + }); + if !should_trace { + return; + } + + eprintln!("=== native_einsum Path Trace ==="); + eprintln!( + "path={:?} output_ids={:?}", + signature.path, signature.output_ids + ); + for operand in &signature.operands { + eprintln!( + " operand shape={:?} ids={:?} dtype={:?}", + operand.shape, operand.ids, operand.dtype + ); + } + for line in report.lines { + eprintln!("{line}"); + } + if env::var("T4A_TRACE_NATIVE_EINSUM_COMPARE_BALANCED").is_ok() { + match native_einsum_balanced_plan_report(&signature) { + Ok(balanced) => { + for line in balanced.lines { + eprintln!("{line}"); + } + } + Err(err) => eprintln!("balanced native_einsum path trace failed: {err:#}"), + } + } +} + /// Reset the aggregated native einsum profile. pub fn reset_native_einsum_profile() { NATIVE_EINSUM_PROFILE_STATE.with(|state| state.borrow_mut().clear()); + NATIVE_EINSUM_TRACE_STATE.with(|state| state.borrow_mut().clear()); } /// Print and clear the aggregated native einsum profile. @@ -129,12 +477,20 @@ pub fn print_and_reset_native_einsum_profile() { entry.total_time.as_secs_f64() * 1e6 / entry.calls as f64, signature.output_ids, ); - for operand in signature.operands { + for operand in &signature.operands { eprintln!( " shape={:?} ids={:?} dtype={:?}", operand.shape, operand.ids, operand.dtype ); } + match native_einsum_time_optimized_plan_report(&signature) { + Ok(report) => { + for line in report.lines { + eprintln!(" {line}"); + } + } + Err(err) => eprintln!(" path report failed: {err:#}"), + } } }); } @@ -189,6 +545,115 @@ fn build_einsum_subscripts(operands: &[&[u32]], output_ids: &[u32]) -> Result Result { + let placeholders = inputs + .iter() + .map(|tensor| TracedTensor::input_concrete_shape(tensor.dtype(), tensor.shape())) + .collect::>(); + let placeholder_refs = placeholders.iter().collect::>(); + let bindings = placeholders + .iter() + .zip(inputs.iter()) + .map(|(placeholder, tensor)| (placeholder, *tensor)) + .collect::>(); + + let trace_pool = native_einsum_pool_trace_enabled(); + let pool_before = trace_pool.then(default_engine_buffer_pool_stats); + let result = with_default_engine(|engine| { + let mut result = einsum_subscripts_with( + engine, + &placeholder_refs, + subscripts, + EinsumOptimize::default(), + ) + .map_err(|e| anyhow!("native einsum failed: {e}"))?; + result + .eval_with_inputs(engine, &bindings) + .cloned() + .map_err(|e| anyhow!("native einsum failed: {e}")) + })?; + if trace_pool { + let pool_after = default_engine_buffer_pool_stats(); + let output_bytes = native_tensor_bytes(&result); + let retained_threshold = native_einsum_pool_trace_min_retained_bytes(); + if pool_after != pool_before.unwrap_or_default() + && pool_after.capacity_bytes >= retained_threshold + || output_bytes >= native_einsum_pool_trace_min_output_bytes() + { + let before = pool_before.unwrap_or_default(); + eprintln!( + "native_einsum pool subscripts={subscripts:?} before_buffers={} before_capacity={:.3} MiB after_buffers={} after_capacity={:.3} MiB output_shape={:?} output_bytes={:.3} MiB", + before.buffers, + before.capacity_bytes as f64 / (1024.0 * 1024.0), + pool_after.buffers, + pool_after.capacity_bytes as f64 / (1024.0 * 1024.0), + result.shape(), + output_bytes as f64 / (1024.0 * 1024.0), + ); + } + } + if reset_native_einsum_engine_after_call() { + let before_reset = trace_pool.then(default_engine_buffer_pool_stats); + reset_default_engine(); + if trace_pool + && before_reset.unwrap_or_default().capacity_bytes + >= native_einsum_pool_trace_min_retained_bytes() + { + let before = before_reset.unwrap_or_default(); + let after = default_engine_buffer_pool_stats(); + eprintln!( + "native_einsum engine_reset before_buffers={} before_capacity={:.3} MiB after_buffers={} after_capacity={:.3} MiB", + before.buffers, + before.capacity_bytes as f64 / (1024.0 * 1024.0), + after.buffers, + after.capacity_bytes as f64 / (1024.0 * 1024.0), + ); + } + } else if reset_native_einsum_buffer_pool_after_call() { + let before_clear = trace_pool.then(default_engine_buffer_pool_stats); + reset_default_engine_buffer_pool(); + if trace_pool + && before_clear.unwrap_or_default().capacity_bytes + >= native_einsum_pool_trace_min_retained_bytes() + { + let before = before_clear.unwrap_or_default(); + let after = default_engine_buffer_pool_stats(); + eprintln!( + "native_einsum pool_reset before_buffers={} before_capacity={:.3} MiB after_buffers={} after_capacity={:.3} MiB", + before.buffers, + before.capacity_bytes as f64 / (1024.0 * 1024.0), + after.buffers, + after.capacity_bytes as f64 / (1024.0 * 1024.0), + ); + } + } + if release_allocator_after_native_einsum_call() { + let report = release_process_allocator_cached_memory(); + if trace_pool && (report.released_bytes.unwrap_or(0) > 0 || report.success == Some(true)) { + eprintln!( + "native_einsum allocator_pressure_relief supported={} released_bytes={:?} success={:?}", + report.supported, + report.released_bytes, + report.success, + ); + } + } + Ok(result) +} + +fn cached_einsum_native_reads( + inputs: &[TensorRead<'_>], + subscripts: &Subscripts, +) -> Result { + with_default_backend(|backend| { + tenferro_einsum::eager_einsum_read_subscripts(backend, inputs, subscripts) + .map_err(|e| anyhow!("native read einsum failed: {e}")) + }) +} + /// Build native einsum ids for a binary contraction. pub(crate) fn build_binary_einsum_ids( lhs_rank: usize, @@ -290,6 +755,46 @@ pub fn storage_to_native_tensor(storage: &Storage, logical_dims: &[usize]) -> Re } } +/// Build a read-only native tensor input over the compact storage payload. +/// +/// Contiguous payloads are borrowed without copying. Non-contiguous payloads +/// are materialized into an owned native tensor. +pub fn storage_payload_native_read_input(storage: &Storage) -> Result> { + if storage.is_f64() { + if let Some(view) = storage + .payload_f64_col_major_view_if_contiguous() + .map_err(anyhow::Error::msg)? + { + return Ok(NativeTensorReadInput::Borrowed(TensorRead::from_view( + TensorView::f64(storage.payload_dims(), view)?, + ))); + } + Ok(NativeTensorReadInput::Owned(NativeTensor::from_vec( + storage.payload_dims().to_vec(), + storage + .payload_f64_col_major_vec() + .map_err(anyhow::Error::msg)?, + ))) + } else if storage.is_c64() { + if let Some(view) = storage + .payload_c64_col_major_view_if_contiguous() + .map_err(anyhow::Error::msg)? + { + return Ok(NativeTensorReadInput::Borrowed(TensorRead::from_view( + TensorView::c64(storage.payload_dims(), view)?, + ))); + } + Ok(NativeTensorReadInput::Owned(NativeTensor::from_vec( + storage.payload_dims().to_vec(), + storage + .payload_c64_col_major_vec() + .map_err(anyhow::Error::msg)?, + ))) + } else { + Err(anyhow!("unsupported storage scalar type")) + } +} + /// Materialize a native tensor into dense storage. pub fn native_tensor_primal_to_storage(tensor: &NativeTensor) -> Result { match tensor.dtype() { @@ -636,11 +1141,12 @@ pub fn axpby_native_tensor( } } -/// Execute an eager einsum over owned native tensors. +/// Execute a cached einsum over owned native tensors. /// /// This is the consuming bridge used by higher-level owned contraction APIs. -/// Inputs are promoted to a common dtype before the owned tenferro eager -/// einsum runs. +/// Inputs are promoted to a common dtype before tenferro evaluates the +/// contraction. Repeated calls with the same equation and shapes reuse +/// tenferro's process-global contraction path cache. /// /// # Arguments /// * `operands` - Native tensors paired with numeric einsum labels for each axis. @@ -702,20 +1208,37 @@ pub fn einsum_native_tensors_owned( let input_slices = input_ids.iter().map(Vec::as_slice).collect::>(); let output_ids_u32 = output_ids.iter().map(|&id| id as u32).collect::>(); - let subscripts = build_einsum_subscripts(&input_slices, &output_ids_u32)?; + let subscripts = EinsumSubscripts::new(&input_slices, &output_ids_u32); - let result = - with_default_backend(|backend| eager_einsum_owned(backend, converted, &subscripts)) - .map_err(|e| anyhow!("native einsum failed: {e}"))?; + let input_refs = converted.iter().collect::>(); + let trace_ids = input_ids + .iter() + .map(|ids| ids.iter().map(|&id| id as usize).collect::>()) + .collect::>(); + let trace_operands = input_refs + .iter() + .zip(trace_ids.iter()) + .map(|(tensor, ids)| (*tensor, ids.as_slice())) + .collect::>(); + maybe_trace_native_einsum_path(NativeEinsumPath::Owned, &trace_operands, &output_ids_u32); + let started = Instant::now(); + let result = cached_einsum_native_tensors(&input_refs, &subscripts)?; + record_native_einsum_profile( + NativeEinsumPath::Owned, + &trace_operands, + &output_ids_u32, + started.elapsed(), + ); Ok(result) } -/// Execute an eager einsum over borrowed native tensors. +/// Execute a cached einsum over borrowed native tensors. /// /// Inputs are promoted to a common dtype before contraction. Operands that /// already have the target dtype are passed to the backend by reference; /// operands with another dtype are converted into temporary native tensors and -/// then borrowed for the contraction. +/// then borrowed for the contraction. Repeated calls with the same equation +/// and shapes reuse tenferro's process-global contraction path cache. /// /// # Arguments /// * `operands` - Native tensors paired with numeric einsum labels for each axis. @@ -780,27 +1303,82 @@ pub fn einsum_native_tensors( let input_slices = input_ids.iter().map(Vec::as_slice).collect::>(); let output_ids_u32 = output_ids.iter().map(|&id| id as u32).collect::>(); - let subscripts = build_einsum_subscripts(&input_slices, &output_ids_u32)?; + let subscripts = EinsumSubscripts::new(&input_slices, &output_ids_u32); let input_refs = operands .iter() .zip(converted.iter()) .map(|((tensor, _), converted)| converted.as_ref().unwrap_or(*tensor)) .collect::>(); - let result = with_default_backend(|backend| eager_einsum(backend, &input_refs, &subscripts)) - .map_err(|e| anyhow!("native einsum failed: {e}"))?; - record_native_einsum_profile( - if has_conversions { - NativeEinsumPath::BorrowedWithConversions - } else { - NativeEinsumPath::Borrowed - }, - operands, - &output_ids_u32, - started.elapsed(), - ); + let trace_path = if has_conversions { + NativeEinsumPath::BorrowedWithConversions + } else { + NativeEinsumPath::Borrowed + }; + maybe_trace_native_einsum_path(trace_path, operands, &output_ids_u32); + let result = cached_einsum_native_tensors(&input_refs, &subscripts)?; + record_native_einsum_profile(trace_path, operands, &output_ids_u32, started.elapsed()); Ok(result) } +/// Execute a cached einsum over read-only native tensor inputs. +/// +/// Backends may consume borrowed host views directly or materialize/upload them +/// inside their execution session. Mixed dtypes are promoted by materializing +/// only the operands that require conversion. +pub fn einsum_native_tensor_reads( + operands: &[(&NativeTensorReadInput<'_>, &[usize])], + output_ids: &[usize], +) -> Result { + ensure!( + !operands.is_empty(), + "native einsum requires at least one operand" + ); + + let target = common_dtype( + &operands + .iter() + .map(|(tensor, _)| tensor.dtype()) + .collect::>(), + ); + let mut converted = Vec::with_capacity(operands.len()); + let mut input_ids = Vec::with_capacity(operands.len()); + let mut read_inputs = Vec::with_capacity(operands.len()); + + for (tensor, ids) in operands { + ensure!( + tensor.shape().len() == ids.len(), + "einsum id list {:?} does not match tensor shape {:?}", + ids, + tensor.shape() + ); + input_ids.push(ids.iter().map(|&id| id as u32).collect::>()); + if tensor.dtype() == target { + converted.push(None); + } else { + converted.push(Some(convert_tensor(&tensor.as_read().to_tensor(), target)?)); + } + } + + for (tensor, converted) in operands + .iter() + .map(|(tensor, _)| *tensor) + .zip(converted.iter()) + { + if let Some(converted) = converted { + read_inputs.push(TensorRead::from_tensor(converted)); + } else { + read_inputs.push(tensor.as_read()); + } + } + + let output_ids_u32 = output_ids.iter().map(|&id| id as u32).collect::>(); + let subscripts = Subscripts { + inputs: input_ids, + output: output_ids_u32, + }; + cached_einsum_native_reads(&read_inputs, &subscripts) +} + /// Permute axes of a native tensor. pub fn permute_native_tensor(tensor: &NativeTensor, perm: &[usize]) -> Result { with_default_backend(|backend| tensor.transpose(perm, backend)) diff --git a/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs b/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs index f4b005cb..16440446 100644 --- a/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs +++ b/crates/tensor4all-tensorbackend/src/tenferro_bridge/tests/mod.rs @@ -36,6 +36,16 @@ fn recorded_native_einsum_call_count(path: NativeEinsumPath) -> usize { }) } +fn default_engine_contains_einsum_subscripts_key( + inputs: &[&[u32]], + output: &[u32], + shapes: Vec>, +) -> bool { + crate::context::with_default_engine(|engine| { + engine.einsum_cache_contains_subscripts(&(EinsumSubscripts::new(inputs, output), shapes)) + }) +} + struct ProfileGuard; impl ProfileGuard { @@ -281,6 +291,27 @@ fn einsum_native_tensors_supports_retained_shared_nary_label() { assert_eq!(values[0], 55.0); } +#[test] +fn einsum_native_tensors_populates_process_global_path_cache() { + let a = NativeTensor::from_vec(vec![2, 3, 4], vec![1.0_f64; 24]); + let b = NativeTensor::from_vec(vec![4, 5], vec![2.0_f64; 20]); + let c = NativeTensor::from_vec(vec![3, 2], vec![3.0_f64; 6]); + + let out = + einsum_native_tensors(&[(&a, &[0, 1, 2]), (&b, &[2, 3]), (&c, &[1, 0])], &[3]).unwrap(); + + assert_eq!(out.shape(), &[5]); + assert_eq!( + native_tensor_primal_to_dense_f64_col_major(&out).unwrap(), + vec![144.0; 5] + ); + assert!(default_engine_contains_einsum_subscripts_key( + &[&[0, 1, 2], &[2, 3], &[1, 0]], + &[3], + vec![vec![2, 3, 4], vec![4, 5], vec![3, 2]] + )); +} + #[test] fn einsum_native_tensors_mixed_dtype_records_borrowed_conversion_profile() { let _guard = ProfileGuard::enable(); @@ -336,6 +367,112 @@ fn einsum_native_tensors_dense_binary_records_borrowed_profile() { ); } +#[test] +fn native_read_input_owned_and_plan_helpers_cover_debug_paths() { + let tensor = NativeTensor::from_vec(vec![2], vec![1.0_f64, 2.0]); + let input = NativeTensorReadInput::Owned(tensor.clone()); + assert_eq!(input.dtype(), DType::F64); + assert_eq!(input.shape(), &[2]); + assert_eq!(input.as_read().shape(), &[2]); + + assert_eq!(native_einsum_path_trace_min_bytes(), 0); + assert_eq!(native_einsum_path_trace_max_signatures(), 64); + assert_eq!(native_einsum_pool_trace_min_output_bytes(), 0); + assert_eq!(native_einsum_pool_trace_min_retained_bytes(), 0); + + assert_eq!(dtype_size_bytes(DType::F32), 4); + assert_eq!(dtype_size_bytes(DType::F64), 8); + assert_eq!(dtype_size_bytes(DType::C32), 8); + assert_eq!(dtype_size_bytes(DType::C64), 16); + assert_eq!(dtype_size_bytes(DType::I64), 8); + assert_eq!(native_tensor_bytes(&tensor), 16); + assert_eq!(format_label('x' as u32), "x"); + assert_eq!(format_label(0x110000), "1114112"); + assert_eq!(format_labels(&[]), "scalar"); + assert_eq!(format_labels(&['a' as u32, 'b' as u32]), "ab"); + + let subscripts = Subscripts { + inputs: vec![vec![0, 1], vec![1, 2]], + output: vec![0, 2], + }; + let dims = label_dims(&subscripts, &[vec![2, 3], vec![3, 4]]).unwrap(); + assert_eq!(labels_size(&[0, 2], &dims), 8); + assert_eq!(union_labels(&[0, 1], &[1, 2]), vec![0, 1, 2]); + + let bad_len = label_dims(&subscripts, &[vec![2], vec![3, 4]]).unwrap_err(); + assert!(bad_len.to_string().contains("do not match shape")); + let bad_dim = label_dims(&subscripts, &[vec![2, 3], vec![4, 4]]).unwrap_err(); + assert!(bad_dim.to_string().contains("inconsistent dimension")); + + let signature = NativeEinsumSignature { + path: NativeEinsumPath::Borrowed, + operands: vec![ + NativeOperandSignature { + shape: vec![2, 3], + ids: vec![0, 1], + dtype: DType::F64, + }, + NativeOperandSignature { + shape: vec![3, 4], + ids: vec![1, 2], + dtype: DType::F64, + }, + ], + output_ids: vec![0, 2], + }; + let time_report = native_einsum_time_optimized_plan_report(&signature).unwrap(); + let balanced_report = native_einsum_balanced_plan_report(&signature).unwrap(); + assert!(time_report.peak_intermediate_bytes >= 8); + assert!(!time_report.lines.is_empty()); + assert!(!balanced_report.lines.is_empty()); +} + +#[test] +fn native_einsum_profile_print_and_c32_arithmetic_paths() { + let _guard = ProfileGuard::enable(); + + let lhs = NativeTensor::from_vec( + vec![2], + vec![Complex32::new(1.0, 2.0), Complex32::new(-3.0, 0.5)], + ); + let rhs = NativeTensor::from_vec( + vec![2], + vec![Complex32::new(0.5, -1.0), Complex32::new(4.0, 2.0)], + ); + + let scaled = scale_native_tensor( + &lhs, + &crate::AnyScalar::from_value(Complex32::new(2.0, -1.0)), + ) + .unwrap(); + assert_eq!(scaled.dtype(), DType::C32); + assert_eq!( + scaled.as_slice::().unwrap(), + &[Complex32::new(4.0, 3.0), Complex32::new(-5.5, 4.0)] + ); + + let combined = axpby_native_tensor( + &lhs, + &crate::AnyScalar::from_value(Complex32::new(1.0, 0.0)), + &rhs, + &crate::AnyScalar::from_value(Complex32::new(0.0, 1.0)), + ) + .unwrap(); + assert_eq!(combined.dtype(), DType::C32); + assert_eq!( + combined.as_slice::().unwrap(), + &[Complex32::new(2.0, 2.5), Complex32::new(-5.0, 4.5)] + ); + + let contraction = einsum_native_tensors(&[(&lhs, &[0]), (&rhs, &[0])], &[]).unwrap(); + assert_eq!(contraction.shape(), &[] as &[usize]); + print_and_reset_native_einsum_profile(); + assert_eq!( + recorded_native_einsum_call_count(NativeEinsumPath::Borrowed), + 0 + ); +} + #[test] fn contract_native_tensor_restores_rhs_free_axis_order() { let lhs = storage_to_native_tensor( diff --git a/crates/tensor4all-treetn/docs/linsolve.md b/crates/tensor4all-treetn/docs/linsolve.md index 39eb4a03..1ab3e101 100644 --- a/crates/tensor4all-treetn/docs/linsolve.md +++ b/crates/tensor4all-treetn/docs/linsolve.md @@ -112,13 +112,19 @@ the square linsolve keeps a persistent reference state: Operationally: - `solve_local` constructs `LocalLinOp` with a reference state, so `ProjectedOperator::apply(..., ket_state=state, reference_state=self.reference_state, ...)` uses distinct - ket/reference link namespaces and avoids spurious traces when contracting with `AllowedPairs::All`. + ket/reference link namespaces and avoids spurious traces when contracting with `full contraction`. ### 4) Local RHS construction: `ProjectedState` The local RHS tensor `b_local` is constructed via environments of the RHS state network: -- `ProjectedState::local_constant_term(region, ket_state, topology)` (2-chain contraction) +- `ProjectedState::local_constant_term(region, reference_state, topology)` (2-chain contraction) + +`reference_state` uses a link-index namespace independent of the RHS/ket network +so `` environments keep the reference-side boundary links open instead +of accidentally tracing over bra/ket boundary bonds. `solve_local` then maps +those open reference boundary bonds back to the current state bonds before +calling GMRES. Relevant file: - `crates/tensor4all-treetn/src/linsolve/square/projected_state.rs` @@ -175,13 +181,13 @@ Source: ## Known pitfall: unintended contraction due to shared index IDs The contraction engine treats indices as contractable if they share the same ID (plus compatible -`ConjState`, dimensions). When bra/ket share the same index IDs, `AllowedPairs::All` can contract +`ConjState`, dimensions). When bra/ket share the same index IDs, `full contraction` can contract more aggressively than intended if the implementation does not explicitly separate bra/ket index namespaces or restrict contraction pairs. This is the direction for "root-cause" fixes: - separate bra/ket index identities (e.g. via priming / directed indices), and/or -- constrain contractions (e.g. `AllowedPairs::Specified`) so environments do not introduce +- use TreeTN-level topology-aware contractions so environments do not introduce unintended traces. Related reproducer examples: @@ -201,9 +207,9 @@ apply_local_update_sweep(state: TreeTN, plan, updater) │ ├─ contract_region(subtree, step.nodes) -> init_local: T │ ├─ solve_local(region, init_local, state=full_treetn) -> solved_local: T │ │ ├─ ProjectedState::local_constant_term(...) -> rhs_local: T - │ │ ├─ linop = LocalLinOp::new(projected_operator, region, state.clone(), reference_state.clone(), a0, a1) - │ │ └─ gmres(apply_a, rhs_local, init_local, gmres_options) - │ │ └─ apply_a(x_local: &T) = LocalLinOp::apply(x_local) + │ │ ├─ linop = LocalLinOp::new(projected_operator, region, state, reference_state) + │ │ └─ gmres_affine(apply_a, rhs_local, init_local, a0, a1, gmres_options) + │ │ └─ apply_a(x_local: &T) = LocalLinOp::apply_projected(x_local) │ │ ├─ ProjectedOperator::apply(v=x_local, region, ket_state=state, reference_state, topology) │ │ │ ├─ ensure_environments(...) │ │ │ │ └─ compute_environment(from, to, ket_state, reference_state, topology) // recursive diff --git a/crates/tensor4all-treetn/examples/bench_swap_interleave_r45.rs b/crates/tensor4all-treetn/examples/bench_swap_interleave_r45.rs index fabbfbc9..337e081b 100644 --- a/crates/tensor4all-treetn/examples/bench_swap_interleave_r45.rs +++ b/crates/tensor4all-treetn/examples/bench_swap_interleave_r45.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use rand_chacha::rand_core::{RngCore, SeedableRng}; -use tensor4all_core::{DynIndex, IndexLike, TensorDynLen, TensorLike}; +use tensor4all_core::{DynIndex, IndexLike, TensorDynLen}; use tensor4all_treetn::{SwapOptions, TreeTN}; fn build_tn( diff --git a/crates/tensor4all-treetn/examples/benchmark_linsolve.rs b/crates/tensor4all-treetn/examples/benchmark_linsolve.rs index 1cedc5dc..c53e2467 100644 --- a/crates/tensor4all-treetn/examples/benchmark_linsolve.rs +++ b/crates/tensor4all-treetn/examples/benchmark_linsolve.rs @@ -14,7 +14,9 @@ use std::time::{Duration, Instant}; use rand::rngs::StdRng; use rand::SeedableRng; -use tensor4all_core::{index::DynId, DynIndex, IndexLike, TensorDynLen, TensorIndex, TensorLike}; +use tensor4all_core::{ + index::DynId, DynIndex, IndexLike, TensorContractionLike, TensorDynLen, TensorIndex, +}; use tensor4all_treetn::{ apply_linear_operator, apply_local_update_sweep, random_treetn, ApplyOptions, CanonicalForm, CanonicalizationOptions, IndexMapping, LinearOperator, LinkSpace, LinsolveOptions, @@ -187,9 +189,9 @@ fn main() -> anyhow::Result<()> { let a0 = 0.0_f64; let a1 = 1.0_f64; - let krylov_tol = 1e-6_f64; - let krylov_maxiter = 20usize; - let krylov_dim = 30usize; + let gmres_tol = 1e-6_f64; + let gmres_max_restarts = 20usize; + let gmres_restart_dim = 30usize; let n_runs = 10usize; @@ -201,7 +203,7 @@ fn main() -> anyhow::Result<()> { println!("max_rank = {max_rank}"); println!("cutoff = {cutoff}"); println!("rtol = sqrt(cutoff) = {rtol:.6}"); - println!("GMRES: tol={krylov_tol}, maxiter={krylov_maxiter}, krylov_dim={krylov_dim}"); + println!("GMRES: tol={gmres_tol}, maxiter={gmres_max_restarts}, gmres_restart_dim={gmres_restart_dim}"); println!("coefficients: a0={a0}, a1={a1}"); println!("seed = {seed} (used for random MPO)"); println!("n_runs = {n_runs} (excluding warmup)"); @@ -233,9 +235,9 @@ fn main() -> anyhow::Result<()> { let options = LinsolveOptions::default() .with_nfullsweeps(n_sweeps) .with_truncation(truncation) - .with_krylov_tol(krylov_tol) - .with_krylov_maxiter(krylov_maxiter) - .with_krylov_dim(krylov_dim) + .with_gmres_tol(gmres_tol) + .with_gmres_max_restarts(gmres_max_restarts) + .with_gmres_restart_dim(gmres_restart_dim) .with_coefficients(a0, a1); let compute_rel_residual = |x: &TreeTN| -> anyhow::Result { diff --git a/crates/tensor4all-treetn/examples/benchmark_linsolve_mpo.rs b/crates/tensor4all-treetn/examples/benchmark_linsolve_mpo.rs index e53720fe..fe6cb548 100644 --- a/crates/tensor4all-treetn/examples/benchmark_linsolve_mpo.rs +++ b/crates/tensor4all-treetn/examples/benchmark_linsolve_mpo.rs @@ -19,7 +19,7 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use tensor4all_core::{ index::{DynId, Index}, - DynIndex, IndexLike, TensorDynLen, TensorIndex, TensorLike, + DynIndex, IndexLike, TensorContractionLike, TensorDynLen, TensorIndex, }; use tensor4all_treetn::{ apply_linear_operator, apply_local_update_sweep, ApplyOptions, CanonicalForm, @@ -269,9 +269,9 @@ fn main() -> anyhow::Result<()> { let a0 = 0.0_f64; let a1 = 1.0_f64; - let krylov_tol = 1e-6_f64; - let krylov_maxiter = 50usize; - let krylov_dim = 30usize; + let gmres_tol = 1e-6_f64; + let gmres_max_restarts = 50usize; + let gmres_restart_dim = 30usize; let n_runs = 10usize; @@ -285,7 +285,7 @@ fn main() -> anyhow::Result<()> { println!("n_sweeps = {n_sweeps}"); println!("cutoff = {cutoff}"); println!("rtol = sqrt(cutoff) = {rtol:.6}"); - println!("GMRES: tol={krylov_tol}, maxiter={krylov_maxiter}, krylov_dim={krylov_dim}"); + println!("GMRES: tol={gmres_tol}, maxiter={gmres_max_restarts}, gmres_restart_dim={gmres_restart_dim}"); println!("coefficients: a0={a0}, a1={a1}"); println!("seed = {seed} (used for random operator MPO A0)"); println!("n_runs = {n_runs} (excluding warmup)"); @@ -343,9 +343,9 @@ fn main() -> anyhow::Result<()> { let options = LinsolveOptions::default() .with_nfullsweeps(n_sweeps) .with_truncation(truncation) - .with_krylov_tol(krylov_tol) - .with_krylov_maxiter(krylov_maxiter) - .with_krylov_dim(krylov_dim) + .with_gmres_tol(gmres_tol) + .with_gmres_max_restarts(gmres_max_restarts) + .with_gmres_restart_dim(gmres_restart_dim) .with_coefficients(a0, a1); let compute_rel_residual = |x: &TreeTN| -> anyhow::Result { diff --git a/crates/tensor4all-treetn/examples/benchmark_local_linsolve.rs b/crates/tensor4all-treetn/examples/benchmark_local_linsolve.rs new file mode 100644 index 00000000..a8fe686a --- /dev/null +++ b/crates/tensor4all-treetn/examples/benchmark_local_linsolve.rs @@ -0,0 +1 @@ +include!("../../../benchmarks/rust/benchmark_local_linsolve.rs"); diff --git a/crates/tensor4all-treetn/examples/benchmark_projected_apply.rs b/crates/tensor4all-treetn/examples/benchmark_projected_apply.rs new file mode 100644 index 00000000..680f92ef --- /dev/null +++ b/crates/tensor4all-treetn/examples/benchmark_projected_apply.rs @@ -0,0 +1 @@ +include!("../../../benchmarks/rust/benchmark_projected_apply.rs"); diff --git a/crates/tensor4all-treetn/examples/compare_mps_mpo_pauli.rs b/crates/tensor4all-treetn/examples/compare_mps_mpo_pauli.rs index f027bab0..55a522d8 100644 --- a/crates/tensor4all-treetn/examples/compare_mps_mpo_pauli.rs +++ b/crates/tensor4all-treetn/examples/compare_mps_mpo_pauli.rs @@ -11,7 +11,9 @@ use std::collections::{HashMap, HashSet}; use std::time::Instant; -use tensor4all_core::{index::DynId, DynIndex, IndexLike, TensorDynLen, TensorIndex, TensorLike}; +use tensor4all_core::{ + index::DynId, DynIndex, IndexLike, TensorContractionLike, TensorDynLen, TensorIndex, +}; use tensor4all_treetn::{ apply_linear_operator, apply_local_update_sweep, ApplyOptions, CanonicalForm, CanonicalizationOptions, IndexMapping, LinearOperator, LinsolveOptions, LocalUpdateSweepPlan, @@ -358,9 +360,9 @@ fn main() -> anyhow::Result<()> { let rtol = cutoff.sqrt(); let a0 = 0.0_f64; let a1 = 1.0_f64; - let krylov_tol = 1e-6_f64; - let krylov_maxiter = 20usize; - let krylov_dim = 30usize; + let gmres_tol = 1e-6_f64; + let gmres_max_restarts = 20usize; + let gmres_restart_dim = 30usize; println!("=== Compare MPS vs MPO linsolve times with Pauli-X operator ==="); println!("N = {n}"); @@ -371,7 +373,7 @@ fn main() -> anyhow::Result<()> { println!("cutoff = {cutoff}"); println!("rtol = sqrt(cutoff) = {rtol}"); println!("coefficients: a0 = {a0}, a1 = {a1}"); - println!("GMRES: tol = {krylov_tol}, maxiter = {krylov_maxiter}, krylov_dim = {krylov_dim}"); + println!("GMRES: tol = {gmres_tol}, maxiter = {gmres_max_restarts}, gmres_restart_dim = {gmres_restart_dim}"); println!(); let truncation = TruncationOptions::default() @@ -381,9 +383,9 @@ fn main() -> anyhow::Result<()> { let options = LinsolveOptions::default() .with_nfullsweeps(n_sweeps) .with_truncation(truncation) - .with_krylov_tol(krylov_tol) - .with_krylov_maxiter(krylov_maxiter) - .with_krylov_dim(krylov_dim) + .with_gmres_tol(gmres_tol) + .with_gmres_max_restarts(gmres_max_restarts) + .with_gmres_restart_dim(gmres_restart_dim) .with_coefficients(a0, a1); // ========== Test 1: MPS ========== diff --git a/crates/tensor4all-treetn/examples/minimal_linsolve_identity_cases.rs b/crates/tensor4all-treetn/examples/minimal_linsolve_identity_cases.rs index 571595f6..843f8f85 100644 --- a/crates/tensor4all-treetn/examples/minimal_linsolve_identity_cases.rs +++ b/crates/tensor4all-treetn/examples/minimal_linsolve_identity_cases.rs @@ -15,7 +15,7 @@ use std::collections::{HashMap, HashSet}; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; -use tensor4all_core::{index::DynId, DynIndex, IndexLike, TensorDynLen}; +use tensor4all_core::{index::DynId, DynIndex, IndexLike, TensorContractionLike, TensorDynLen}; use tensor4all_treetn::{ apply_local_update_sweep, CanonicalizationOptions, IndexMapping, LinsolveOptions, LocalUpdateStep, LocalUpdateSweepPlan, LocalUpdater, SquareLinsolveUpdater, TreeTN, @@ -245,9 +245,9 @@ fn case_ok_identity_single_1site_step() -> anyhow::Result<()> { let options = LinsolveOptions::default() .with_nfullsweeps(1) .with_max_rank(state_bond_dim) - .with_krylov_tol(1e-6) - .with_krylov_maxiter(20) - .with_krylov_dim(30) + .with_gmres_tol(1e-6) + .with_gmres_max_restarts(20) + .with_gmres_restart_dim(30) .with_coefficients(0.0, 1.0); let mut x = init.canonicalize([center.clone()], CanonicalizationOptions::default())?; @@ -289,9 +289,9 @@ fn case_fail_identity_2site_sweep() -> anyhow::Result<()> { let options = LinsolveOptions::default() .with_nfullsweeps(1) .with_max_rank(state_bond_dim) - .with_krylov_tol(1e-6) - .with_krylov_maxiter(20) - .with_krylov_dim(30) + .with_gmres_tol(1e-6) + .with_gmres_max_restarts(20) + .with_gmres_restart_dim(30) .with_coefficients(1.0, 1.0); let seed = 1234u64; diff --git a/crates/tensor4all-treetn/examples/repro_linsolve_random_mpo.rs b/crates/tensor4all-treetn/examples/repro_linsolve_random_mpo.rs index abc58101..1b93e404 100644 --- a/crates/tensor4all-treetn/examples/repro_linsolve_random_mpo.rs +++ b/crates/tensor4all-treetn/examples/repro_linsolve_random_mpo.rs @@ -15,7 +15,7 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use tensor4all_core::{ index::{DynId, Index}, - DynIndex, IndexLike, TensorDynLen, TensorIndex, TensorLike, + DynIndex, IndexLike, TensorContractionLike, TensorDynLen, TensorIndex, }; use tensor4all_treetn::{ apply_linear_operator, apply_local_update_sweep, ApplyOptions, CanonicalizationOptions, @@ -390,9 +390,9 @@ fn run_linsolve_test( let options = LinsolveOptions::default() .with_nfullsweeps(n_sweeps) .with_max_rank(max_rank) - .with_krylov_tol(1e-8) - .with_krylov_maxiter(50) - .with_krylov_dim(30) + .with_gmres_tol(1e-8) + .with_gmres_max_restarts(50) + .with_gmres_restart_dim(30) .with_coefficients(a0, a1); // Initial guess: x0 = rhs diff --git a/crates/tensor4all-treetn/examples/repro_linsolve_single_run.rs b/crates/tensor4all-treetn/examples/repro_linsolve_single_run.rs index 98c835e4..61e4b9c5 100644 --- a/crates/tensor4all-treetn/examples/repro_linsolve_single_run.rs +++ b/crates/tensor4all-treetn/examples/repro_linsolve_single_run.rs @@ -286,9 +286,9 @@ fn main() -> anyhow::Result<()> { let cutoff = 1e-8_f64; let rtol = cutoff.sqrt(); - let krylov_tol = 1e-6_f64; - let krylov_maxiter = 20usize; - let krylov_dim = 30usize; + let gmres_tol = 1e-6_f64; + let gmres_max_restarts = 20usize; + let gmres_restart_dim = 30usize; println!("=== repro_linsolve_single_run ==="); println!("N = {n_sites}"); @@ -300,7 +300,7 @@ fn main() -> anyhow::Result<()> { println!("max_rank = {max_rank}"); println!("cutoff = {cutoff}"); println!("rtol = sqrt(cutoff) = {rtol:.6}"); - println!("GMRES: tol={krylov_tol}, maxiter={krylov_maxiter}, krylov_dim={krylov_dim}"); + println!("GMRES: tol={gmres_tol}, maxiter={gmres_max_restarts}, gmres_restart_dim={gmres_restart_dim}"); println!("coefficients: a0={a0}, a1={a1}"); // This line remains unchanged println!(); @@ -342,9 +342,9 @@ fn main() -> anyhow::Result<()> { let options = LinsolveOptions::default() .with_nfullsweeps(n_sweeps) .with_truncation(truncation) - .with_krylov_tol(krylov_tol) - .with_krylov_maxiter(krylov_maxiter) - .with_krylov_dim(krylov_dim) + .with_gmres_tol(gmres_tol) + .with_gmres_max_restarts(gmres_max_restarts) + .with_gmres_restart_dim(gmres_restart_dim) .with_coefficients(a0, a1); let plan = LocalUpdateSweepPlan::from_treetn(&x, ¢er, 2) diff --git a/crates/tensor4all-treetn/examples/test_linsolve_general_coefficients_n10.rs b/crates/tensor4all-treetn/examples/test_linsolve_general_coefficients_n10.rs index 9ae923cd..bb69ea51 100644 --- a/crates/tensor4all-treetn/examples/test_linsolve_general_coefficients_n10.rs +++ b/crates/tensor4all-treetn/examples/test_linsolve_general_coefficients_n10.rs @@ -394,7 +394,7 @@ fn run_test_case(a0: f64, a1: f64, init_mode: &str, bond_dim: usize) -> anyhow:: // For N=10, we need higher max_rank to handle the larger system let options = LinsolveOptions::default() .with_nfullsweeps(10) - .with_krylov_tol(1e-10) + .with_gmres_tol(1e-10) .with_max_rank(100) .with_coefficients(a0, a1); diff --git a/crates/tensor4all-treetn/examples/test_linsolve_general_coefficients_n3.rs b/crates/tensor4all-treetn/examples/test_linsolve_general_coefficients_n3.rs index 97270b44..fcae606f 100644 --- a/crates/tensor4all-treetn/examples/test_linsolve_general_coefficients_n3.rs +++ b/crates/tensor4all-treetn/examples/test_linsolve_general_coefficients_n3.rs @@ -505,7 +505,7 @@ fn run_test_case( // Setup linsolve options and updater let options = LinsolveOptions::default() .with_nfullsweeps(10) - .with_krylov_tol(1e-10) + .with_gmres_tol(1e-10) .with_max_rank(50) .with_coefficients(a0, a1); diff --git a/crates/tensor4all-treetn/examples/test_linsolve_identity_operator_issue.rs b/crates/tensor4all-treetn/examples/test_linsolve_identity_operator_issue.rs index 2408a458..e32cf545 100644 --- a/crates/tensor4all-treetn/examples/test_linsolve_identity_operator_issue.rs +++ b/crates/tensor4all-treetn/examples/test_linsolve_identity_operator_issue.rs @@ -600,15 +600,15 @@ fn run_test_case( // Adjust GMRES parameters for better convergence let options = LinsolveOptions::default() .with_nfullsweeps(10) - .with_krylov_tol(1e-8) // Slightly relaxed from 1e-10 - .with_krylov_maxiter(200) // Increased from default 100 - .with_krylov_dim(50) // Increased from default 30 + .with_gmres_tol(1e-8) // Slightly relaxed from 1e-10 + .with_gmres_max_restarts(200) // Increased from default 100 + .with_gmres_restart_dim(50) // Increased from default 30 .with_max_rank(max_rank) .with_coefficients(a0, a1) .with_convergence_tol(1e-6); // Early termination if residual < 1e-6 if verbose { - println!("Linsolve options: max_rank={max_rank}, nfullsweeps=10, krylov_tol=1e-8, krylov_maxiter=200, krylov_dim=50, convergence_tol=1e-6"); + println!("Linsolve options: max_rank={max_rank}, nfullsweeps=10, gmres_tol=1e-8, gmres_max_restarts=200, gmres_restart_dim=50, convergence_tol=1e-6"); } let mut updater = SquareLinsolveUpdater::with_index_mappings( diff --git a/crates/tensor4all-treetn/examples/test_linsolve_identity_residual_n3.rs b/crates/tensor4all-treetn/examples/test_linsolve_identity_residual_n3.rs index b44e8915..b8cf9162 100644 --- a/crates/tensor4all-treetn/examples/test_linsolve_identity_residual_n3.rs +++ b/crates/tensor4all-treetn/examples/test_linsolve_identity_residual_n3.rs @@ -289,7 +289,7 @@ fn run_test_case(init_mode: &str, bond_dim: usize) -> anyhow::Result<()> { // Setup linsolve options and updater let options = LinsolveOptions::default() .with_nfullsweeps(20) - .with_krylov_tol(1e-10) + .with_gmres_tol(1e-10) .with_max_rank(4) .with_coefficients(a0, a1); diff --git a/crates/tensor4all-treetn/examples/test_linsolve_identity_residual_n3_variants.rs b/crates/tensor4all-treetn/examples/test_linsolve_identity_residual_n3_variants.rs index 215d21e4..fe1a926b 100644 --- a/crates/tensor4all-treetn/examples/test_linsolve_identity_residual_n3_variants.rs +++ b/crates/tensor4all-treetn/examples/test_linsolve_identity_residual_n3_variants.rs @@ -349,7 +349,7 @@ fn run_test_case(a0: f64, a1: f64, init_mode: &str, bond_dim: usize) -> anyhow:: // Setup linsolve options and updater let options = LinsolveOptions::default() .with_nfullsweeps(10) - .with_krylov_tol(1e-10) + .with_gmres_tol(1e-10) .with_max_rank(50) .with_coefficients(a0, a1); diff --git a/crates/tensor4all-treetn/examples/test_linsolve_mpo_identity.rs b/crates/tensor4all-treetn/examples/test_linsolve_mpo_identity.rs index d36e9990..cf6371f9 100644 --- a/crates/tensor4all-treetn/examples/test_linsolve_mpo_identity.rs +++ b/crates/tensor4all-treetn/examples/test_linsolve_mpo_identity.rs @@ -16,7 +16,7 @@ use num_complex::Complex64; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; use tensor4all_core::{ - index::DynId, AnyScalar, DynIndex, IndexLike, TensorDynLen, TensorIndex, TensorLike, + index::DynId, AnyScalar, DynIndex, IndexLike, TensorContractionLike, TensorDynLen, TensorIndex, }; use tensor4all_treetn::{ apply_linear_operator, apply_local_update_sweep, ApplyOptions, CanonicalizationOptions, @@ -381,7 +381,7 @@ fn compute_residual( let mut b2 = 0.0_f64; for ((ax_i, x_i), b_i) in ax_vec.iter().zip(x_vec.iter()).zip(b_vec.iter()) { let opx_i = x_i * a0 + ax_i * a1; - let ri = opx_i - b_i; + let ri: Complex64 = opx_i - *b_i; r2 += ri.norm_sqr(); b2 += b_i.norm_sqr(); } @@ -431,9 +431,9 @@ fn main() -> anyhow::Result<()> { let options = LinsolveOptions::default() .with_nfullsweeps(5) .with_max_rank(4) - .with_krylov_tol(1e-8) - .with_krylov_maxiter(20) - .with_krylov_dim(30) + .with_gmres_tol(1e-8) + .with_gmres_max_restarts(20) + .with_gmres_restart_dim(30) .with_coefficients(1.0, 0.0); // a0=1, a1=0 => I * x = b let (operator_a, a_input_mapping, a_output_mapping) = @@ -560,9 +560,9 @@ fn main() -> anyhow::Result<()> { let options_case3 = LinsolveOptions::default() .with_nfullsweeps(5) .with_max_rank(4) - .with_krylov_tol(1e-8) - .with_krylov_maxiter(20) - .with_krylov_dim(30) + .with_gmres_tol(1e-8) + .with_gmres_max_restarts(20) + .with_gmres_restart_dim(30) .with_coefficients(0.0, 1.0); // a0=0, a1=1 => (i*I)*x = b for init_mode in ["rhs", "random"] { diff --git a/crates/tensor4all-treetn/examples/test_linsolve_mpo_pauli_imaginary.rs b/crates/tensor4all-treetn/examples/test_linsolve_mpo_pauli_imaginary.rs index 086a1400..d2933338 100644 --- a/crates/tensor4all-treetn/examples/test_linsolve_mpo_pauli_imaginary.rs +++ b/crates/tensor4all-treetn/examples/test_linsolve_mpo_pauli_imaginary.rs @@ -17,7 +17,9 @@ use num_complex::Complex64; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; use tensor4all_core::index::DynId; -use tensor4all_core::{AnyScalar, DynIndex, IndexLike, TensorDynLen, TensorIndex, TensorLike}; +use tensor4all_core::{ + AnyScalar, DynIndex, IndexLike, TensorContractionLike, TensorDynLen, TensorIndex, +}; use tensor4all_treetn::{ apply_linear_operator, apply_local_update_sweep, ApplyOptions, CanonicalizationOptions, IndexMapping, LinearOperator, LinsolveOptions, LocalUpdateSweepPlan, SquareLinsolveUpdater, @@ -828,9 +830,9 @@ fn main() -> anyhow::Result<()> { let options = LinsolveOptions::default() .with_nfullsweeps(10) .with_max_rank(50) - .with_krylov_tol(1e-10) - .with_krylov_maxiter(30) - .with_krylov_dim(30) + .with_gmres_tol(1e-10) + .with_gmres_max_restarts(30) + .with_gmres_restart_dim(30) .with_coefficients(0.0, 1.0); // a0=0, a1=1 => (i*X) * x = b let center = make_node_name(n / 2); diff --git a/crates/tensor4all-treetn/examples/test_linsolve_mpo_pauli_operator.rs b/crates/tensor4all-treetn/examples/test_linsolve_mpo_pauli_operator.rs index a550f51a..cff5e6b1 100644 --- a/crates/tensor4all-treetn/examples/test_linsolve_mpo_pauli_operator.rs +++ b/crates/tensor4all-treetn/examples/test_linsolve_mpo_pauli_operator.rs @@ -13,7 +13,7 @@ use num_complex::Complex64; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; use tensor4all_core::{ - index::DynId, AnyScalar, DynIndex, IndexLike, TensorDynLen, TensorIndex, TensorLike, + index::DynId, AnyScalar, DynIndex, IndexLike, TensorContractionLike, TensorDynLen, TensorIndex, }; use tensor4all_treetn::{ apply_linear_operator, apply_local_update_sweep, ApplyOptions, CanonicalizationOptions, @@ -602,9 +602,9 @@ fn main() -> anyhow::Result<()> { let options_1 = LinsolveOptions::default() .with_nfullsweeps(5) .with_max_rank(4) - .with_krylov_tol(1e-8) - .with_krylov_maxiter(20) - .with_krylov_dim(30) + .with_gmres_tol(1e-8) + .with_gmres_max_restarts(20) + .with_gmres_restart_dim(30) .with_coefficients(0.0, 1.0); // a0=0, a1=1 => A * x = b println!("--- Solving A*x = b with init=rhs (i.e., init=b) ---"); @@ -703,9 +703,9 @@ fn main() -> anyhow::Result<()> { let options_2 = LinsolveOptions::default() .with_nfullsweeps(5) .with_max_rank(4) - .with_krylov_tol(1e-8) - .with_krylov_maxiter(20) - .with_krylov_dim(30) + .with_gmres_tol(1e-8) + .with_gmres_max_restarts(20) + .with_gmres_restart_dim(30) .with_coefficients(2.0, 1.0); // a0=2, a1=1 => (2I + A) * x = b println!("--- Solving (2I + A)*x = b with init=rhs (i.e., init=b) ---"); diff --git a/crates/tensor4all-treetn/examples/test_linsolve_mps_identity_n2.rs b/crates/tensor4all-treetn/examples/test_linsolve_mps_identity_n2.rs index dff3882a..df3dde70 100644 --- a/crates/tensor4all-treetn/examples/test_linsolve_mps_identity_n2.rs +++ b/crates/tensor4all-treetn/examples/test_linsolve_mps_identity_n2.rs @@ -9,7 +9,7 @@ use num_complex::Complex64; use rand::Rng; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; -use tensor4all_core::{index::DynId, DynIndex, IndexLike, TensorDynLen}; +use tensor4all_core::{index::DynId, DynIndex, IndexLike, TensorContractionLike, TensorDynLen}; use tensor4all_treetn::{ apply_linear_operator, apply_local_update_sweep, ApplyOptions, CanonicalizationOptions, IndexMapping, LinearOperator, LinsolveOptions, LocalUpdateSweepPlan, SquareLinsolveUpdater, @@ -302,7 +302,7 @@ fn run_case(phys_dim: usize) -> anyhow::Result<()> { let options = LinsolveOptions::default() .with_nfullsweeps(5) .with_max_rank(bond_dim) - .with_krylov_tol(1e-8); + .with_gmres_tol(1e-8); let mut x = x_true.clone(); diff --git a/crates/tensor4all-treetn/examples/test_linsolve_mps_pauli_imaginary.rs b/crates/tensor4all-treetn/examples/test_linsolve_mps_pauli_imaginary.rs index 9ec6febf..728a69c3 100644 --- a/crates/tensor4all-treetn/examples/test_linsolve_mps_pauli_imaginary.rs +++ b/crates/tensor4all-treetn/examples/test_linsolve_mps_pauli_imaginary.rs @@ -16,7 +16,7 @@ use rand::SeedableRng; use rand_chacha::ChaCha8Rng; use tensor4all_core::{ - index::DynId, AnyScalar, DynIndex, IndexLike, TensorDynLen, TensorIndex, TensorLike, + index::DynId, AnyScalar, DynIndex, IndexLike, TensorContractionLike, TensorDynLen, TensorIndex, }; use tensor4all_treetn::{ apply_linear_operator, apply_local_update_sweep, ApplyOptions, CanonicalizationOptions, @@ -407,9 +407,9 @@ fn main() -> anyhow::Result<()> { let options = LinsolveOptions::default() .with_nfullsweeps(10) .with_max_rank(bond_dim) - .with_krylov_tol(1e-10) - .with_krylov_maxiter(30) - .with_krylov_dim(30) + .with_gmres_tol(1e-10) + .with_gmres_max_restarts(30) + .with_gmres_restart_dim(30) .with_coefficients(0.0, 1.0); let n_sweeps = 10usize; diff --git a/crates/tensor4all-treetn/src/lib.rs b/crates/tensor4all-treetn/src/lib.rs index d37fad0e..a4f710ec 100644 --- a/crates/tensor4all-treetn/src/lib.rs +++ b/crates/tensor4all-treetn/src/lib.rs @@ -54,8 +54,10 @@ pub use operator::{ pub use options::{CanonicalizationOptions, RestructureOptions, SplitOptions, TruncationOptions}; pub use random::{random_treetn, LinkSpace}; pub use simplett_bridge::{ + fix_and_remove_site_from_treetn_chain, insert_onehot_site_in_treetn_chain, tensor_train_to_treetn, tensor_train_to_treetn_with_names, tensor_train_to_treetn_with_names_and_site_indices, treetn_to_tensor_train, + weighted_remove_site_from_treetn_chain, }; pub use site_index_network::SiteIndexNetwork; pub use treetn::{ @@ -67,6 +69,7 @@ pub use treetn::{ get_boundary_edges, hadamard, partial_contract, + partial_contract_to_site_network, sum_over_indices, weighted_sum_over_index_pairs, BoundaryEdge, @@ -91,9 +94,9 @@ pub use treetn::{ // Re-export linsolve types from new location pub use linsolve::{ - square_linsolve, EnvironmentCache, LinsolveOptions, LinsolveVerifyReport, NetworkTopology, - NodeVerifyDetail, ProjectedOperator, ProjectedState, SquareLinsolveResult, - SquareLinsolveUpdater, + relative_linear_system_residual, square_linsolve, EnvironmentCache, LinsolveOptions, + LinsolveVerifyReport, NetworkTopology, NodeVerifyDetail, ProjectedOperator, ProjectedState, + SquareLinsolveResult, SquareLinsolveUpdater, }; use petgraph::graph::NodeIndex; diff --git a/crates/tensor4all-treetn/src/linsolve/common/mod.rs b/crates/tensor4all-treetn/src/linsolve/common/mod.rs index 95d7ad47..e7765149 100644 --- a/crates/tensor4all-treetn/src/linsolve/common/mod.rs +++ b/crates/tensor4all-treetn/src/linsolve/common/mod.rs @@ -8,5 +8,5 @@ mod options; mod projected_operator; pub use environment::{EnvironmentCache, NetworkTopology}; -pub use options::LinsolveOptions; +pub use options::{GmresToleranceMode, LinsolveOptions}; pub use projected_operator::ProjectedOperator; diff --git a/crates/tensor4all-treetn/src/linsolve/common/options.rs b/crates/tensor4all-treetn/src/linsolve/common/options.rs index 9093e3d6..e1d3e899 100644 --- a/crates/tensor4all-treetn/src/linsolve/common/options.rs +++ b/crates/tensor4all-treetn/src/linsolve/common/options.rs @@ -1,7 +1,16 @@ //! Common options for linsolve algorithms. use crate::TruncationOptions; -use tensor4all_core::SvdTruncationPolicy; +use tensor4all_core::{AnyScalar, SvdTruncationPolicy}; + +/// Residual tolerance convention for local GMRES solves. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum GmresToleranceMode { + /// Stop when `||b - A*x|| / ||b|| < gmres_tol`. + Relative, + /// Stop when `||b - A*x|| < gmres_tol`. + Absolute, +} /// Options for the linsolve algorithm. #[derive(Debug, Clone)] @@ -13,18 +22,28 @@ pub struct LinsolveOptions { /// Truncation options for factorization. pub truncation: TruncationOptions, /// Tolerance for GMRES convergence. - pub krylov_tol: f64, - /// Maximum GMRES iterations per local solve. - pub krylov_maxiter: usize, - /// Krylov subspace dimension (restart parameter). - pub krylov_dim: usize, + pub gmres_tol: f64, + /// Residual tolerance convention for GMRES convergence. + pub gmres_tolerance_mode: GmresToleranceMode, + /// Maximum number of GMRES restart cycles per local solve. + /// + /// This matches KrylovKit's `maxiter` convention. The maximum number of + /// operator expansion steps is roughly `gmres_max_restarts * gmres_restart_dim`. + pub gmres_max_restarts: usize, + /// GMRES restart cycle length. + pub gmres_restart_dim: usize, /// Coefficient a₀ in (a₀ + a₁ * A) * x = b. - pub a0: f64, + pub a0: AnyScalar, /// Coefficient a₁ in (a₀ + a₁ * A) * x = b. - pub a1: f64, + pub a1: AnyScalar, /// Convergence tolerance for early termination. /// If Some(tol), stop when relative residual < tol. pub convergence_tol: Option, + /// Whether to compute and return a final true residual after the sweep. + /// + /// Disabling this skips an extra operator application after the requested sweeps. A residual + /// is still computed when `convergence_tol` is set because early stopping depends on it. + pub check_residual: bool, } impl Default for LinsolveOptions { @@ -32,12 +51,14 @@ impl Default for LinsolveOptions { Self { nfullsweeps: 5, truncation: TruncationOptions::default(), - krylov_tol: 1e-10, - krylov_maxiter: 100, - krylov_dim: 30, - a0: 0.0, - a1: 1.0, + gmres_tol: 1e-10, + gmres_tolerance_mode: GmresToleranceMode::Relative, + gmres_max_restarts: 100, + gmres_restart_dim: 30, + a0: AnyScalar::new_real(0.0), + a1: AnyScalar::new_real(1.0), convergence_tol: None, + check_residual: true, } } } @@ -76,27 +97,51 @@ impl LinsolveOptions { } /// Set GMRES tolerance. - pub fn with_krylov_tol(mut self, tol: f64) -> Self { - self.krylov_tol = tol; + pub fn with_gmres_tol(mut self, tol: f64) -> Self { + self.gmres_tol = tol; + self + } + + /// Set the residual tolerance convention for GMRES. + pub fn with_gmres_tolerance_mode(mut self, mode: GmresToleranceMode) -> Self { + self.gmres_tolerance_mode = mode; + self + } + + /// Use relative residual convergence for GMRES. + pub fn with_gmres_relative_tolerance(mut self) -> Self { + self.gmres_tolerance_mode = GmresToleranceMode::Relative; self } - /// Set maximum GMRES iterations. - pub fn with_krylov_maxiter(mut self, maxiter: usize) -> Self { - self.krylov_maxiter = maxiter; + /// Use absolute residual convergence for GMRES. + pub fn with_gmres_absolute_tolerance(mut self) -> Self { + self.gmres_tolerance_mode = GmresToleranceMode::Absolute; + self + } + + /// Set maximum number of GMRES restart cycles. + /// + /// This follows KrylovKit's `maxiter` convention. + pub fn with_gmres_max_restarts(mut self, max_restarts: usize) -> Self { + self.gmres_max_restarts = max_restarts; self } - /// Set Krylov subspace dimension. - pub fn with_krylov_dim(mut self, dim: usize) -> Self { - self.krylov_dim = dim; + /// Set GMRES restart cycle length. + pub fn with_gmres_restart_dim(mut self, dim: usize) -> Self { + self.gmres_restart_dim = dim; self } /// Set coefficients a₀ and a₁. - pub fn with_coefficients(mut self, a0: f64, a1: f64) -> Self { - self.a0 = a0; - self.a1 = a1; + pub fn with_coefficients(mut self, a0: A0, a1: A1) -> Self + where + A0: Into, + A1: Into, + { + self.a0 = a0.into(); + self.a1 = a1.into(); self } @@ -105,6 +150,16 @@ impl LinsolveOptions { self.convergence_tol = Some(tol); self } + + /// Set whether `square_linsolve` computes a final true residual. + /// + /// Use `false` when the caller only needs the swept solution and wants to avoid the extra + /// post-solve operator application. The residual is still evaluated if `convergence_tol` is + /// enabled, because it is required for early stopping. + pub fn with_residual_check(mut self, check_residual: bool) -> Self { + self.check_residual = check_residual; + self + } } #[cfg(test)] diff --git a/crates/tensor4all-treetn/src/linsolve/common/options/tests/mod.rs b/crates/tensor4all-treetn/src/linsolve/common/options/tests/mod.rs index 9c097636..65164cfe 100644 --- a/crates/tensor4all-treetn/src/linsolve/common/options/tests/mod.rs +++ b/crates/tensor4all-treetn/src/linsolve/common/options/tests/mod.rs @@ -1,13 +1,16 @@ use super::*; +use num_complex::Complex64; use tensor4all_core::SvdTruncationPolicy; #[test] fn test_default_options() { let opts = LinsolveOptions::default(); assert_eq!(opts.nfullsweeps, 5); - assert_eq!(opts.a0, 0.0); - assert_eq!(opts.a1, 1.0); + assert_eq!(opts.a0, AnyScalar::new_real(0.0)); + assert_eq!(opts.a1, AnyScalar::new_real(1.0)); + assert_eq!(opts.gmres_tolerance_mode, GmresToleranceMode::Relative); assert!(opts.convergence_tol.is_none()); + assert!(opts.check_residual); } #[test] @@ -18,15 +21,30 @@ fn test_builder_pattern() { let opts = LinsolveOptions::new(5) .with_max_rank(100) .with_svd_policy(policy) - .with_krylov_tol(1e-8) + .with_gmres_tol(1e-8) + .with_gmres_absolute_tolerance() .with_coefficients(1.0, -1.0) - .with_convergence_tol(1e-6); + .with_convergence_tol(1e-6) + .with_residual_check(false); assert_eq!(opts.nfullsweeps, 5); assert_eq!(opts.truncation.max_rank(), Some(100)); assert_eq!(opts.truncation.svd_policy(), Some(policy)); - assert_eq!(opts.krylov_tol, 1e-8); - assert_eq!(opts.a0, 1.0); - assert_eq!(opts.a1, -1.0); + assert_eq!(opts.gmres_tol, 1e-8); + assert_eq!(opts.gmres_tolerance_mode, GmresToleranceMode::Absolute); + assert_eq!(opts.a0, AnyScalar::new_real(1.0)); + assert_eq!(opts.a1, AnyScalar::new_real(-1.0)); assert_eq!(opts.convergence_tol, Some(1e-6)); + assert!(!opts.check_residual); +} + +#[test] +fn test_complex_coefficients() { + let opts = LinsolveOptions::default().with_coefficients( + Complex64::new(0.5, -0.25), + AnyScalar::new_complex(-1.0, 2.0), + ); + + assert_eq!(opts.a0, AnyScalar::new_complex(0.5, -0.25)); + assert_eq!(opts.a1, AnyScalar::new_complex(-1.0, 2.0)); } diff --git a/crates/tensor4all-treetn/src/linsolve/common/projected_operator.rs b/crates/tensor4all-treetn/src/linsolve/common/projected_operator.rs index d33c224c..9f1c3caa 100644 --- a/crates/tensor4all-treetn/src/linsolve/common/projected_operator.rs +++ b/crates/tensor4all-treetn/src/linsolve/common/projected_operator.rs @@ -13,7 +13,7 @@ use std::hash::Hash; use anyhow::Result; -use tensor4all_core::{AllowedPairs, IndexLike, TensorLike}; +use tensor4all_core::{IndexLike, TensorLike}; use super::environment::{EnvironmentCache, NetworkTopology}; use crate::operator::IndexMapping; @@ -56,6 +56,29 @@ where pub output_mapping: Option>>, } +struct LocalIndexMapping { + true_in: I, + internal_in: I, + temp_in: I, + true_out: I, + internal_out: I, + temp_out: I, +} + +enum ContractOperand<'a, T> { + Borrowed(&'a T), + Owned(T), +} + +impl<'a, T> ContractOperand<'a, T> { + fn as_ref(&'a self) -> &'a T { + match self { + Self::Borrowed(tensor) => tensor, + Self::Owned(tensor) => tensor, + } + } +} + impl ProjectedOperator where T: TensorLike, @@ -122,7 +145,7 @@ where // Ensure environments are computed self.ensure_environments(region, ket_state, bra_state, topology)?; - let mut all_tensors: Vec = Vec::new(); + let mut all_tensors = Vec::new(); let mut temp_out_to_true: Vec<(T::Index, T::Index)> = Vec::new(); if let (Some(ref input_mapping), Some(ref output_mapping)) = @@ -131,52 +154,78 @@ where // MPO-with-mappings path: use unique temp indices to avoid duplicate IDs. // Replace true_index -> temp_in on v (never use internal_index on v). // Use same temp_in/temp_out on op tensors so they contract with v. - let mut per_node: Vec<(T::Index, T::Index, T::Index)> = Vec::new(); + let mut per_node: Vec>> = Vec::new(); for node in region { - let im = input_mapping - .get(node) - .ok_or_else(|| anyhow::anyhow!("Missing input_mapping for node {:?}", node))?; - let om = output_mapping - .get(node) - .ok_or_else(|| anyhow::anyhow!("Missing output_mapping for node {:?}", node))?; - let temp_in = im.internal_index.sim(); - let temp_out = om.internal_index.sim(); - per_node.push((temp_in, temp_out, om.true_index.clone())); + match (input_mapping.get(node), output_mapping.get(node)) { + (Some(im), Some(om)) => { + let temp_in = im.internal_index.sim(); + let temp_out = om.internal_index.sim(); + per_node.push(Some(LocalIndexMapping { + true_in: im.true_index.clone(), + internal_in: im.internal_index.clone(), + temp_in, + true_out: om.true_index.clone(), + internal_out: om.internal_index.clone(), + temp_out, + })); + } + (None, None) => { + if self + .operator + .site_space(node) + .is_some_and(|site_space| !site_space.is_empty()) + { + return Err(anyhow::anyhow!( + "Missing index mappings for operator node {:?} with non-empty site space", + node + )); + } + per_node.push(None); + } + (None, Some(_)) => { + return Err(anyhow::anyhow!("Missing input_mapping for node {:?}", node)); + } + (Some(_), None) => { + return Err(anyhow::anyhow!( + "Missing output_mapping for node {:?}", + node + )); + } + } } - let mut transformed_v = v.clone(); - for (node, (temp_in, _temp_out, _)) in region.iter().zip(per_node.iter()) { - let im = input_mapping - .get(node) - .ok_or_else(|| anyhow::anyhow!("Missing input_mapping for node {:?}", node))?; - transformed_v = transformed_v.replaceind(&im.true_index, temp_in)?; + if per_node.iter().any(Option::is_some) { + let mut transformed_v = v.clone(); + for mapping in per_node.iter().flatten() { + transformed_v = transformed_v.replaceind(&mapping.true_in, &mapping.temp_in)?; + } + all_tensors.push(ContractOperand::Owned(transformed_v)); + } else { + all_tensors.push(ContractOperand::Borrowed(v)); } - all_tensors.push(transformed_v); - - for (node, (temp_in, temp_out, true_idx)) in region.iter().zip(per_node.iter()) { - let im = input_mapping - .get(node) - .ok_or_else(|| anyhow::anyhow!("Missing input_mapping for node {:?}", node))?; - let om = output_mapping - .get(node) - .ok_or_else(|| anyhow::anyhow!("Missing output_mapping for node {:?}", node))?; + + for (node, mapping) in region.iter().zip(per_node.iter()) { let node_idx = self .operator .node_index(node) .ok_or_else(|| anyhow::anyhow!("Node {:?} not found in operator", node))?; - let mut t = self + let tensor = self .operator .tensor(node_idx) - .ok_or_else(|| anyhow::anyhow!("Tensor not found in operator"))? - .clone(); - t = t.replaceind(&im.internal_index, temp_in)?; - t = t.replaceind(&om.internal_index, temp_out)?; - all_tensors.push(t); - temp_out_to_true.push((temp_out.clone(), true_idx.clone())); + .ok_or_else(|| anyhow::anyhow!("Tensor not found in operator"))?; + if let Some(mapping) = mapping { + let mut t = tensor.clone(); + t = t.replaceind(&mapping.internal_in, &mapping.temp_in)?; + t = t.replaceind(&mapping.internal_out, &mapping.temp_out)?; + temp_out_to_true.push((mapping.temp_out.clone(), mapping.true_out.clone())); + all_tensors.push(ContractOperand::Owned(t)); + } else { + all_tensors.push(ContractOperand::Borrowed(tensor)); + } } } else { // No mappings: plain path - all_tensors.push(v.clone()); + all_tensors.push(ContractOperand::Borrowed(v)); for node in region { let node_idx = self .operator @@ -185,9 +234,8 @@ where let tensor = self .operator .tensor(node_idx) - .ok_or_else(|| anyhow::anyhow!("Tensor not found in operator"))? - .clone(); - all_tensors.push(tensor); + .ok_or_else(|| anyhow::anyhow!("Tensor not found in operator"))?; + all_tensors.push(ContractOperand::Borrowed(tensor)); } } @@ -198,13 +246,13 @@ where continue; } if let Some(env) = self.envs.get(&neighbor, node) { - all_tensors.push(env.clone()); + all_tensors.push(ContractOperand::Borrowed(env)); } } } - let tensor_refs: Vec<&T> = all_tensors.iter().collect(); - let mut contracted = T::contract(&tensor_refs, AllowedPairs::All)?; + let tensor_refs: Vec<&T> = all_tensors.iter().map(ContractOperand::as_ref).collect(); + let mut contracted = T::contract(&tensor_refs)?; // Replace temp_out -> true_index for (temp_out, true_idx) in &temp_out_to_true { @@ -301,9 +349,9 @@ where } // Collect child environments - let child_envs: Vec = child_neighbors + let child_envs: Vec<&T> = child_neighbors .iter() - .filter_map(|child| self.envs.get(child, from).cloned()) + .filter_map(|child| self.envs.get(child, from)) .collect(); // Get tensors from bra (V_out), operator, and ket (V_in) at this node @@ -339,33 +387,65 @@ where let bra_conj = tensor_bra.conj(); + let input_unmapped = self + .input_mapping + .as_ref() + .is_none_or(|mapping| !mapping.contains_key(from)); + let output_unmapped = self + .output_mapping + .as_ref() + .is_none_or(|mapping| !mapping.contains_key(from)); + let operator_site_is_empty = self + .operator + .site_space(from) + .is_some_and(|site_space| site_space.is_empty()); + let no_site_spectator = input_unmapped && output_unmapped && operator_site_is_empty; + + let mut tensor_refs = Vec::with_capacity(3 + child_envs.len()); + // Transform ket tensor for contraction with operator - let transformed_ket = if let Some(ref input_mapping) = self.input_mapping { + if let Some(ref input_mapping) = self.input_mapping { if let Some(mapping) = input_mapping.get(from) { - tensor_ket.replaceind(&mapping.true_index, &mapping.internal_index)? + tensor_refs.push(ContractOperand::Owned( + tensor_ket.replaceind(&mapping.true_index, &mapping.internal_index)?, + )); } else { - tensor_ket.clone() + tensor_refs.push(ContractOperand::Borrowed(tensor_ket)); } } else { - tensor_ket.clone() - }; + tensor_refs.push(ContractOperand::Borrowed(tensor_ket)); + } + + if !no_site_spectator { + tensor_refs.push(ContractOperand::Borrowed(tensor_op)); + } // Transform bra_conj tensor for contraction with operator - let transformed_bra_conj = if let Some(ref output_mapping) = self.output_mapping { + if let Some(ref output_mapping) = self.output_mapping { if let Some(mapping) = output_mapping.get(from) { - bra_conj.replaceind(&mapping.true_index, &mapping.internal_index)? + tensor_refs.push(ContractOperand::Owned( + bra_conj.replaceind(&mapping.true_index, &mapping.internal_index)?, + )); } else { - bra_conj.clone() + tensor_refs.push(ContractOperand::Owned(bra_conj)); } } else { - bra_conj.clone() - }; + tensor_refs.push(ContractOperand::Owned(bra_conj)); + } // Contract ket, op, bra, and child environments together // Let contract() find the optimal contraction order - let mut tensor_refs: Vec<&T> = vec![&transformed_ket, tensor_op, &transformed_bra_conj]; - tensor_refs.extend(child_envs.iter()); - T::contract(&tensor_refs, AllowedPairs::All) + tensor_refs.extend(child_envs.into_iter().map(ContractOperand::Borrowed)); + let tensor_refs = tensor_refs + .iter() + .map(ContractOperand::as_ref) + .collect::>(); + let contracted = T::contract(&tensor_refs)?; + if no_site_spectator { + contracted.contract_pair(tensor_op) + } else { + Ok(contracted) + } } /// Compute the local dimension (size of the local Hilbert space). diff --git a/crates/tensor4all-treetn/src/linsolve/mod.rs b/crates/tensor4all-treetn/src/linsolve/mod.rs index 6b1b5960..b977211e 100644 --- a/crates/tensor4all-treetn/src/linsolve/mod.rs +++ b/crates/tensor4all-treetn/src/linsolve/mod.rs @@ -25,8 +25,10 @@ pub mod common; pub mod square; // Re-export commonly used types -pub use common::{EnvironmentCache, LinsolveOptions, NetworkTopology, ProjectedOperator}; +pub use common::{ + EnvironmentCache, GmresToleranceMode, LinsolveOptions, NetworkTopology, ProjectedOperator, +}; pub use square::{ - square_linsolve, LinsolveVerifyReport, NodeVerifyDetail, ProjectedState, SquareLinsolveResult, - SquareLinsolveUpdater, + relative_linear_system_residual, square_linsolve, LinsolveVerifyReport, NodeVerifyDetail, + ProjectedState, SquareLinsolveResult, SquareLinsolveUpdater, }; diff --git a/crates/tensor4all-treetn/src/linsolve/square/local_linop.rs b/crates/tensor4all-treetn/src/linsolve/square/local_linop.rs index 048f0b67..caebc36d 100644 --- a/crates/tensor4all-treetn/src/linsolve/square/local_linop.rs +++ b/crates/tensor4all-treetn/src/linsolve/square/local_linop.rs @@ -7,7 +7,6 @@ use std::hash::Hash; use std::sync::{Arc, RwLock}; use anyhow::{Context, Result}; -use tensor4all_core::any_scalar::AnyScalar; use tensor4all_core::{IndexLike, TensorLike}; use crate::linsolve::common::ProjectedOperator; @@ -15,12 +14,13 @@ use crate::treetn::TreeTN; /// LocalLinOp: Wraps the projected operator for local GMRES solving. /// -/// This applies the local linear operator: `y = a₀ * x + a₁ * H * x` -/// where H is the projected operator. +/// This applies only the projected local operator `H * x`. +/// Affine coefficients such as `a₀ I + a₁ H` belong to the Krylov solver +/// so the Arnoldi basis is built from the same unshifted operator as KrylovKit. /// /// This is the V_in = V_out specialized version that maintains a separate /// reference state for stable environment computation. -pub struct LocalLinOp +pub struct LocalLinOp<'a, T, V> where T: TensorLike + 'static, ::Id: @@ -32,17 +32,13 @@ where /// The region being updated pub region: Vec, /// Current state for ket in environment computation - pub state: TreeTN, + pub state: &'a TreeTN, /// Reference state for bra in environment computation /// Uses separate bond indices to prevent unintended contractions - pub reference_state: TreeTN, - /// Coefficient a₀ (can be real or complex) - pub a0: AnyScalar, - /// Coefficient a₁ (can be real or complex) - pub a1: AnyScalar, + pub reference_state: &'a TreeTN, } -impl LocalLinOp +impl<'a, T, V> LocalLinOp<'a, T, V> where T: TensorLike + 'static, T::Index: IndexLike, @@ -57,27 +53,81 @@ where pub fn new( projected_operator: Arc>>, region: Vec, - state: TreeTN, - reference_state: TreeTN, - a0: AnyScalar, - a1: AnyScalar, + state: &'a TreeTN, + reference_state: &'a TreeTN, ) -> Self { Self { projected_operator, region, state, reference_state, - a0, - a1, } } - /// Apply the local linear operator: `y = a₀ * x + a₁ * H * x` - /// - /// This is used by `tensor4all_core::krylov::gmres` to solve the local problem. - pub fn apply(&self, x: &T) -> Result { - // Apply operator: H * x - // ProjectedOperator handles environment computation and index mappings + fn local_input_indices(&self) -> Result> { + let Some((first_node, rest_nodes)) = self.region.split_first() else { + return Err(anyhow::anyhow!( + "LocalLinOp::apply_projected: region must not be empty" + )); + }; + let first_idx = self.state.node_index(first_node).ok_or_else(|| { + anyhow::anyhow!( + "LocalLinOp::apply_projected: node {:?} not found in state", + first_node + ) + })?; + let mut local = self + .state + .tensor(first_idx) + .ok_or_else(|| { + anyhow::anyhow!( + "LocalLinOp::apply_projected: tensor for node {:?} not found in state", + first_node + ) + })? + .clone(); + + for node in rest_nodes { + let node_idx = self.state.node_index(node).ok_or_else(|| { + anyhow::anyhow!( + "LocalLinOp::apply_projected: node {:?} not found in state", + node + ) + })?; + let tensor = self.state.tensor(node_idx).ok_or_else(|| { + anyhow::anyhow!( + "LocalLinOp::apply_projected: tensor for node {:?} not found in state", + node + ) + })?; + local = local.contract_pair(tensor)?; + } + + Ok(local.external_indices()) + } + + fn same_index_set(left: &[T::Index], right: &[T::Index]) -> bool { + let left_keys: std::collections::HashSet<_> = + left.iter().map(|idx| (idx.clone(), idx.dim())).collect(); + let right_keys: std::collections::HashSet<_> = + right.iter().map(|idx| (idx.clone(), idx.dim())).collect(); + left.len() == right.len() && left_keys == right_keys + } + + /// Apply the projected local operator: `y = H * x`. + pub fn apply_projected(&self, x: &T) -> Result { + let x_indices = x.external_indices(); + let expected_input_indices = self.local_input_indices()?; + if !Self::same_index_set(&x_indices, &expected_input_indices) { + return Err(anyhow::anyhow!( + "LocalLinOp::apply_projected: index structure mismatch between input (x) and the local state vector space:\n x has {} indices: {:?}\n expected {} indices: {:?}", + x_indices.len(), + x_indices.iter().map(|i| format!("{:?}:{}", i.id(), i.dim())).collect::>(), + expected_input_indices.len(), + expected_input_indices.iter().map(|i| format!("{:?}:{}", i.id(), i.dim())).collect::>(), + )); + } + let mut proj_op = self .projected_operator .write() @@ -86,75 +136,31 @@ where }) .context("LocalLinOp::apply: lock poisoned")?; - let mut hx = proj_op.apply( + let hx = proj_op.apply( x, &self.region, - &self.state, - &self.reference_state, + self.state, + self.reference_state, self.state.site_index_network(), )?; - // Map output tensor's boundary bond indices back to ket space - // The projected operator application produces output with bra-side boundary bonds - for node in &self.region { - for neighbor in self.state.site_index_network().neighbors(node) { - if !self.region.contains(&neighbor) { - let ket_edge = match self.state.edge_between(node, &neighbor) { - Some(e) => e, - None => continue, - }; - let bra_edge = match self.reference_state.edge_between(node, &neighbor) { - Some(e) => e, - None => continue, - }; - let ket_bond = match self.state.bond_index(ket_edge) { - Some(b) => b, - None => continue, - }; - let bra_bond = match self.reference_state.bond_index(bra_edge) { - Some(b) => b, - None => continue, - }; - - // Only replace if hx actually contains the exact bra bond. - if hx.external_indices().iter().any(|idx| idx == bra_bond) { - hx = hx.replaceind(bra_bond, ket_bond)?; - } - } - } - } - - // When a0 = 0, just return a1 * H * x (avoids axpby which requires same indices) - if self.a0.is_zero() { - return hx.scale(self.a1.clone()); - } - - // Align hx indices to match x's index order for axpby. - // Check that hx and x have the same full index structure. - let x_indices = x.external_indices(); let hx_indices = hx.external_indices(); - let x_index_keys: std::collections::HashSet<_> = - x_indices.iter().map(|i| (i.clone(), i.dim())).collect(); - let hx_index_keys: std::collections::HashSet<_> = - hx_indices.iter().map(|i| (i.clone(), i.dim())).collect(); - - let hx_aligned = if x_index_keys == hx_index_keys && x_indices.len() == hx_indices.len() { - // Same index set and count - permute to match order - hx.permuteinds(&x_indices)? - } else { + let indices_match = x_indices.len() == hx_indices.len() + && x_indices + .iter() + .zip(hx_indices.iter()) + .all(|(x_idx, hx_idx)| x_idx == hx_idx && x_idx.dim() == hx_idx.dim()); + if !indices_match { return Err(anyhow::anyhow!( - "LocalLinOp::apply: index structure mismatch between operator output (hx) and input (x):\n x has {} indices: {:?}\n hx has {} indices: {:?}\n x index keys: {:?}\n hx index keys: {:?}\n\nThis suggests the projected operator application produced output with different index structure than expected.", + "LocalLinOp::apply_projected: index structure mismatch between operator output (hx) and input (x):\n x has {} indices: {:?}\n hx has {} indices: {:?}\n\nProjectedOperator::apply is expected to return output in the same local vector space and index order as its input.", x_indices.len(), x_indices.iter().map(|i| format!("{:?}:{}", i.id(), i.dim())).collect::>(), hx_indices.len(), hx_indices.iter().map(|i| format!("{:?}:{}", i.id(), i.dim())).collect::>(), - x_index_keys.iter().map(|(i, dim)| format!("{i:?}:{dim}")).collect::>(), - hx_index_keys.iter().map(|(i, dim)| format!("{i:?}:{dim}")).collect::>(), )); - }; + } - // Compute y = a₀ * x + a₁ * H * x - x.axpby(self.a0.clone(), &hx_aligned, self.a1.clone()) + Ok(hx) } } diff --git a/crates/tensor4all-treetn/src/linsolve/square/local_linop/tests/mod.rs b/crates/tensor4all-treetn/src/linsolve/square/local_linop/tests/mod.rs index 4f046d6c..61885af9 100644 --- a/crates/tensor4all-treetn/src/linsolve/square/local_linop/tests/mod.rs +++ b/crates/tensor4all-treetn/src/linsolve/square/local_linop/tests/mod.rs @@ -1,7 +1,7 @@ use std::collections::{HashMap, HashSet}; use tensor4all_core::index::DynId; -use tensor4all_core::{AnyScalar, DynIndex, IndexLike, TensorDynLen, TensorIndex}; +use tensor4all_core::{DynIndex, IndexLike, TensorDynLen, TensorIndex}; use crate::operator::IndexMapping; use crate::treetn::TreeTN; @@ -32,20 +32,16 @@ fn test_local_linop_new() { let linop = LocalLinOp::new( projected_op, vec!["site0".to_string()], - state, - reference_state, - AnyScalar::new_real(1.0), - AnyScalar::new_real(0.0), + &state, + &reference_state, ); assert_eq!(linop.region.len(), 1); - assert_eq!(linop.a0, AnyScalar::new_real(1.0)); - assert_eq!(linop.a1, AnyScalar::new_real(0.0)); } -/// Apply with a0=0 hits the early return path (scale only, no index alignment). +/// Projected apply requires output to stay in the input local vector space. #[test] -fn test_local_linop_apply_a0_zero() { +fn test_local_linop_apply_projected_rejects_mismatch() { use crate::linsolve::common::ProjectedOperator; let mut state = TreeTN::::new(); @@ -59,10 +55,8 @@ fn test_local_linop_apply_a0_zero() { let linop = LocalLinOp::new( projected_op, vec!["site0".to_string()], - state.clone(), - reference_state, - AnyScalar::new_real(0.0), - AnyScalar::new_real(1.0), + &state, + &reference_state, ); let site0 = "site0".to_string(); @@ -70,8 +64,8 @@ fn test_local_linop_apply_a0_zero() { .tensor(state.node_index(&site0).unwrap()) .unwrap() .clone(); - let y = linop.apply(&x).unwrap(); - assert_eq!(y.external_indices().len(), 0); + let err = linop.apply_projected(&x).unwrap_err(); + assert!(err.to_string().contains("index structure mismatch")); } /// Apply with x whose index structure differs from operator output triggers index mismatch error. @@ -90,15 +84,13 @@ fn test_local_linop_apply_index_mismatch() { let linop = LocalLinOp::new( projected_op, vec!["site0".to_string()], - state, - reference_state, - AnyScalar::new_real(1.0), - AnyScalar::new_real(0.0), + &state, + &reference_state, ); let other = DynIndex::new_dyn(2); let x = TensorDynLen::from_dense(vec![other], vec![1.0, 0.0]).unwrap(); - let err = linop.apply(&x).unwrap_err(); + let err = linop.apply_projected(&x).unwrap_err(); assert!(err.to_string().contains("index structure mismatch")); } @@ -154,10 +146,8 @@ fn test_local_linop_apply_success_mappings() { let linop = LocalLinOp::new( projected_op, vec!["site0".to_string()], - state.clone(), - reference_state, - AnyScalar::new_real(1.0), - AnyScalar::new_real(0.0), + &state, + &reference_state, ); let site0 = "site0".to_string(); @@ -165,7 +155,7 @@ fn test_local_linop_apply_success_mappings() { .tensor(state.node_index(&site0).unwrap()) .unwrap() .clone(); - let y = linop.apply(&x).unwrap(); + let y = linop.apply_projected(&x).unwrap(); let x_ids: HashSet<_> = x .external_indices() .iter() diff --git a/crates/tensor4all-treetn/src/linsolve/square/mod.rs b/crates/tensor4all-treetn/src/linsolve/square/mod.rs index 2d6fce7b..bae3a364 100644 --- a/crates/tensor4all-treetn/src/linsolve/square/mod.rs +++ b/crates/tensor4all-treetn/src/linsolve/square/mod.rs @@ -33,12 +33,12 @@ pub use updater::{LinsolveVerifyReport, NodeVerifyDetail, SquareLinsolveUpdater} use std::collections::HashMap; use std::hash::Hash; -use anyhow::Result; +use anyhow::{bail, Result}; -use tensor4all_core::{IndexLike, TensorLike}; +use tensor4all_core::{AnyScalar, IndexLike, TensorLike}; use crate::linsolve::common::LinsolveOptions; -use crate::operator::IndexMapping; +use crate::operator::{apply_linear_operator, ApplyOptions, IndexMapping, LinearOperator}; use crate::{apply_local_update_sweep, CanonicalizationOptions, LocalUpdateSweepPlan, TreeTN}; /// Result of square_linsolve operation. @@ -52,7 +52,7 @@ where pub solution: TreeTN, /// Number of sweeps performed pub sweeps: usize, - /// Final residual norm (if computed) + /// Final relative residual norm `||(a0 I + a1 A)x - b|| / ||b||`. pub residual: Option, /// Converged flag pub converged: bool, @@ -152,9 +152,32 @@ where // Validate inputs before proceeding validate_linsolve_inputs(operator, rhs, &init)?; + if options.a1.is_zero() || operator_is_zero(operator)? { + return solve_identity_term_only( + operator, + rhs, + &init, + options, + input_mapping, + output_mapping, + ); + } + // Canonicalize initial guess towards center let mut x = init.canonicalize([center.clone()], CanonicalizationOptions::default())?; + let needs_residual_operator = options.check_residual || options.convergence_tol.is_some(); + let residual_operator = if needs_residual_operator { + Some(linear_operator_for_residual( + operator, + &x, + input_mapping.clone(), + output_mapping.clone(), + )?) + } else { + None + }; + // Create SquareLinsolveUpdater with or without index mappings let mut updater = match (input_mapping, output_mapping) { (Some(input), Some(output)) => SquareLinsolveUpdater::with_index_mappings( @@ -178,21 +201,228 @@ where let mut final_sweeps = 0; + let mut residual = None; + let mut converged = false; + // Perform sweeps for sweep in 0..options.nfullsweeps { final_sweeps = sweep + 1; apply_local_update_sweep(&mut x, &plan, &mut updater)?; + if let Some(tol) = options.convergence_tol { + let current_residual = relative_linear_system_residual( + residual_operator + .as_ref() + .ok_or_else(|| anyhow::anyhow!("missing residual operator"))?, + &x, + rhs, + options.a0.clone(), + options.a1.clone(), + ApplyOptions::naive(), + )?; + residual = Some(current_residual); + if current_residual < tol { + converged = true; + break; + } + } + } + + if residual.is_none() && options.check_residual { + let final_residual = relative_linear_system_residual( + residual_operator + .as_ref() + .ok_or_else(|| anyhow::anyhow!("missing residual operator"))?, + &x, + rhs, + options.a0.clone(), + options.a1.clone(), + ApplyOptions::naive(), + )?; + converged = options + .convergence_tol + .is_some_and(|tol| final_residual < tol); + residual = Some(final_residual); } - // Note: Residual computation (||Hx - b|| / ||b||) and convergence checking - // are not yet implemented. Currently, all requested sweeps are performed. Ok(SquareLinsolveResult { solution: x, sweeps: final_sweeps, - residual: None, - converged: false, + residual, + converged, + }) +} + +fn operator_is_zero(operator: &TreeTN) -> Result +where + T: TensorLike, + V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug, +{ + let mut operator = operator.clone(); + Ok(operator.norm()? <= 1.0e-15) +} + +fn solve_identity_term_only( + operator: &TreeTN, + rhs: &TreeTN, + init: &TreeTN, + options: LinsolveOptions, + input_mapping: Option>>, + output_mapping: Option>>, +) -> Result> +where + T: TensorLike, + T::Index: IndexLike + Clone + Hash + Eq + std::fmt::Debug, + ::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync, + V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug, +{ + if options.a0.is_zero() { + bail!("square_linsolve: a0 and effective operator term are both zero"); + } + + let mut solution = rhs.clone(); + solution.scale(AnyScalar::new_real(1.0) / options.a0.clone())?; + let residual = if options.check_residual || options.convergence_tol.is_some() { + let residual_operator = + linear_operator_for_residual(operator, init, input_mapping, output_mapping)?; + Some(relative_linear_system_residual( + &residual_operator, + &solution, + rhs, + options.a0.clone(), + options.a1.clone(), + ApplyOptions::naive(), + )?) + } else { + None + }; + let converged = options + .convergence_tol + .zip(residual) + .is_some_and(|(tol, residual)| residual < tol); + + Ok(SquareLinsolveResult { + solution, + sweeps: 0, + residual, + converged, }) } +fn linear_operator_for_residual( + operator: &TreeTN, + state: &TreeTN, + input_mapping: Option>>, + output_mapping: Option>>, +) -> Result> +where + T: TensorLike, + T::Index: IndexLike + Clone + Hash + Eq + std::fmt::Debug, + ::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync, + V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug, +{ + match (input_mapping, output_mapping) { + (Some(input), Some(output)) => Ok(LinearOperator::new(operator.clone(), input, output)), + (None, None) => LinearOperator::from_mpo_and_state(operator.clone(), state), + _ => Err(anyhow::anyhow!( + "input_mapping and output_mapping must both be Some or both be None" + )), + } +} + +/// Compute the true relative residual of a TreeTN linear system. +/// +/// This evaluates `||(a0 I + a1 A)x - b|| / ||b||` by applying `operator` to +/// `solution`, combining the identity and operator terms, and measuring the +/// TreeTN norm of the resulting residual. When `||b||` is numerically zero, the +/// absolute residual norm is returned to avoid division by zero. +/// +/// # Arguments +/// * `operator` - Linear operator `A` including any input/output index mappings. +/// * `solution` - Candidate solution `x`. +/// * `rhs` - Right-hand side `b`. +/// * `a0` - Identity coefficient. +/// * `a1` - Operator coefficient. +/// * `apply_options` - Options for applying `A`; use [`ApplyOptions::naive`] for +/// an exact residual of the represented TreeTN. +/// +/// # Returns +/// The relative residual norm, or the absolute residual norm for zero RHS. +/// +/// # Errors +/// Returns an error if operator application, TreeTN addition, scaling, or norm +/// computation fails. +/// +/// # Examples +/// ``` +/// use std::collections::HashMap; +/// use tensor4all_core::{DynIndex, TensorDynLen}; +/// use tensor4all_treetn::{ +/// relative_linear_system_residual, ApplyOptions, IndexMapping, LinearOperator, TreeTN, +/// }; +/// +/// let site = DynIndex::new_dyn(2); +/// let s_in = DynIndex::new_dyn(2); +/// let s_out = DynIndex::new_dyn(2); +/// let state_tensor = TensorDynLen::from_dense(vec![site.clone()], vec![3.0_f64, 5.0]).unwrap(); +/// let state = TreeTN::::from_tensors(vec![state_tensor], vec![0]).unwrap(); +/// let mpo_tensor = TensorDynLen::from_dense( +/// vec![s_out.clone(), s_in.clone()], +/// vec![1.0_f64, 0.0, 0.0, 1.0], +/// ).unwrap(); +/// let mpo = TreeTN::::from_tensors(vec![mpo_tensor], vec![0]).unwrap(); +/// let mut input_mapping = HashMap::new(); +/// input_mapping.insert(0usize, IndexMapping { true_index: site.clone(), internal_index: s_in }); +/// let mut output_mapping = HashMap::new(); +/// output_mapping.insert(0usize, IndexMapping { true_index: site, internal_index: s_out }); +/// let operator = LinearOperator::new(mpo, input_mapping, output_mapping); +/// +/// let residual = relative_linear_system_residual( +/// &operator, +/// &state, +/// &state, +/// 0.0, +/// 1.0, +/// ApplyOptions::naive(), +/// ).unwrap(); +/// assert!(residual < 1.0e-12); +/// ``` +pub fn relative_linear_system_residual( + operator: &LinearOperator, + solution: &TreeTN, + rhs: &TreeTN, + a0: A0, + a1: A1, + apply_options: ApplyOptions, +) -> Result +where + T: TensorLike, + T::Index: IndexLike + Clone + Hash + Eq + std::fmt::Debug, + ::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync, + V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug, + A0: Into, + A1: Into, +{ + let a0 = a0.into(); + let a1 = a1.into(); + let mut lhs = solution.clone(); + lhs.scale(a0)?; + + if !a1.is_zero() { + let mut ax = apply_linear_operator(operator, solution, apply_options)?; + ax.scale(a1)?; + lhs = lhs.add_reindexed_like_self(&ax)?; + } + + let mut negative_rhs = rhs.clone(); + negative_rhs.scale(AnyScalar::new_real(-1.0))?; + let mut residual = lhs.add_reindexed_like_self(&negative_rhs)?; + + let rhs_norm = rhs.clone().norm()?; + if rhs_norm <= 1.0e-15 { + return residual.norm(); + } + Ok(residual.norm()? / rhs_norm) +} + #[cfg(test)] mod tests; diff --git a/crates/tensor4all-treetn/src/linsolve/square/projected_state.rs b/crates/tensor4all-treetn/src/linsolve/square/projected_state.rs index 5c22dac9..d955bce4 100644 --- a/crates/tensor4all-treetn/src/linsolve/square/projected_state.rs +++ b/crates/tensor4all-treetn/src/linsolve/square/projected_state.rs @@ -13,7 +13,7 @@ use std::hash::Hash; use anyhow::Result; -use tensor4all_core::{AllowedPairs, IndexLike, TensorLike}; +use tensor4all_core::{IndexLike, TensorLike}; use crate::linsolve::common::{EnvironmentCache, NetworkTopology}; use crate::treetn::TreeTN; @@ -70,6 +70,9 @@ where /// /// For the square case, the `reference_state` is used as the bra (conjugated), /// and `rhs` is used as the ket, i.e. environments are constructed for ``. + /// If the reference and RHS networks share link indices, the reference links + /// are relabeled for this local contraction so bra/ket boundary links are + /// not accidentally traced out. /// /// # Arguments /// * `region` - The nodes in the local update region @@ -81,6 +84,18 @@ where reference_state: &TreeTN, topology: &NT, ) -> Result { + let reference_storage; + let reference_state = if self.has_link_collision_with_rhs(reference_state, topology)? { + // A bra/ket overlap needs two independent copies of every open link: + // the RHS-side link is contracted with the local RHS tensor, while + // the reference-side link remains as the output local vector space. + self.envs.clear(); + reference_storage = Some(reference_state.sim_linkinds()?); + reference_storage.as_ref().unwrap() + } else { + reference_state + }; + // Ensure environments are computed self.ensure_environments(region, reference_state, topology)?; @@ -114,7 +129,47 @@ where // Use T::contract for optimal contraction ordering let tensor_refs: Vec<&T> = all_tensors.iter().collect(); - T::contract(&tensor_refs, AllowedPairs::All) + T::contract(&tensor_refs) + } + + fn has_link_collision_with_rhs>( + &self, + reference_state: &TreeTN, + topology: &NT, + ) -> Result { + for node in reference_state.node_names() { + for neighbor in topology.neighbors(&node) { + if node > neighbor { + continue; + } + + let Some(ref_edge) = reference_state.edge_between(&node, &neighbor) else { + continue; + }; + let Some(rhs_edge) = self.rhs.edge_between(&node, &neighbor) else { + continue; + }; + let ref_bond = reference_state.bond_index(ref_edge).ok_or_else(|| { + anyhow::anyhow!( + "Reference bond index not found for edge {:?}-{:?}", + node, + neighbor + ) + })?; + let rhs_bond = self.rhs.bond_index(rhs_edge).ok_or_else(|| { + anyhow::anyhow!( + "RHS bond index not found for edge {:?}-{:?}", + node, + neighbor + ) + })?; + + if ref_bond == rhs_bond { + return Ok(true); + } + } + } + Ok(false) } /// Ensure environments are computed for neighbors of the region. @@ -157,9 +212,9 @@ where } // Collect child environments - let child_envs: Vec = child_neighbors + let child_envs: Vec<&T> = child_neighbors .iter() - .filter_map(|child| self.envs.get(child, from).cloned()) + .filter_map(|child| self.envs.get(child, from)) .collect(); // Contract bra (reference_state) with ket (RHS) at this node @@ -182,15 +237,15 @@ where let bra_conj = tensor_ref.conj(); // Contract bra and ket - T::contract auto-detects contractable pairs - let bra_ket = T::contract(&[&bra_conj, tensor_b], AllowedPairs::All)?; + let bra_ket = T::contract(&[&bra_conj, tensor_b])?; // Contract bra*ket with child environments using T::contract if child_envs.is_empty() { Ok(bra_ket) } else { let mut all_tensors: Vec<&T> = vec![&bra_ket]; - all_tensors.extend(child_envs.iter()); - T::contract(&all_tensors, AllowedPairs::All) + all_tensors.extend(child_envs); + T::contract(&all_tensors) } } diff --git a/crates/tensor4all-treetn/src/linsolve/square/tests/mod.rs b/crates/tensor4all-treetn/src/linsolve/square/tests/mod.rs index 939b6f94..fa3cea37 100644 --- a/crates/tensor4all-treetn/src/linsolve/square/tests/mod.rs +++ b/crates/tensor4all-treetn/src/linsolve/square/tests/mod.rs @@ -95,7 +95,30 @@ fn test_square_linsolve_zero_sweeps_returns_solution_wrapper() { assert_eq!(result.solution.node_count(), 2); assert_eq!(result.sweeps, 0); - assert_eq!(result.residual, None); + assert!(result.residual.is_some_and(|residual| residual < 1.0e-12)); + assert!(!result.converged); +} + +#[test] +fn test_square_linsolve_can_skip_final_residual() { + let operator = create_simple_2site_mpo(); + let rhs = create_simple_2site_mps(); + let init = create_simple_2site_mps(); + + let result = square_linsolve( + &operator, + &rhs, + init, + &"site0".to_string(), + LinsolveOptions::new(0).with_residual_check(false), + None, + None, + ) + .unwrap(); + + assert_eq!(result.solution.node_count(), 2); + assert_eq!(result.sweeps, 0); + assert!(result.residual.is_none()); assert!(!result.converged); } diff --git a/crates/tensor4all-treetn/src/linsolve/square/updater.rs b/crates/tensor4all-treetn/src/linsolve/square/updater.rs index e2c951c6..8ca9de70 100644 --- a/crates/tensor4all-treetn/src/linsolve/square/updater.rs +++ b/crates/tensor4all-treetn/src/linsolve/square/updater.rs @@ -3,26 +3,30 @@ //! Uses GMRES (via tensor4all_core::krylov) to solve the local linear problem at each sweep step. //! This is the V_in = V_out specialized version. +use std::cell::Cell; use std::collections::HashMap; use std::hash::Hash; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::sync::RwLock; +use std::time::Instant; use anyhow::{Context, Result}; -use tensor4all_core::any_scalar::AnyScalar; -use tensor4all_core::krylov::{gmres, GmresOptions}; -use tensor4all_core::{AllowedPairs, FactorizeOptions, IndexLike, TensorLike}; +use tensor4all_core::krylov::{gmres_affine, gmres_affine_with_absolute_tolerance, GmresOptions}; +use tensor4all_core::{FactorizeOptions, IndexLike, TensorLike}; use super::local_linop::LocalLinOp; use super::projected_state::ProjectedState; -use crate::linsolve::common::{LinsolveOptions, ProjectedOperator}; +use crate::linsolve::common::{GmresToleranceMode, LinsolveOptions, ProjectedOperator}; use crate::operator::IndexMapping; use crate::{ factorize_tensor_to_treetn_with, get_boundary_edges, LocalUpdateStep, LocalUpdater, TreeTN, TreeTopology, }; +static LOCAL_SOLVE_TRACE_COUNTER: AtomicUsize = AtomicUsize::new(0); + /// Report from SquareLinsolveUpdater::verify(). #[derive(Debug, Clone)] pub struct LinsolveVerifyReport { @@ -365,9 +369,9 @@ where }) .collect::>()?; - // Use TensorLike::contract for contraction + // Use TensorContractionLike::contract for contraction let tensor_refs: Vec<&T> = tensors.iter().collect(); - T::contract(&tensor_refs, AllowedPairs::All) + T::contract(&tensor_refs) } /// Build TreeTopology for the subtree region from the solved tensor. @@ -501,13 +505,25 @@ where /// /// Solves: (a₀ + a₁ * H_local) |x_local⟩ = |b_local⟩ fn solve_local(&mut self, region: &[V], init: &T, state: &TreeTN) -> Result { + let solve_index = LOCAL_SOLVE_TRACE_COUNTER.fetch_add(1, Ordering::Relaxed); + let trace_limit = std::env::var("T4A_LINSOLVE_TRACE_LIMIT") + .ok() + .and_then(|value| value.parse::().ok()); + let trace = trace_limit.is_some_and(|limit| solve_index < limit); + let abort_after = std::env::var("T4A_LINSOLVE_ABORT_AFTER") + .ok() + .and_then(|value| value.parse::().ok()); + let solve_started = Instant::now(); + // Use state's SiteIndexNetwork directly (implements NetworkTopology) let topology = state.site_index_network(); // Get local RHS: gmres_affine( + apply_a, + &rhs_local, + init, + self.options.a0.clone(), + self.options.a1.clone(), + &gmres_options, + )?, + GmresToleranceMode::Absolute => gmres_affine_with_absolute_tolerance( + apply_a, + &rhs_local, + init, + self.options.a0.clone(), + self.options.a1.clone(), + &gmres_options, + self.options.gmres_tol, + )?, + }; + + if trace { + eprintln!( + "T4A local_solve #{solve_index}: region={region:?} mode={:?} rhs_norm={:.6e} init_norm={:.6e} iterations={} residual={:.6e} converged={} apply_calls={} apply_ms={:.3} total_ms={:.3}", + self.options.gmres_tolerance_mode, + rhs_local.norm(), + init.norm(), + result.iterations, + result.residual_norm, + result.converged, + apply_calls.get(), + apply_elapsed_micros.get() as f64 / 1000.0, + solve_started.elapsed().as_secs_f64() * 1000.0, + ); + } + if abort_after.is_some_and(|limit| solve_index + 1 >= limit) { + anyhow::bail!("T4A_LINSOLVE_ABORT_AFTER reached after local solve #{solve_index}"); + } Ok(result.solution) } @@ -864,6 +917,48 @@ where ) } + fn replace_reference_boundary_bonds_with_state( + &self, + tensor: T, + region: &[V], + state: &TreeTN, + ) -> Result { + let mut result = tensor; + for node in region { + for neighbor in state.site_index_network().neighbors(node) { + if region.contains(&neighbor) { + continue; + } + + let Some(state_edge) = state.edge_between(node, &neighbor) else { + continue; + }; + let Some(reference_edge) = self.reference_state.edge_between(node, &neighbor) + else { + continue; + }; + let Some(state_bond) = state.bond_index(state_edge) else { + continue; + }; + let Some(reference_bond) = self.reference_state.bond_index(reference_edge) else { + continue; + }; + + if reference_bond == state_bond { + continue; + } + if result + .external_indices() + .iter() + .any(|index| index == reference_bond) + { + result = result.replaceind(reference_bond, state_bond)?; + } + } + } + Ok(result) + } + fn precheck_ref_bra_ket_convention( &mut self, step: &LocalUpdateStep, @@ -873,9 +968,16 @@ where let init_local = self.contract_region(&subtree, &step.nodes)?; let topology = full_treetn_before.site_index_network(); - let rhs_local_raw = - self.projected_state - .local_constant_term(&step.nodes, full_treetn_before, topology)?; + let rhs_local_raw = self.projected_state.local_constant_term( + &step.nodes, + &self.reference_state, + topology, + )?; + let rhs_local_raw = self.replace_reference_boundary_bonds_with_state( + rhs_local_raw, + &step.nodes, + full_treetn_before, + )?; let init_indices = init_local.external_indices(); let rhs_indices = rhs_local_raw.external_indices(); @@ -982,6 +1084,23 @@ where } } +fn local_gmres_options(options: &LinsolveOptions) -> Result { + if options.gmres_restart_dim == 0 { + anyhow::bail!("LinsolveOptions::gmres_restart_dim must be greater than zero"); + } + if options.gmres_max_restarts == 0 { + anyhow::bail!("LinsolveOptions::gmres_max_restarts must be greater than zero"); + } + + Ok(GmresOptions { + max_iter: options.gmres_restart_dim, + rtol: options.gmres_tol, + max_restarts: options.gmres_max_restarts, + verbose: false, + check_true_residual: true, + }) +} + #[cfg(test)] mod tests { use tensor4all_core::{DynIndex, TensorDynLen}; @@ -1001,4 +1120,40 @@ mod tests { assert!(updater.index_sets_match(std::slice::from_ref(&i), std::slice::from_ref(&i))); assert!(!updater.index_sets_match(&[i], &[i_prime])); } + + #[test] + fn local_gmres_options_match_krylovkit_restart_convention() { + let options = LinsolveOptions::default() + .with_gmres_restart_dim(30) + .with_gmres_max_restarts(10) + .with_gmres_tol(1.0e-8); + + let gmres_options = local_gmres_options(&options).unwrap(); + + assert_eq!(gmres_options.max_iter, 30); + assert_eq!(gmres_options.max_restarts, 10); + assert_eq!(gmres_options.rtol, 1.0e-8); + } + + #[test] + fn local_gmres_options_does_not_convert_maxiter_to_total_step_limit() { + let options = LinsolveOptions::default() + .with_gmres_restart_dim(30) + .with_gmres_max_restarts(100); + + let gmres_options = local_gmres_options(&options).unwrap(); + + assert_eq!(gmres_options.max_iter, 30); + assert_eq!(gmres_options.max_restarts, 100); + } + + #[test] + fn local_gmres_options_reject_zero_iteration_parameters() { + assert!( + local_gmres_options(&LinsolveOptions::default().with_gmres_restart_dim(0)).is_err() + ); + assert!( + local_gmres_options(&LinsolveOptions::default().with_gmres_max_restarts(0)).is_err() + ); + } } diff --git a/crates/tensor4all-treetn/src/operator/apply.rs b/crates/tensor4all-treetn/src/operator/apply.rs index 9fb5427e..2bfe6420 100644 --- a/crates/tensor4all-treetn/src/operator/apply.rs +++ b/crates/tensor4all-treetn/src/operator/apply.rs @@ -70,8 +70,8 @@ use std::sync::Arc; use anyhow::{Context, Result}; use tensor4all_core::{ - AllowedPairs, DynIndex, IndexLike, LinearizationOrder, SvdTruncationPolicy, TensorDynLen, - TensorIndex, TensorLike, + DynIndex, IndexLike, LinearizationOrder, SvdTruncationPolicy, TensorDynLen, TensorIndex, + TensorLike, }; use super::index_mapping::IndexMapping; @@ -804,8 +804,21 @@ where ) })?; - let contracted = T::contract(&[state_tensor, mpo_tensor], AllowedPairs::All) - .with_context(|| format!("apply_linear_operator_naive_local: failed at {:?}", node))?; + let contracted = if mpo + .site_space(node) + .is_some_and(|site_space| site_space.is_empty()) + { + state_tensor.contract_pair(mpo_tensor).with_context(|| { + format!( + "apply_linear_operator_naive_local: failed spectator product at {:?}", + node + ) + })? + } else { + T::contract(&[state_tensor, mpo_tensor]).with_context(|| { + format!("apply_linear_operator_naive_local: failed at {:?}", node) + })? + }; tensors_by_node.insert(node.clone(), contracted); } diff --git a/crates/tensor4all-treetn/src/operator/identity.rs b/crates/tensor4all-treetn/src/operator/identity.rs index da7fa33e..f89c1511 100644 --- a/crates/tensor4all-treetn/src/operator/identity.rs +++ b/crates/tensor4all-treetn/src/operator/identity.rs @@ -3,11 +3,11 @@ //! When composing exclusive operators, gap positions (nodes not covered by any operator) //! need identity tensors that pass information through unchanged. //! -//! This module provides convenience wrappers around `TensorLike::delta()`. +//! This module provides convenience wrappers around `TensorConstructionLike::delta()`. use anyhow::Result; -use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; +use tensor4all_core::{DynIndex, TensorConstructionLike, TensorDynLen}; /// Build an identity operator tensor for a gap node. /// @@ -16,7 +16,7 @@ use tensor4all_core::{DynIndex, TensorDynLen, TensorLike}; /// - Each site index `s` gets a primed version `s'` (output index) /// - The tensor is diagonal: `T[s1, s1', s2, s2', ...] = δ_{s1,s1'} × δ_{s2,s2'} × ...` /// -/// This is a convenience wrapper around `TensorDynLen::delta()`. +/// This is a convenience wrapper around `TensorConstructionLike::delta()`. /// /// # Arguments /// @@ -37,7 +37,7 @@ pub fn build_identity_operator_tensor( site_indices: &[DynIndex], output_site_indices: &[DynIndex], ) -> Result { - TensorDynLen::delta(site_indices, output_site_indices) + ::delta(site_indices, output_site_indices) } #[cfg(test)] diff --git a/crates/tensor4all-treetn/src/operator/linear_operator.rs b/crates/tensor4all-treetn/src/operator/linear_operator.rs index 7b1143e2..861323e9 100644 --- a/crates/tensor4all-treetn/src/operator/linear_operator.rs +++ b/crates/tensor4all-treetn/src/operator/linear_operator.rs @@ -20,16 +20,20 @@ //! //! When applying to `x`, it automatically handles the index transformations. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::hash::Hash; use anyhow::Result; -use tensor4all_core::AllowedPairs; +use tensor4all_core::DynIndex; use tensor4all_core::IndexLike; +use tensor4all_core::LinearizationOrder; +use tensor4all_core::TensorDynLen; use tensor4all_core::TensorLike; use super::index_mapping::IndexMapping; +use crate::options::RestructureOptions; +use crate::site_index_network::SiteIndexNetwork; use crate::treetn::TreeTN; /// LinearOperator: Wraps an MPO with index mapping for automatic transformations. @@ -328,14 +332,14 @@ where op_tensor = Some(match op_tensor { None => tensor, - Some(t) => T::contract(&[&t, &tensor], AllowedPairs::All)?, + Some(t) => T::contract(&[&t, &tensor])?, }); } let op_tensor = op_tensor.ok_or_else(|| anyhow::anyhow!("Empty region"))?; // Contract transformed tensor with operator - let contracted = T::contract(&[&transformed, &op_tensor], AllowedPairs::All)?; + let contracted = T::contract(&[&transformed, &op_tensor])?; // Step 3: Replace output indices back to true indices let mut result = contracted; @@ -593,16 +597,417 @@ where output_mapping: self.input_mapping, } } + + /// Restructure the internal MPO while preserving input and output mappings. + /// + /// This is the operator-level counterpart of [`TreeTN::restructure_to`]. + /// The `target` network is expressed in the operator's internal MPO site + /// indices, not in the true input/output indices stored in + /// [`IndexMapping`]. After the MPO is restructured, each mapping is moved + /// to the target node that owns its internal index, preserving the mapping + /// order from the original operator node order. + /// + /// Use this when an operator has the right local indices but its node + /// grouping or topology needs to match another tensor network before + /// selected-index application. + /// + /// # Arguments + /// + /// * `target` - Desired internal MPO site-index network. + /// * `options` - Restructure options passed to [`TreeTN::restructure_to`]. + /// + /// # Returns + /// + /// A new operator whose internal MPO and mapping node keys follow `target`. + /// + /// # Errors + /// + /// Returns an error if the target is incompatible with the internal MPO, or + /// if a mapping's internal index is absent from the target network. + /// + /// # Examples + /// + /// ``` + /// use std::collections::{HashMap, HashSet}; + /// + /// use tensor4all_core::{DynIndex, IndexLike, TensorDynLen}; + /// use tensor4all_treetn::{ + /// IndexMapping, LinearOperator, RestructureOptions, SiteIndexNetwork, TreeTN, + /// }; + /// + /// # fn main() -> anyhow::Result<()> { + /// let x_true = DynIndex::new_dyn(2); + /// let y_true = DynIndex::new_dyn(2); + /// let x_in = DynIndex::new_dyn(2); + /// let x_out = DynIndex::new_dyn(2); + /// let y_in = DynIndex::new_dyn(2); + /// let y_out = DynIndex::new_dyn(2); + /// let bond = DynIndex::new_dyn(1); + /// + /// let left = TensorDynLen::from_dense( + /// vec![x_out.clone(), x_in.clone(), bond.clone()], + /// vec![1.0, 0.0, 0.0, 1.0], + /// )?; + /// let right = TensorDynLen::from_dense( + /// vec![bond, y_out.clone(), y_in.clone()], + /// vec![1.0, 0.0, 0.0, 1.0], + /// )?; + /// let mut mpo = TreeTN::::new(); + /// mpo.add_tensor("x".to_string(), left)?; + /// mpo.add_tensor("y".to_string(), right)?; + /// let x_node = mpo.node_index(&"x".to_string()).unwrap(); + /// let y_node = mpo.node_index(&"y".to_string()).unwrap(); + /// let link = mpo.tensor(x_node).unwrap().indices()[2].clone(); + /// mpo.connect(x_node, &link, y_node, &link)?; + /// + /// let mut input = HashMap::new(); + /// input.insert("x".to_string(), IndexMapping { true_index: x_true, internal_index: x_in }); + /// input.insert("y".to_string(), IndexMapping { true_index: y_true, internal_index: y_in }); + /// let mut output = HashMap::new(); + /// output.insert("x".to_string(), IndexMapping { true_index: DynIndex::new_dyn(2), internal_index: x_out.clone() }); + /// output.insert("y".to_string(), IndexMapping { true_index: DynIndex::new_dyn(2), internal_index: y_out.clone() }); + /// let op = LinearOperator::new(mpo, input, output); + /// + /// let mut target = SiteIndexNetwork::new(); + /// target.add_node("left".to_string(), HashSet::from([y_out, op.get_input_mapping(&"y".to_string()).unwrap().internal_index.clone()]))?; + /// target.add_node("right".to_string(), HashSet::from([x_out, op.get_input_mapping(&"x".to_string()).unwrap().internal_index.clone()]))?; + /// target.add_edge(&"left".to_string(), &"right".to_string())?; + /// + /// let moved = op.restructure_to(&target, &RestructureOptions::default())?; + /// assert!(moved.get_input_mapping(&"left".to_string()).is_some()); + /// assert!(moved.get_input_mapping(&"right".to_string()).is_some()); + /// # Ok(()) + /// # } + /// ``` + pub fn restructure_to( + &self, + target: &SiteIndexNetwork, + options: &RestructureOptions, + ) -> Result> + where + TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug, + { + let mpo = self.mpo.restructure_to(target, options)?; + let input_mapping = self.restructure_mapping_nodes(&self.input_mapping, target, "input")?; + let output_mapping = + self.restructure_mapping_nodes(&self.output_mapping, target, "output")?; + Ok(LinearOperator::new_multi( + mpo, + input_mapping, + output_mapping, + )) + } + + fn restructure_mapping_nodes( + &self, + mappings_by_node: &HashMap>>, + target: &SiteIndexNetwork, + kind: &str, + ) -> Result>>> + where + TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug, + { + let mut result: HashMap>> = HashMap::new(); + let mut nodes: Vec = mappings_by_node.keys().cloned().collect(); + nodes.sort(); + + for node in nodes { + let Some(mappings) = mappings_by_node.get(&node) else { + continue; + }; + for mapping in mappings { + let target_node = target + .find_node_by_index(&mapping.internal_index) + .cloned() + .ok_or_else(|| { + anyhow::anyhow!( + "LinearOperator::restructure_to: {kind} internal index {:?} from node {:?} is missing from target", + mapping.internal_index.id(), + node + ) + })?; + result.entry(target_node).or_default().push(mapping.clone()); + } + } + + Ok(result) + } +} + +impl LinearOperator +where + V: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug, +{ + /// Replace one fused input mapping with several ordered input mappings. + /// + /// The operator's internal input index is exactly unfused inside the MPO + /// tensor using [`TensorDynLen::unfuse_index`], then the corresponding + /// [`IndexMapping`] entry is replaced by one entry per `new_true_indices`. + /// New internal indices are generated automatically with matching + /// dimensions and the same order as `new_true_indices`. + /// + /// Use this for tensorized operators whose local input dimension is a + /// product, such as turning one dimension-4 input leg into two binary input + /// legs before applying the operator to interleaved QTT groups. + /// + /// # Arguments + /// + /// * `old_true_index` - Existing true input index to replace. + /// * `new_true_indices` - Ordered true input indices whose dimensions + /// multiply to `old_true_index.dim()`. + /// * `order` - Linearization convention used to decode the old fused + /// coordinate into the new coordinates. + /// + /// # Returns + /// + /// A new operator with the input mapping unfused. + /// + /// # Errors + /// + /// Returns an error if the input mapping is missing or ambiguous, if the + /// dimensions do not multiply correctly, or if the internal MPO reshape + /// fails. + /// + /// # Examples + /// + /// ``` + /// use std::collections::HashMap; + /// + /// use tensor4all_core::{DynIndex, IndexLike, LinearizationOrder, TensorDynLen}; + /// use tensor4all_treetn::{IndexMapping, LinearOperator, TreeTN}; + /// + /// # fn main() -> anyhow::Result<()> { + /// let x = DynIndex::new_dyn(4); + /// let y = DynIndex::new_dyn(4); + /// let x_internal = DynIndex::new_dyn(4); + /// let y_internal = DynIndex::new_dyn(4); + /// let tensor = TensorDynLen::from_dense( + /// vec![y_internal.clone(), x_internal.clone()], + /// vec![1.0, 0.0, 0.0, 0.0, + /// 0.0, 1.0, 0.0, 0.0, + /// 0.0, 0.0, 1.0, 0.0, + /// 0.0, 0.0, 0.0, 1.0], + /// )?; + /// let mpo = TreeTN::::from_tensors(vec![tensor], vec![0])?; + /// let mut input = HashMap::new(); + /// input.insert(0, IndexMapping { true_index: x.clone(), internal_index: x_internal }); + /// let mut output = HashMap::new(); + /// output.insert(0, IndexMapping { true_index: y, internal_index: y_internal }); + /// let op = LinearOperator::new(mpo, input, output); + /// + /// let x0 = DynIndex::new_dyn(2); + /// let x1 = DynIndex::new_dyn(2); + /// let unfused = op.unfuse_input_index(&x, &[x0.clone(), x1.clone()], LinearizationOrder::ColumnMajor)?; + /// + /// let mappings = unfused.get_input_mappings(&0).unwrap(); + /// assert_eq!(mappings.len(), 2); + /// assert!(mappings[0].true_index.same_id(&x0)); + /// assert!(mappings[1].true_index.same_id(&x1)); + /// # Ok(()) + /// # } + /// ``` + pub fn unfuse_input_index( + &self, + old_true_index: &DynIndex, + new_true_indices: &[DynIndex], + order: LinearizationOrder, + ) -> Result { + self.unfuse_mapping_index(old_true_index, new_true_indices, order, MappingKind::Input) + } + + /// Replace one fused output mapping with several ordered output mappings. + /// + /// This is the output-space counterpart of + /// [`Self::unfuse_input_index`]. The internal MPO output index is exactly + /// reshaped, and the output mapping vector is expanded in the order given + /// by `new_true_indices`. + /// + /// # Arguments + /// + /// * `old_true_index` - Existing true output index to replace. + /// * `new_true_indices` - Ordered true output indices whose dimensions + /// multiply to `old_true_index.dim()`. + /// * `order` - Linearization convention used to decode the old fused + /// coordinate into the new coordinates. + /// + /// # Returns + /// + /// A new operator with the output mapping unfused. + /// + /// # Errors + /// + /// Returns an error if the output mapping is missing or ambiguous, if the + /// dimensions do not multiply correctly, or if the internal MPO reshape + /// fails. + /// + /// # Examples + /// + /// ``` + /// use std::collections::HashMap; + /// + /// use tensor4all_core::{DynIndex, IndexLike, LinearizationOrder, TensorDynLen}; + /// use tensor4all_treetn::{IndexMapping, LinearOperator, TreeTN}; + /// + /// # fn main() -> anyhow::Result<()> { + /// let x = DynIndex::new_dyn(4); + /// let y = DynIndex::new_dyn(4); + /// let x_internal = DynIndex::new_dyn(4); + /// let y_internal = DynIndex::new_dyn(4); + /// let tensor = TensorDynLen::from_dense( + /// vec![y_internal.clone(), x_internal.clone()], + /// vec![1.0, 0.0, 0.0, 0.0, + /// 0.0, 1.0, 0.0, 0.0, + /// 0.0, 0.0, 1.0, 0.0, + /// 0.0, 0.0, 0.0, 1.0], + /// )?; + /// let mpo = TreeTN::::from_tensors(vec![tensor], vec![0])?; + /// let mut input = HashMap::new(); + /// input.insert(0, IndexMapping { true_index: x, internal_index: x_internal }); + /// let mut output = HashMap::new(); + /// output.insert(0, IndexMapping { true_index: y.clone(), internal_index: y_internal }); + /// let op = LinearOperator::new(mpo, input, output); + /// + /// let y0 = DynIndex::new_dyn(2); + /// let y1 = DynIndex::new_dyn(2); + /// let unfused = op.unfuse_output_index(&y, &[y0.clone(), y1.clone()], LinearizationOrder::ColumnMajor)?; + /// + /// let mappings = unfused.get_output_mappings(&0).unwrap(); + /// assert_eq!(mappings.len(), 2); + /// assert!(mappings[0].true_index.same_id(&y0)); + /// assert!(mappings[1].true_index.same_id(&y1)); + /// # Ok(()) + /// # } + /// ``` + pub fn unfuse_output_index( + &self, + old_true_index: &DynIndex, + new_true_indices: &[DynIndex], + order: LinearizationOrder, + ) -> Result { + self.unfuse_mapping_index(old_true_index, new_true_indices, order, MappingKind::Output) + } + + fn unfuse_mapping_index( + &self, + old_true_index: &DynIndex, + new_true_indices: &[DynIndex], + order: LinearizationOrder, + kind: MappingKind, + ) -> Result { + anyhow::ensure!( + !new_true_indices.is_empty(), + "LinearOperator::{kind}: replacement indices must not be empty" + ); + let product = new_true_indices + .iter() + .try_fold(1usize, |acc, index| acc.checked_mul(index.dim())) + .ok_or_else(|| anyhow::anyhow!("LinearOperator::{kind}: dimension product overflow"))?; + anyhow::ensure!( + product == old_true_index.dim(), + "LinearOperator::{kind}: replacement dimension product {} does not match old dimension {}", + product, + old_true_index.dim() + ); + + let (node, position, old_internal_index) = self.find_mapping_entry(old_true_index, kind)?; + let new_internal_indices = new_true_indices + .iter() + .map(|index| DynIndex::new_dyn(index.dim())) + .collect::>(); + + let mpo = self.mpo.replace_site_index_with_indices( + &old_internal_index, + &new_internal_indices, + order, + )?; + + let mut result = self.clone(); + result.mpo = mpo; + + let mappings = match kind { + MappingKind::Input => result.input_mapping.get_mut(&node), + MappingKind::Output => result.output_mapping.get_mut(&node), + } + .ok_or_else(|| { + anyhow::anyhow!( + "LinearOperator::{kind}: mapping node {:?} disappeared during unfuse", + node + ) + })?; + + let replacements = new_true_indices + .iter() + .cloned() + .zip(new_internal_indices) + .map(|(true_index, internal_index)| IndexMapping { + true_index, + internal_index, + }) + .collect::>(); + mappings.splice(position..=position, replacements); + + Ok(result) + } + + fn find_mapping_entry( + &self, + old_true_index: &DynIndex, + kind: MappingKind, + ) -> Result<(V, usize, DynIndex)> { + let mappings_by_node = match kind { + MappingKind::Input => &self.input_mapping, + MappingKind::Output => &self.output_mapping, + }; + let mut nodes: Vec = mappings_by_node.keys().cloned().collect(); + nodes.sort(); + + let mut found = None; + for node in nodes { + let Some(mappings) = mappings_by_node.get(&node) else { + continue; + }; + for (position, mapping) in mappings.iter().enumerate() { + if mapping.true_index == *old_true_index { + if found.is_some() { + return Err(anyhow::anyhow!( + "LinearOperator::{kind}: true index {:?} appears in more than one mapping", + old_true_index.id() + )); + } + found = Some((node.clone(), position, mapping.internal_index.clone())); + } + } + } + + found.ok_or_else(|| { + anyhow::anyhow!( + "LinearOperator::{kind}: true index {:?} not found", + old_true_index.id() + ) + }) + } +} + +#[derive(Clone, Copy)] +enum MappingKind { + Input, + Output, +} + +impl std::fmt::Display for MappingKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Input => f.write_str("unfuse_input_index"), + Self::Output => f.write_str("unfuse_output_index"), + } + } } // ============================================================================ // Helper methods // ============================================================================ -use std::collections::HashSet; - use crate::operator::Operator; -use crate::SiteIndexNetwork; // Implement Operator trait for LinearOperator impl Operator for LinearOperator diff --git a/crates/tensor4all-treetn/src/operator/linear_operator/tests/mod.rs b/crates/tensor4all-treetn/src/operator/linear_operator/tests/mod.rs index 995337e5..d15baabe 100644 --- a/crates/tensor4all-treetn/src/operator/linear_operator/tests/mod.rs +++ b/crates/tensor4all-treetn/src/operator/linear_operator/tests/mod.rs @@ -1,10 +1,13 @@ use super::*; -use std::collections::HashMap; -use tensor4all_core::{DynIndex, IndexLike, TensorDynLen}; +use std::collections::{HashMap, HashSet}; +use tensor4all_core::{ + DynIndex, IndexLike, LinearizationOrder, TensorConstructionLike, TensorDynLen, +}; use crate::operator::index_mapping::IndexMapping; use crate::operator::Operator; use crate::treetn::TreeTN; +use crate::{RestructureOptions, SiteIndexNetwork}; /// Create a simple "MPO" TreeTN with two site indices per node (input + output). /// Structure: single node "A" with indices (s_in_tmp, s_out_tmp) @@ -70,6 +73,54 @@ fn make_linear_operator() -> ( (op, s, s_in_tmp, s_out_tmp) } +fn make_fused_identity_operator() -> ( + LinearOperator, + DynIndex, + DynIndex, + DynIndex, + DynIndex, +) { + let input_true = DynIndex::new_dyn(4); + let output_true = DynIndex::new_dyn(4); + let input_internal = DynIndex::new_dyn(4); + let output_internal = DynIndex::new_dyn(4); + + let mut data = vec![0.0; 16]; + for value in 0..4 { + data[value + 4 * value] = 1.0; + } + let tensor = + TensorDynLen::from_dense(vec![output_internal.clone(), input_internal.clone()], data) + .unwrap(); + let mpo = + TreeTN::::from_tensors(vec![tensor], vec!["A".to_string()]).unwrap(); + + let mut input_mapping = HashMap::new(); + input_mapping.insert( + "A".to_string(), + IndexMapping { + true_index: input_true.clone(), + internal_index: input_internal.clone(), + }, + ); + let mut output_mapping = HashMap::new(); + output_mapping.insert( + "A".to_string(), + IndexMapping { + true_index: output_true.clone(), + internal_index: output_internal.clone(), + }, + ); + + ( + LinearOperator::new(mpo, input_mapping, output_mapping), + input_true, + output_true, + input_internal, + output_internal, + ) +} + #[test] fn test_linear_operator_new() { let (op, _s, _s_in_tmp, _s_out_tmp) = make_linear_operator(); @@ -551,3 +602,107 @@ fn test_linear_operator_transpose_preserves_mpo() { assert_eq!(transposed.mpo().node_count(), original_node_count); } + +#[test] +fn test_unfuse_input_and_output_indices_splits_internal_mpo_axes() { + let (op, input_true, output_true, _input_internal, _output_internal) = + make_fused_identity_operator(); + let input0 = DynIndex::new_dyn(2); + let input1 = DynIndex::new_dyn(2); + let output0 = DynIndex::new_dyn(2); + let output1 = DynIndex::new_dyn(2); + + let op = op + .unfuse_output_index( + &output_true, + &[output0.clone(), output1.clone()], + LinearizationOrder::ColumnMajor, + ) + .unwrap() + .unfuse_input_index( + &input_true, + &[input0.clone(), input1.clone()], + LinearizationOrder::ColumnMajor, + ) + .unwrap(); + + let input_mappings = op.get_input_mappings(&"A".to_string()).unwrap(); + assert_eq!(input_mappings.len(), 2); + assert!(input_mappings[0].true_index.same_id(&input0)); + assert!(input_mappings[1].true_index.same_id(&input1)); + + let output_mappings = op.get_output_mappings(&"A".to_string()).unwrap(); + assert_eq!(output_mappings.len(), 2); + assert!(output_mappings[0].true_index.same_id(&output0)); + assert!(output_mappings[1].true_index.same_id(&output1)); + + let output_internal = output_mappings + .iter() + .map(|mapping| mapping.internal_index.clone()) + .collect::>(); + let input_internal = input_mappings + .iter() + .map(|mapping| mapping.internal_index.clone()) + .collect::>(); + let expected = + ::delta(&output_internal, &input_internal).unwrap(); + let actual = op.mpo().contract_to_tensor().unwrap(); + + assert!(actual.distance(&expected).unwrap() < 1.0e-12); +} + +#[test] +fn test_linear_operator_restructure_to_moves_mapping_nodes() { + let (op, s0, s1) = make_two_node_mpo_and_operator(); + let node_a = "A".to_string(); + let node_b = "B".to_string(); + let a_in = op + .get_input_mapping(&node_a) + .unwrap() + .internal_index + .clone(); + let a_out = op + .get_output_mapping(&node_a) + .unwrap() + .internal_index + .clone(); + let b_in = op + .get_input_mapping(&node_b) + .unwrap() + .internal_index + .clone(); + let b_out = op + .get_output_mapping(&node_b) + .unwrap() + .internal_index + .clone(); + + let mut target = SiteIndexNetwork::::new(); + target + .add_node("left".to_string(), HashSet::from([b_in, b_out])) + .unwrap(); + target + .add_node("right".to_string(), HashSet::from([a_in, a_out])) + .unwrap(); + target + .add_edge(&"left".to_string(), &"right".to_string()) + .unwrap(); + + let expected = op.mpo().contract_to_tensor().unwrap(); + let moved = op + .restructure_to(&target, &RestructureOptions::default()) + .unwrap(); + let actual = moved.mpo().contract_to_tensor().unwrap(); + + assert!(actual.distance(&expected).unwrap() < 1.0e-12); + assert!(moved + .get_input_mapping(&"left".to_string()) + .unwrap() + .true_index + .same_id(&s1)); + assert!(moved + .get_input_mapping(&"right".to_string()) + .unwrap() + .true_index + .same_id(&s0)); +} diff --git a/crates/tensor4all-treetn/src/simplett_bridge.rs b/crates/tensor4all-treetn/src/simplett_bridge.rs index 08a26fa7..f07bbe54 100644 --- a/crates/tensor4all-treetn/src/simplett_bridge.rs +++ b/crates/tensor4all-treetn/src/simplett_bridge.rs @@ -1,7 +1,7 @@ use anyhow::{ensure, Result}; use tensor4all_core::{DynIndex, IndexLike, TensorDynLen, TensorElement}; use tensor4all_simplett::{ - tensor3_from_data, AbstractTensorTrain, TTScalar, Tensor3Ops, TensorTrain, + tensor3_from_data, tensor3_zeros, AbstractTensorTrain, TTScalar, Tensor3Ops, TensorTrain, }; use crate::TreeTN; @@ -47,7 +47,7 @@ where /// # Examples /// /// ``` -/// use tensor4all_simplett::{tensor3_from_data, TensorTrain}; +/// use tensor4all_simplett::{tensor3_from_data, AbstractTensorTrain, TensorTrain}; /// use tensor4all_treetn::tensor_train_to_treetn_with_names; /// /// let tt = TensorTrain::new(vec![ @@ -249,6 +249,419 @@ where Ok(TensorTrain::new(tensors)?) } +/// Insert a one-hot site into a linear-chain `TreeTN`. +/// +/// The input and output chains use node names `0..n-1` / `0..n`. `position` +/// chooses the insertion point: +/// - `0` inserts before the first site; +/// - `len` inserts after the last site; +/// - any `1..len` inserts on the chain edge between `position - 1` and +/// `position`. +/// +/// The inserted site is fixed to `value`, so the new tensor evaluates to the +/// old tensor when the inserted coordinate equals `value` and to zero +/// otherwise. Existing site indices are preserved in order. +/// +/// # Arguments +/// +/// - `treetn`: a linear-chain `TreeTN` whose node names are exactly `0..n`. +/// Each node must carry exactly one site index. +/// - `position`: insertion point in `0..=n`; `0` prepends and `n` appends. +/// - `site_index`: site index to attach to the inserted node. +/// - `value`: fixed coordinate for the inserted one-hot site. +/// +/// # Returns +/// +/// A new linear-chain `TreeTN` with node names `0..=n`. Existing site indices +/// keep their relative order and the inserted `site_index` appears at +/// `position`. +/// +/// # Errors +/// +/// Returns an error if `position` is out of range, `value` is outside +/// `site_index`, the input is not a numbered single-site chain, or the +/// conversion between the chain `TreeTN` and `TensorTrain` fails. +/// +/// # Examples +/// +/// ``` +/// use tensor4all_core::DynIndex; +/// use tensor4all_simplett::{tensor3_from_data, AbstractTensorTrain, TensorTrain}; +/// use tensor4all_treetn::{ +/// insert_onehot_site_in_treetn_chain, tensor_train_to_treetn, treetn_to_tensor_train, +/// }; +/// +/// # fn main() -> anyhow::Result<()> { +/// let tt = TensorTrain::new(vec![ +/// tensor3_from_data(vec![1.0_f64, 2.0], 1, 2, 1)?, +/// ])?; +/// let (tree, _) = tensor_train_to_treetn(&tt)?; +/// let fixed = DynIndex::new_dyn(2); +/// +/// let extended = insert_onehot_site_in_treetn_chain::(tree, 0, fixed, 0)?; +/// let roundtrip = treetn_to_tensor_train::(extended)?; +/// +/// assert_eq!(roundtrip.site_dims(), vec![2, 2]); +/// # Ok(()) +/// # } +/// ``` +pub fn insert_onehot_site_in_treetn_chain( + treetn: TreeTN, + position: usize, + site_index: DynIndex, + value: usize, +) -> Result> +where + T: TTScalar + TensorElement + Clone + Default, +{ + let old_site_indices = chain_site_indices(&treetn, "insert_onehot_site_in_treetn_chain")?; + ensure!( + position <= old_site_indices.len(), + "insert_onehot_site_in_treetn_chain: position {} is out of range 0..={}", + position, + old_site_indices.len() + ); + ensure!( + value < site_index.dim(), + "insert_onehot_site_in_treetn_chain: fixed value {} exceeds site dimension {}", + value, + site_index.dim() + ); + + let tt = treetn_to_tensor_train::(treetn)?; + let mut tensors = Vec::with_capacity(tt.len() + 1); + for site in 0..position { + tensors.push(tt.site_tensor(site).clone()); + } + + let bond_dim = if tt.is_empty() || position == 0 { + 1 + } else { + tt.site_tensor(position - 1).right_dim() + }; + let mut inserted = tensor3_zeros::(bond_dim, site_index.dim(), bond_dim); + for bond in 0..bond_dim { + inserted.set3(bond, value, bond, T::one()); + } + tensors.push(inserted); + + for site in position..tt.len() { + tensors.push(tt.site_tensor(site).clone()); + } + + let mut site_indices = old_site_indices; + site_indices.insert(position, site_index); + let tt = TensorTrain::new(tensors)?; + tensor_train_to_treetn_with_names_and_site_indices(&tt, (0..tt.len()).collect(), site_indices) +} + +/// Fix a site in a linear-chain `TreeTN` and remove it. +/// +/// `position` selects the chain site to remove, and `value` selects the local +/// coordinate kept at that site. The returned chain has one fewer site, node +/// names `0..n-2`, and all remaining site indices preserved in their original +/// order. +/// +/// # Arguments +/// +/// - `treetn`: a linear-chain `TreeTN` whose node names are exactly `0..n`. +/// Each node must carry exactly one site index. +/// - `position`: chain site to fix and remove, in `0..n`. +/// - `value`: local coordinate retained at the removed site. +/// +/// # Returns +/// +/// A new linear-chain `TreeTN` representing the original tensor restricted to +/// `site[position] == value`, with that site removed from the external index +/// list. +/// +/// # Errors +/// +/// Returns an error if `position` is out of range, `value` is outside the site +/// dimension, removing the only site would require an unsupported scalar +/// zero-site `TreeTN`, the input is not a numbered single-site chain, or the +/// conversion between the chain `TreeTN` and `TensorTrain` fails. +/// +/// # Examples +/// +/// ``` +/// use tensor4all_simplett::{tensor3_from_data, AbstractTensorTrain, TensorTrain}; +/// use tensor4all_treetn::{ +/// fix_and_remove_site_from_treetn_chain, tensor_train_to_treetn, treetn_to_tensor_train, +/// }; +/// +/// # fn main() -> anyhow::Result<()> { +/// let tt = TensorTrain::new(vec![ +/// tensor3_from_data(vec![1.0_f64, 2.0], 1, 2, 1)?, +/// tensor3_from_data(vec![10.0_f64, 20.0], 1, 2, 1)?, +/// ])?; +/// let (tree, _) = tensor_train_to_treetn(&tt)?; +/// +/// let reduced = fix_and_remove_site_from_treetn_chain::(tree, 0, 1)?; +/// let roundtrip = treetn_to_tensor_train::(reduced)?; +/// +/// assert_eq!(roundtrip.site_dims(), vec![2]); +/// assert!((roundtrip.evaluate(&[0])? - 20.0).abs() < 1.0e-12); +/// assert!((roundtrip.evaluate(&[1])? - 40.0).abs() < 1.0e-12); +/// # Ok(()) +/// # } +/// ``` +pub fn fix_and_remove_site_from_treetn_chain( + treetn: TreeTN, + position: usize, + value: usize, +) -> Result> +where + T: TTScalar + TensorElement + Clone + Default, +{ + let site_indices = chain_site_indices(&treetn, "fix_and_remove_site_from_treetn_chain")?; + ensure!( + position < site_indices.len(), + "fix_and_remove_site_from_treetn_chain: position {} is out of range 0..{}", + position, + site_indices.len() + ); + ensure!( + site_indices.len() > 1, + "fix_and_remove_site_from_treetn_chain: cannot remove the only site because scalar zero-site TreeTN chains are not supported" + ); + + let tt = treetn_to_tensor_train::(treetn)?; + ensure!( + value < tt.site_dim(position), + "fix_and_remove_site_from_treetn_chain: fixed value {} exceeds site dimension {}", + value, + tt.site_dim(position) + ); + + let reduced_site = fixed_site_matrix(tt.site_tensor(position), value); + remove_site_with_reduced_matrix(tt, site_indices, position, &reduced_site) +} + +/// Contract a site of a linear-chain `TreeTN` with weights and remove it. +/// +/// The removed site is summed as `sum_s weights[s] * tensor[..., s, ...]`. +/// Pass already scaled weights, such as `1/d` averaging weights, when a +/// normalized reduction is desired. The returned chain has one fewer site, node +/// names `0..n-2`, and all remaining site indices preserved in their original +/// order. +/// +/// # Arguments +/// +/// - `treetn`: a linear-chain `TreeTN` whose node names are exactly `0..n`. +/// Each node must carry exactly one site index. +/// - `position`: chain site to contract and remove, in `0..n`. +/// - `weights`: one weight per local coordinate of the removed site. +/// +/// # Returns +/// +/// A new linear-chain `TreeTN` representing the weighted local reduction with +/// the selected site removed from the external index list. +/// +/// # Errors +/// +/// Returns an error if `position` is out of range, `weights.len()` does not +/// match the removed site dimension, removing the only site would require an +/// unsupported scalar zero-site `TreeTN`, the input is not a numbered +/// single-site chain, or the conversion between the chain `TreeTN` and +/// `TensorTrain` fails. +/// +/// # Examples +/// +/// ``` +/// use tensor4all_simplett::{tensor3_from_data, AbstractTensorTrain, TensorTrain}; +/// use tensor4all_treetn::{ +/// tensor_train_to_treetn, treetn_to_tensor_train, weighted_remove_site_from_treetn_chain, +/// }; +/// +/// # fn main() -> anyhow::Result<()> { +/// let tt = TensorTrain::new(vec![ +/// tensor3_from_data(vec![1.0_f64, 3.0], 1, 2, 1)?, +/// tensor3_from_data(vec![2.0_f64, 4.0], 1, 2, 1)?, +/// ])?; +/// let (tree, _) = tensor_train_to_treetn(&tt)?; +/// +/// let reduced = weighted_remove_site_from_treetn_chain::(tree, 0, &[0.25, 0.75])?; +/// let roundtrip = treetn_to_tensor_train::(reduced)?; +/// +/// assert_eq!(roundtrip.site_dims(), vec![2]); +/// assert!((roundtrip.evaluate(&[0])? - 5.0).abs() < 1.0e-12); +/// assert!((roundtrip.evaluate(&[1])? - 10.0).abs() < 1.0e-12); +/// # Ok(()) +/// # } +/// ``` +pub fn weighted_remove_site_from_treetn_chain( + treetn: TreeTN, + position: usize, + weights: &[T], +) -> Result> +where + T: TTScalar + TensorElement + Clone + Default, +{ + let site_indices = chain_site_indices(&treetn, "weighted_remove_site_from_treetn_chain")?; + ensure!( + position < site_indices.len(), + "weighted_remove_site_from_treetn_chain: position {} is out of range 0..{}", + position, + site_indices.len() + ); + ensure!( + site_indices.len() > 1, + "weighted_remove_site_from_treetn_chain: cannot remove the only site because scalar zero-site TreeTN chains are not supported" + ); + + let tt = treetn_to_tensor_train::(treetn)?; + ensure!( + weights.len() == tt.site_dim(position), + "weighted_remove_site_from_treetn_chain: weights length {} must match site dimension {}", + weights.len(), + tt.site_dim(position) + ); + + let reduced_site = weighted_site_matrix(tt.site_tensor(position), weights); + remove_site_with_reduced_matrix(tt, site_indices, position, &reduced_site) +} + +fn chain_site_indices( + treetn: &TreeTN, + context: &str, +) -> Result> { + let nsites = treetn.node_count(); + let mut node_names = treetn.node_names(); + node_names.sort_unstable(); + ensure!( + node_names == (0..nsites).collect::>(), + "{context}: expected node names 0..{}, got {:?}", + nsites, + node_names + ); + + let mut site_indices = Vec::with_capacity(nsites); + for site in 0..nsites { + let site_space = treetn + .site_space(&site) + .ok_or_else(|| anyhow::anyhow!("{context}: missing site space at node {site}"))?; + ensure!( + site_space.len() == 1, + "{context}: node {site} must have exactly one site index, got {}", + site_space.len() + ); + let site_index = site_space + .iter() + .next() + .ok_or_else(|| anyhow::anyhow!("{context}: node {site} has no site index"))?; + site_indices.push(site_index.clone()); + } + Ok(site_indices) +} + +fn fixed_site_matrix(tensor: &tensor4all_simplett::Tensor3, value: usize) -> Vec +where + T: TTScalar + Clone + Default, +{ + tensor.slice_site(value) +} + +fn weighted_site_matrix(tensor: &tensor4all_simplett::Tensor3, weights: &[T]) -> Vec +where + T: TTScalar + Clone + Default, +{ + let mut matrix = vec![T::default(); tensor.left_dim() * tensor.right_dim()]; + for (s, &weight) in weights.iter().enumerate() { + for r in 0..tensor.right_dim() { + for l in 0..tensor.left_dim() { + let offset = l + tensor.left_dim() * r; + matrix[offset] = matrix[offset] + *tensor.get3(l, s, r) * weight; + } + } + } + matrix +} + +fn remove_site_with_reduced_matrix( + tt: TensorTrain, + mut site_indices: Vec, + position: usize, + reduced_site: &[T], +) -> Result> +where + T: TTScalar + TensorElement + Clone + Default, +{ + let mut tensors = Vec::with_capacity(tt.len().saturating_sub(1)); + for site in 0..position { + if site + 1 == position && position + 1 == tt.len() { + tensors.push(absorb_reduced_site_into_left( + tt.site_tensor(site), + reduced_site, + )); + } else { + tensors.push(tt.site_tensor(site).clone()); + } + } + + if position + 1 < tt.len() { + tensors.push(absorb_reduced_site_into_right( + reduced_site, + tt.site_tensor(position), + tt.site_tensor(position + 1), + )); + for site in position + 2..tt.len() { + tensors.push(tt.site_tensor(site).clone()); + } + } + + site_indices.remove(position); + let tt = TensorTrain::new(tensors)?; + tensor_train_to_treetn_with_names_and_site_indices(&tt, (0..tt.len()).collect(), site_indices) +} + +fn absorb_reduced_site_into_right( + reduced_site: &[T], + removed: &tensor4all_simplett::Tensor3, + right: &tensor4all_simplett::Tensor3, +) -> tensor4all_simplett::Tensor3 +where + T: TTScalar + Clone + Default, +{ + let left_dim = removed.left_dim(); + let removed_right_dim = removed.right_dim(); + let mut tensor = tensor3_zeros::(left_dim, right.site_dim(), right.right_dim()); + + for l in 0..left_dim { + for s in 0..right.site_dim() { + for r in 0..right.right_dim() { + let mut value = T::default(); + for bridge in 0..removed_right_dim { + value = value + reduced_site[l + left_dim * bridge] * *right.get3(bridge, s, r); + } + tensor.set3(l, s, r, value); + } + } + } + tensor +} + +fn absorb_reduced_site_into_left( + left: &tensor4all_simplett::Tensor3, + reduced_site: &[T], +) -> tensor4all_simplett::Tensor3 +where + T: TTScalar + Clone + Default, +{ + let mut tensor = tensor3_zeros::(left.left_dim(), left.site_dim(), 1); + + for l in 0..left.left_dim() { + for s in 0..left.site_dim() { + let mut value = T::default(); + for (bridge, &weight) in reduced_site.iter().enumerate().take(left.right_dim()) { + value = value + *left.get3(l, s, bridge) * weight; + } + tensor.set3(l, s, 0, value); + } + } + tensor +} + struct ChainSiteMetadata { site_index: DynIndex, left_bond: Option, diff --git a/crates/tensor4all-treetn/src/treetn/addition.rs b/crates/tensor4all-treetn/src/treetn/addition.rs index 6a0f43fd..86744774 100644 --- a/crates/tensor4all-treetn/src/treetn/addition.rs +++ b/crates/tensor4all-treetn/src/treetn/addition.rs @@ -137,17 +137,23 @@ where self.replaceinds(&old_indices, &new_indices) } - /// Add two TreeTNs after aligning the second operand's site index IDs to the first. + /// Add two TreeTNs after reindexing the second operand's site space to match `self`. /// - /// This is useful when two states share the same topology and site dimensions - /// but were constructed with different site index IDs. + /// This method first calls [`reindex_site_space_like`](Self::reindex_site_space_like) + /// on `other`, then performs strict direct-sum addition. Use it only when + /// the node-by-node site-space matching is the intended operation. Use + /// [`add`](Self::add) when site indices must already match exactly. /// /// # Arguments - /// * `other` - The other TreeTN to align and add + /// * `other` - The other TreeTN to reindex and add /// /// # Returns /// The direct-sum addition result with site IDs matching `self`. /// + /// # Errors + /// Returns an error if the two TreeTNs cannot be reindexed to the same site + /// space, or if the strict addition fails after reindexing. + /// /// # Examples /// ``` /// use tensor4all_core::{DynIndex, TensorDynLen}; @@ -162,11 +168,11 @@ where /// let state_a = make_chain(DynIndex::new_dyn(2), DynIndex::new_dyn(2)); /// let state_b = make_chain(DynIndex::new_dyn(2), DynIndex::new_dyn(2)); /// - /// let sum = state_a.add_aligned(&state_b).unwrap(); + /// let sum = state_a.add_reindexed_like_self(&state_b).unwrap(); /// assert_eq!(sum.node_count(), 2); /// assert!(sum.share_equivalent_site_index_network(&state_a)); /// ``` - pub fn add_aligned(&self, other: &Self) -> Result + pub fn add_reindexed_like_self(&self, other: &Self) -> Result where V: Ord, T::Index: Clone, @@ -257,7 +263,7 @@ where // Create merged bond index using direct_sum on dummy tensors // For now, we just store dimensions; the actual merged index will be - // created during the direct sum operation using TensorLike::direct_sum + // created during the direct sum operation using TensorContractionLike::direct_sum // // Note: We need a way to create a new index with dim_a + dim_b. // This requires the TensorLike implementation to handle index creation. @@ -290,7 +296,8 @@ where /// This creates a new TreeTN where each tensor is the direct sum of the /// corresponding tensors from self and other, with bond dimensions merged. /// The two networks must share the same topology **and** the same site - /// index IDs. Use [`add_aligned`](Self::add_aligned) if site index IDs differ. + /// index IDs. Use [`add_reindexed_like_self`](Self::add_reindexed_like_self) + /// if site index IDs differ and implicit node-by-node reindexing is intended. /// /// # Arguments /// * `other` - The other TreeTN to add @@ -320,6 +327,7 @@ where pub fn add(&self, other: &Self) -> Result where V: Ord, + T::Index: Eq + Hash, ::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync, { @@ -329,6 +337,7 @@ where "Cannot add TreeTNs with different topologies" )); } + self.ensure_same_site_spaces(other)?; // Track merged indices for each edge. // Key: (smaller_node_name, larger_node_name) for canonical ordering @@ -430,6 +439,80 @@ where // Build result TreeTN TreeTN::from_tensors(result_tensors, result_node_names) } + + /// Compute a strict linear combination `a * self + b * other`. + /// + /// This scales two TreeTNs and adds them with [`add`](Self::add), so the two + /// operands must already have the same topology and the same site index IDs. + /// No site reindexing or truncation is performed. + /// + /// # Arguments + /// * `a` - Scalar coefficient multiplying `self`. + /// * `other` - Second TreeTN in the same site-index space. + /// * `b` - Scalar coefficient multiplying `other`. + /// + /// # Returns + /// A direct-sum TreeTN representing `a * self + b * other`. + /// + /// # Errors + /// Returns an error if scaling fails, or if strict direct-sum addition fails + /// because topology or site indices do not match. + /// + /// # Examples + /// ``` + /// use tensor4all_core::{AnyScalar, DynIndex, TensorDynLen}; + /// use tensor4all_treetn::TreeTN; + /// + /// let site = DynIndex::new_dyn(2); + /// let a = TensorDynLen::from_dense(vec![site.clone()], vec![1.0_f64, 2.0]).unwrap(); + /// let b = TensorDynLen::from_dense(vec![site.clone()], vec![3.0_f64, 4.0]).unwrap(); + /// let tn_a = TreeTN::<_, usize>::from_tensors(vec![a], vec![0]).unwrap(); + /// let tn_b = TreeTN::<_, usize>::from_tensors(vec![b], vec![0]).unwrap(); + /// + /// let result = tn_a + /// .axpby(AnyScalar::new_real(2.0), &tn_b, AnyScalar::new_real(-1.0)) + /// .unwrap(); + /// let dense = result.contract_to_tensor().unwrap(); + /// let expected = TensorDynLen::from_dense(vec![site], vec![-1.0, 0.0]).unwrap(); + /// assert!(dense.isapprox(&expected, 1.0e-12, 0.0)); + /// ``` + pub fn axpby(&self, a: A, other: &Self, b: B) -> Result + where + V: Ord, + T::Index: Eq + Hash, + ::Id: + Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync, + A: Into, + B: Into, + { + let mut lhs = self.clone(); + lhs.scale(a.into())?; + let mut rhs = other.clone(); + rhs.scale(b.into())?; + lhs.add(&rhs) + } + + fn ensure_same_site_spaces(&self, other: &Self) -> Result<()> + where + V: Ord, + T::Index: Eq + Hash, + { + for node_name in self.node_names() { + let self_site_space = self + .site_space(&node_name) + .ok_or_else(|| anyhow::anyhow!("site space not found for node {:?}", node_name))?; + let other_site_space = other + .site_space(&node_name) + .ok_or_else(|| anyhow::anyhow!("site space not found for node {:?}", node_name))?; + if self_site_space != other_site_space { + bail!( + "Cannot add TreeTNs with different site indices at node {:?}", + node_name + ); + } + } + Ok(()) + } } #[cfg(test)] diff --git a/crates/tensor4all-treetn/src/treetn/addition/tests/mod.rs b/crates/tensor4all-treetn/src/treetn/addition/tests/mod.rs index 4b380ca3..55700c7f 100644 --- a/crates/tensor4all-treetn/src/treetn/addition/tests/mod.rs +++ b/crates/tensor4all-treetn/src/treetn/addition/tests/mod.rs @@ -1,5 +1,5 @@ use std::collections::HashSet; -use tensor4all_core::{DynIndex, IndexLike, TensorDynLen, TensorIndex, TensorLike}; +use tensor4all_core::{DynIndex, IndexLike, TensorDynLen, TensorIndex}; use crate::treetn::TreeTN; @@ -330,13 +330,13 @@ fn test_reindex_site_space_like_rejects_incompatible_inputs() { } #[test] -fn test_add_aligned_accepts_equivalent_site_space_with_different_ids() { +fn test_add_reindexed_like_self_accepts_equivalent_site_space_with_different_ids() { let (tn_a, tn_b) = make_two_matching_treetns_different_site_ids(); let tn_b_aligned = tn_b.reindex_site_space_like(&tn_a).unwrap(); assert!(!tn_a.share_equivalent_site_index_network(&tn_b)); - let sum = tn_a.add_aligned(&tn_b).unwrap(); + let sum = tn_a.add_reindexed_like_self(&tn_b).unwrap(); assert!(sum.share_equivalent_site_index_network(&tn_a)); let dense_sum = sum.contract_to_tensor().unwrap(); @@ -351,11 +351,27 @@ fn test_add_aligned_accepts_equivalent_site_space_with_different_ids() { .unwrap(); assert!( dense_sum.isapprox(&dense_expected, 1e-10, 0.0), - "add_aligned failed: maxabs diff = {}", + "add_reindexed_like_self failed: maxabs diff = {}", dense_sum.distance(&dense_expected).unwrap() ); } +#[test] +fn test_axpby_is_strict_about_site_index_ids() { + let (tn_a, tn_b) = make_two_matching_treetns_different_site_ids(); + let err = tn_a + .axpby( + tensor4all_core::AnyScalar::new_real(1.0), + &tn_b, + tensor4all_core::AnyScalar::new_real(1.0), + ) + .unwrap_err(); + assert!( + err.to_string().contains("site") || err.to_string().contains("index"), + "unexpected error: {err}" + ); +} + #[test] fn test_add_preserves_same_id_prime_pair_site_indices() { let (inputs, outputs, links) = make_same_id_prime_pair_mpo_like_indices(); diff --git a/crates/tensor4all-treetn/src/treetn/cached_evaluator.rs b/crates/tensor4all-treetn/src/treetn/cached_evaluator.rs index 1a3d559b..0871e8ac 100644 --- a/crates/tensor4all-treetn/src/treetn/cached_evaluator.rs +++ b/crates/tensor4all-treetn/src/treetn/cached_evaluator.rs @@ -7,8 +7,8 @@ use std::hash::Hash; use anyhow::{bail, Context, Result}; use num_complex::Complex64; use tensor4all_core::{ - contract_multi_with_options, AllowedPairs, AnyScalar, ColMajorArrayRef, ContractionOptions, - DynIndex, IndexLike, TensorDynLen, TensorIndex, TensorLike, + contract_with_options, AnyScalar, ColMajorArrayRef, ContractionOptions, DynIndex, IndexLike, + TensorContractionLike, TensorDynLen, TensorIndex, TensorLike, }; use super::TreeTN; @@ -960,9 +960,9 @@ where } let retain = [assignment_index.clone()]; - let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain); + let options = ContractionOptions::new().with_retain_indices(&retain); let operand_refs = operands.iter().collect::>(); - let tensor = contract_multi_with_options(&operand_refs, options).context( + let tensor = contract_with_options(&operand_refs, options).context( "TreeTNCachedEvaluator::evaluate_batch: failed to contract batched directed message", )?; let tensor = ensure_assignment_axis_last(tensor, &assignment_index)?; @@ -1040,9 +1040,9 @@ where operands.remove(0) } else { let retain = [point_index.clone()]; - let options = ContractionOptions::new(AllowedPairs::All).with_retain_indices(&retain); + let options = ContractionOptions::new().with_retain_indices(&retain); let operand_refs = operands.iter().collect::>(); - contract_multi_with_options(&operand_refs, options) + contract_with_options(&operand_refs, options) .context("TreeTNCachedEvaluator::evaluate_batch: failed to contract center batch")? }; let result_tensor = ensure_assignment_axis_last(result_tensor, &point_index)?; diff --git a/crates/tensor4all-treetn/src/treetn/contraction.rs b/crates/tensor4all-treetn/src/treetn/contraction.rs index 115d4932..f460baae 100644 --- a/crates/tensor4all-treetn/src/treetn/contraction.rs +++ b/crates/tensor4all-treetn/src/treetn/contraction.rs @@ -15,12 +15,18 @@ use anyhow::{Context, Result}; use crate::algorithm::CanonicalForm; use tensor4all_core::{ - AllowedPairs, Canonical, FactorizeAlg, FactorizeOptions, IndexLike, SvdTruncationPolicy, - TensorIndex, TensorLike, + Canonical, FactorizeAlg, FactorizeOptions, IndexLike, SvdTruncationPolicy, TensorIndex, + TensorLike, }; use super::TreeTN; +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum ZipupTopologyMode { + PruneScalarSubtrees, + PreserveInputTopology, +} + impl TreeTN where T: TensorLike, @@ -186,7 +192,7 @@ where // Contract and store result at `to` // (bond indices are auto-detected via is_contractable) - let contracted = T::contract(&[&to_tensor, &from_tensor], AllowedPairs::All) + let contracted = T::contract(&[&to_tensor, &from_tensor]) .context("Failed to contract along edge")?; tensors.insert(to, contracted); } @@ -270,9 +276,23 @@ where self.contract_zipup_with(other, center, CanonicalForm::Unitary, svd_policy, max_rank) } - /// Contract two TreeTNs with the same topology using the zip-up algorithm with a specified form. + /// Contract two TreeTNs with the same topology using zip-up and a specified canonical form. + /// + /// # Algorithm + /// 1. Process leaves: contract `A[leaf] * B[leaf]`, factorize, store R at parent + /// 2. Process internal nodes: contract incoming factors with `A[node]` and `B[node]`, factorize, store R\_new at parent + /// 3. Process root: contract `[R_list..., A[root], B[root]]`, store as final result + /// 4. Set canonical center + /// + /// # Arguments + /// * `other` - The other TreeTN to contract with (must have same topology) + /// * `center` - The center node name towards which to contract + /// * `form` - Canonical form (Unitary/LU/CI) + /// * `svd_policy` - Optional SVD truncation policy + /// * `max_rank` - Optional maximum bond dimension /// - /// See [`contract_zipup`](Self::contract_zipup) for details. + /// # Returns + /// The contracted TreeTN result, or an error if topologies don't match or contraction fails. pub fn contract_zipup_with( &self, other: &Self, @@ -286,37 +306,47 @@ where ::Id: Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync, { - self.contract_zipup_tree_accumulated(other, center, form, svd_policy, max_rank) + self.contract_zipup_impl( + other, + center, + form, + svd_policy, + max_rank, + ZipupTopologyMode::PruneScalarSubtrees, + ) } - /// Contract two TreeTNs using zip-up algorithm with accumulated intermediate tensors. - /// - /// This is an improved version of zip-up contraction that maintains intermediate tensors - /// (environment tensors) as it processes from leaves towards the root, similar to - /// ITensors.jl's MPO zip-up algorithm. - /// - /// # Algorithm - /// 1. Process leaves: contract `A[leaf] * B[leaf]`, factorize, store R at parent - /// 2. Process internal nodes: contract `[R_accumulated..., A[node], B[node]]`, factorize, store R\_new at parent - /// 3. Process root: contract `[R_list..., A[root], B[root]]`, store as final result - /// 4. Set canonical center - /// - /// # Arguments - /// * `other` - The other TreeTN to contract with (must have same topology) - /// * `center` - The center node name towards which to contract - /// * `form` - Canonical form (Unitary/LU/CI) - /// * `svd_policy` - Optional SVD truncation policy - /// * `max_rank` - Optional maximum bond dimension - /// - /// # Returns - /// The contracted TreeTN result, or an error if topologies don't match or contraction fails. - pub fn contract_zipup_tree_accumulated( + pub(crate) fn contract_zipup_preserving_topology_with( + &self, + other: &Self, + center: &V, + form: CanonicalForm, + svd_policy: Option, + max_rank: Option, + ) -> Result + where + V: Ord, + ::Id: + Clone + std::hash::Hash + Eq + Ord + std::fmt::Debug + Send + Sync, + { + self.contract_zipup_impl( + other, + center, + form, + svd_policy, + max_rank, + ZipupTopologyMode::PreserveInputTopology, + ) + } + + fn contract_zipup_impl( &self, other: &Self, center: &V, form: CanonicalForm, svd_policy: Option, max_rank: Option, + topology_mode: ZipupTopologyMode, ) -> Result where V: Ord, @@ -326,7 +356,7 @@ where // 1. Verify topologies are compatible if !self.same_topology(other) { return Err(anyhow::anyhow!( - "contract_zipup_tree_accumulated: networks have incompatible topologies" + "contract_zipup_with: networks have incompatible topologies" )); } @@ -336,32 +366,31 @@ where // 3. Get traversal edges from leaves to center (post-order DFS) let edges = tn_a.edges_to_canonicalize_by_names(center).ok_or_else(|| { - anyhow::anyhow!( - "contract_zipup_tree_accumulated: center node {:?} not found", - center - ) + anyhow::anyhow!("contract_zipup_with: center node {:?} not found", center) })?; // 4. Handle single node case if edges.is_empty() && self.node_count() == 1 { - let node_idx = tn_a.graph.graph().node_indices().next().ok_or_else(|| { - anyhow::anyhow!("contract_zipup_tree_accumulated: no nodes found") - })?; - let t_a = tn_a.tensor(node_idx).ok_or_else(|| { - anyhow::anyhow!("contract_zipup_tree_accumulated: tensor not found in tn_a") - })?; + let node_idx = tn_a + .graph + .graph() + .node_indices() + .next() + .ok_or_else(|| anyhow::anyhow!("contract_zipup_with: no nodes found"))?; + let t_a = tn_a + .tensor(node_idx) + .ok_or_else(|| anyhow::anyhow!("contract_zipup_with: tensor not found in tn_a"))?; let t_b = tn_b .tensor(tn_b.graph.graph().node_indices().next().ok_or_else(|| { - anyhow::anyhow!("contract_zipup_tree_accumulated: tensor not found in tn_b") + anyhow::anyhow!("contract_zipup_with: tensor not found in tn_b") })?) - .ok_or_else(|| { - anyhow::anyhow!("contract_zipup_tree_accumulated: tensor not found in tn_b") - })?; + .ok_or_else(|| anyhow::anyhow!("contract_zipup_with: tensor not found in tn_b"))?; - let contracted = T::contract(&[t_a, t_b], AllowedPairs::All)?; - let node_name = tn_a.graph.node_name(node_idx).ok_or_else(|| { - anyhow::anyhow!("contract_zipup_tree_accumulated: node name not found") - })?; + let contracted = t_a.contract_pair(t_b)?; + let node_name = tn_a + .graph + .node_name(node_idx) + .ok_or_else(|| anyhow::anyhow!("contract_zipup_with: node name not found"))?; let mut result = TreeTN::new(); result.add_tensor(node_name.clone(), contracted)?; @@ -445,7 +474,8 @@ where let c_temp = if is_leaf { // Leaf node: contract A[source] * B[source] - T::contract(&[&tensor_a, &tensor_b], AllowedPairs::All) + tensor_a + .contract_pair(&tensor_b) .context("Failed to contract leaf tensors")? } else { // Internal node: contract [R_accumulated..., A[source], B[source]] @@ -456,8 +486,7 @@ where tensor_list.push(tensor_a); tensor_list.push(tensor_b); let tensor_refs: Vec<&T> = tensor_list.iter().collect(); - T::contract(&tensor_refs, AllowedPairs::All) - .context("Failed to contract internal node tensors")? + T::contract(&tensor_refs).context("Failed to contract internal node tensors")? }; // Factorize child tensor and pass the right factor to destination (even if destination is root) @@ -476,11 +505,34 @@ where .collect(); if left_inds.is_empty() { - // If no left indices remain, pass the tensor directly to destination - intermediate_tensors - .entry(destination_name.clone()) - .or_default() - .push(c_temp); + match topology_mode { + ZipupTopologyMode::PruneScalarSubtrees => { + // If no left indices remain, pass the tensor directly to destination. + intermediate_tensors + .entry(destination_name.clone()) + .or_default() + .push(c_temp); + } + ZipupTopologyMode::PreserveInputTopology => { + // Fit sweeps require C to retain A/B's node set. When a + // subtree has no surviving site indices, keep it connected + // by a dimension-1 dummy link instead of pruning the node. + let (dummy_left, dummy_right) = T::Index::create_dummy_link_pair(); + let left_tensor = T::ones(std::slice::from_ref(&dummy_left)) + .context("Failed to create topology-preserving dummy left tensor")?; + let dummy_right_tensor = T::ones(std::slice::from_ref(&dummy_right)) + .context("Failed to create topology-preserving dummy right tensor")?; + let right_tensor = c_temp + .outer_product(&dummy_right_tensor) + .context("Failed to attach topology-preserving dummy bond")?; + + result_tensors.insert(source_name.clone(), left_tensor); + intermediate_tensors + .entry(destination_name.clone()) + .or_default() + .push(right_tensor); + } + } continue; } @@ -524,8 +576,8 @@ where tensor_list.push(root_tensor_a); tensor_list.push(root_tensor_b); let tensor_refs: Vec<&T> = tensor_list.iter().collect(); - let root_result = T::contract(&tensor_refs, AllowedPairs::All) - .context("Failed to contract root node tensors")?; + let root_result = + T::contract(&tensor_refs).context("Failed to contract root node tensors")?; // Store root result (no factorization needed) result_tensors.insert(root_name.clone(), root_result); @@ -547,7 +599,7 @@ where .tensor(root_b_idx) .ok_or_else(|| anyhow::anyhow!("Root tensor not found in tn_b"))?; - let root_result = T::contract(&[root_tensor_a, root_tensor_b], AllowedPairs::All) + let root_result = T::contract(&[root_tensor_a, root_tensor_b]) .context("Failed to contract root node tensors")?; result_tensors.insert(root_name.clone(), root_result); @@ -570,13 +622,13 @@ where ) { let tensor_a = result.tensor(node_a_idx).ok_or_else(|| { anyhow::anyhow!( - "contract_zipup_tree_accumulated: result tensor not found for node {:?}", + "contract_zipup_with: result tensor not found for node {:?}", source_name ) })?; let tensor_b = result.tensor(node_b_idx).ok_or_else(|| { anyhow::anyhow!( - "contract_zipup_tree_accumulated: result tensor not found for node {:?}", + "contract_zipup_with: result tensor not found for node {:?}", destination_name ) })?; @@ -645,9 +697,11 @@ where .contract_to_tensor() .map_err(|e| anyhow::anyhow!("contract_naive: failed to contract tn2: {}", e))?; - // 4. Contract along common indices - // T::contract auto-contracts all is_contractable pairs - T::contract(&[&tensor1, &tensor2], AllowedPairs::All) + // 4. Exact pairwise product over common site indices. If the caller + // explicitly asks for a reference product with no common site indices + // (for example partial_contract with an empty spec), this is the + // corresponding outer product reference. + tensor1.contract_pair(&tensor2) } /// Validate that `canonical_region` and edge `ortho_towards` are consistent. @@ -841,8 +895,10 @@ pub struct ContractionOptions { /// Maximum dense elements allowed for explicit mismatched-topology /// reference fallback in `partial_contract`. /// - /// `None` rejects the fallback. Set this only for small reference/debug - /// cases where full dense materialization is expected and bounded. + /// `None` rejects the fallback; compatible tree-union mismatches are first + /// handled by structural dimension-1 topology alignment without dense + /// materialization. Set this only for small reference/debug cases where + /// full dense materialization is acceptable if structural alignment fails. pub mismatched_topology_dense_limit: Option, } @@ -941,7 +997,7 @@ impl ContractionOptions { } /// Allow `partial_contract` to use its mismatched-topology dense reference - /// fallback up to `max_elements` elements. + /// fallback up to `max_elements` elements if structural alignment fails. /// /// # Arguments /// * `max_elements` - Maximum number of elements allowed in each dense @@ -949,7 +1005,9 @@ impl ContractionOptions { /// remain small and test-sized. /// /// # Returns - /// Updated options with the dense/reference fallback limit enabled. + /// Updated options with the dense/reference fallback limit enabled. This + /// does not force dense materialization; compatible tree-union topology + /// mismatches are handled by structural dimension-1 alignment first. /// /// # Examples /// diff --git a/crates/tensor4all-treetn/src/treetn/contraction/tests/mod.rs b/crates/tensor4all-treetn/src/treetn/contraction/tests/mod.rs index 73029906..7dce2cca 100644 --- a/crates/tensor4all-treetn/src/treetn/contraction/tests/mod.rs +++ b/crates/tensor4all-treetn/src/treetn/contraction/tests/mod.rs @@ -1,5 +1,7 @@ use super::*; -use tensor4all_core::{DynIndex, IndexLike, SvdTruncationPolicy, TensorDynLen, TensorIndex}; +use tensor4all_core::{ + DynIndex, IndexLike, SvdTruncationPolicy, TensorContractionLike, TensorDynLen, TensorIndex, +}; /// Helper to create a simple 2-node TreeTN: A -- bond -- B fn make_two_node_treetn() -> (TreeTN, DynIndex, DynIndex, DynIndex) { @@ -158,7 +160,7 @@ fn test_contract_to_tensor_two_nodes() { vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0], ) .unwrap(); - let expected = t0.contract(&t1).unwrap().to_vec::().unwrap(); + let expected = t0.contract_pair(&t1).unwrap().to_vec::().unwrap(); let result_data = result.to_vec::().unwrap(); assert_eq!(result_data.len(), expected.len()); diff --git a/crates/tensor4all-treetn/src/treetn/decompose.rs b/crates/tensor4all-treetn/src/treetn/decompose.rs index fc46c3cf..ad24778a 100644 --- a/crates/tensor4all-treetn/src/treetn/decompose.rs +++ b/crates/tensor4all-treetn/src/treetn/decompose.rs @@ -338,7 +338,7 @@ where )); } - // Perform factorization using TensorLike::factorize + // Perform factorization using TensorFactorizationLike::factorize // left will have the node's physical indices + bond index // right will have bond index + remaining indices let factorize_result = current_tensor diff --git a/crates/tensor4all-treetn/src/treetn/evaluator.rs b/crates/tensor4all-treetn/src/treetn/evaluator.rs index 421e5f7f..1c9ff05a 100644 --- a/crates/tensor4all-treetn/src/treetn/evaluator.rs +++ b/crates/tensor4all-treetn/src/treetn/evaluator.rs @@ -307,11 +307,9 @@ where let onehot = T::onehot(&index_vals) .context("TreeTNEvaluator::evaluate_batch: failed to create one-hot tensor")?; - let result = T::contract( - &[&entry.tensor, &onehot], - tensor4all_core::AllowedPairs::All, - ) - .context("TreeTNEvaluator::evaluate_batch: failed to contract tensor with one-hot")?; + let result = T::contract(&[&entry.tensor, &onehot]).context( + "TreeTNEvaluator::evaluate_batch: failed to contract tensor with one-hot", + )?; contracted_tensors.push(result); contracted_names.push(entry.name.clone()); diff --git a/crates/tensor4all-treetn/src/treetn/fit.rs b/crates/tensor4all-treetn/src/treetn/fit.rs index 035589d0..09d9c3df 100644 --- a/crates/tensor4all-treetn/src/treetn/fit.rs +++ b/crates/tensor4all-treetn/src/treetn/fit.rs @@ -34,8 +34,8 @@ use anyhow::Result; use tensor4all_core::{ print_and_reset_contract_profile, print_and_reset_native_einsum_profile, - reset_contract_profile, reset_native_einsum_profile, AllowedPairs, Canonical, FactorizeAlg, - FactorizeOptions, IndexLike, SvdTruncationPolicy, TensorLike, + reset_contract_profile, reset_native_einsum_profile, Canonical, FactorizeAlg, FactorizeOptions, + FactorizeResult, IndexLike, SvdTruncationPolicy, TensorLike, }; use super::localupdate::{LocalUpdateStep, LocalUpdateSweepPlan, LocalUpdater}; @@ -422,9 +422,9 @@ where let tensor_b = tensor_at_node(tn_b, node, "tn_b")?; let tensor_c = tensor_at_node(tn_c, node, "tn_c")?; - // Contract A × B × conj(C) with a single multi-tensor call. + // A, B, and C must form one connected local environment. let c_conj = tensor_c.conj(); - let env = T::contract(&[tensor_a, tensor_b, &c_conj], AllowedPairs::All) + let env = T::contract(&[tensor_a, tensor_b, &c_conj]) .map_err(|e| anyhow::anyhow!("contract failed: {}", e))?; if let Some(started) = started { @@ -465,12 +465,13 @@ where return compute_leaf_environment(node, towards, tn_a, tn_b, tn_c); } - // Non-leaf: contract A × B × conj(C) × child_envs in one multi-tensor call. + // Non-leaf: all local tensors and child environments must form one + // connected contraction graph. let c_conj = tensor_c.conj(); let mut tensor_refs: Vec<&T> = vec![tensor_a, tensor_b, &c_conj]; tensor_refs.extend(child_envs.iter()); - let result = T::contract(&tensor_refs, AllowedPairs::All) - .map_err(|e| anyhow::anyhow!("contract failed: {}", e))?; + let result = + T::contract(&tensor_refs).map_err(|e| anyhow::anyhow!("contract failed: {}", e))?; if let Some(started) = started { with_fit_profile(|profile| { @@ -667,13 +668,13 @@ where env_tensors.push(env); } - // Compute optimal 2-site tensor: env × A[u] × B[u] × A[v] × B[v] × env - // Collect all tensors and let contract() find the optimal contraction order + // Compute optimal 2-site tensor: env × A[u] × B[u] × A[v] × B[v] × env. + // Collect all tensors and let contract() find the optimal contraction order. let contract_started = fit_profile_enabled().then(Instant::now); let mut tensor_refs: Vec<&T> = vec![a_u, b_u, a_v, b_v]; tensor_refs.extend(env_tensors.iter()); - let ab_uv = T::contract(&tensor_refs, AllowedPairs::All) - .map_err(|e| anyhow::anyhow!("contract failed: {}", e))?; + let ab_uv = + T::contract(&tensor_refs).map_err(|e| anyhow::anyhow!("contract failed: {}", e))?; if let Some(contract_started) = contract_started { with_fit_profile(|profile| { profile.two_site_contract_time += contract_started.elapsed(); @@ -691,8 +692,8 @@ where node_u ) })?; - let mut left_inds: Vec<_> = ab_uv - .external_indices() + let ab_uv_indices = ab_uv.external_indices(); + let mut left_inds: Vec<_> = ab_uv_indices .iter() .filter(|idx| { // Keep site indices of u and link indices to u's other neighbors @@ -746,11 +747,39 @@ where options = options.with_max_rank(cap); } - // Factorize using TensorLike::factorize + // Factorize using TensorFactorizationLike::factorize let factorize_started = fit_profile_enabled().then(Instant::now); - let factorize_result = ab_uv - .factorize(&left_inds, &options) - .map_err(|e| anyhow::anyhow!("Factorization failed: {}", e))?; + let factorize_result = if left_inds.is_empty() || left_inds.len() == ab_uv_indices.len() { + let (dummy_left, dummy_right) = T::Index::create_dummy_link_pair(); + let dummy_left_tensor = T::ones(std::slice::from_ref(&dummy_left)) + .map_err(|e| anyhow::anyhow!("Failed to create dummy left tensor: {e}"))?; + let dummy_right_tensor = T::ones(std::slice::from_ref(&dummy_right)) + .map_err(|e| anyhow::anyhow!("Failed to create dummy right tensor: {e}"))?; + + let (left, right) = if left_inds.is_empty() { + let right = ab_uv + .outer_product(&dummy_right_tensor) + .map_err(|e| anyhow::anyhow!("Failed to attach dummy right bond: {e}"))?; + (dummy_left_tensor, right) + } else { + let left = ab_uv + .outer_product(&dummy_left_tensor) + .map_err(|e| anyhow::anyhow!("Failed to attach dummy left bond: {e}"))?; + (left, dummy_right_tensor) + }; + + FactorizeResult { + left, + right, + bond_index: dummy_left, + singular_values: None, + rank: 1, + } + } else { + ab_uv + .factorize(&left_inds, &options) + .map_err(|e| anyhow::anyhow!("Factorization failed: {}", e))? + }; if let Some(factorize_started) = factorize_started { with_fit_profile(|profile| { profile.factorize_time += factorize_started.elapsed(); @@ -958,9 +987,10 @@ where )); } - // Initialize C using the SVD-based zipup contraction. + // Initialize C using the SVD-based zipup contraction while preserving + // the input topology required by variational sweeps. let zipup_started = profile_enabled.then(Instant::now); - let mut tn_c = tn_a.contract_zipup_tree_accumulated( + let mut tn_c = tn_a.contract_zipup_preserving_topology_with( tn_b, center, CanonicalForm::Unitary, diff --git a/crates/tensor4all-treetn/src/treetn/fit/tests/mod.rs b/crates/tensor4all-treetn/src/treetn/fit/tests/mod.rs index e12d2419..b1c9d28f 100644 --- a/crates/tensor4all-treetn/src/treetn/fit/tests/mod.rs +++ b/crates/tensor4all-treetn/src/treetn/fit/tests/mod.rs @@ -5,8 +5,12 @@ use tensor4all_core::{DynIndex, TensorDynLen}; /// Create a simple 2-node TreeTN: A -- bond -- B fn make_two_node_treetn() -> TreeTN { let s0 = DynIndex::new_dyn(2); - let bond = DynIndex::new_dyn(3); let s1 = DynIndex::new_dyn(2); + make_two_node_treetn_with_sites(&s0, &s1) +} + +fn make_two_node_treetn_with_sites(s0: &DynIndex, s1: &DynIndex) -> TreeTN { + let bond = DynIndex::new_dyn(3); let t0 = TensorDynLen::from_dense( vec![s0.clone(), bond.clone()], @@ -26,6 +30,77 @@ fn make_two_node_treetn() -> TreeTN { .unwrap() } +fn make_contractible_two_node_pair() -> (TreeTN, TreeTN) +{ + let s0 = DynIndex::new_dyn(2); + let s1 = DynIndex::new_dyn(2); + ( + make_two_node_treetn_with_sites(&s0, &s1), + make_two_node_treetn_with_sites(&s0, &s1), + ) +} + +fn make_contractible_two_node_pair_with_surviving_sites( +) -> (TreeTN, TreeTN) { + let s0 = DynIndex::new_dyn(2); + let s1 = DynIndex::new_dyn(2); + let a0 = DynIndex::new_dyn(2); + let a1 = DynIndex::new_dyn(2); + let b0 = DynIndex::new_dyn(2); + let b1 = DynIndex::new_dyn(2); + let bond_a = DynIndex::new_dyn(2); + let bond_b = DynIndex::new_dyn(2); + + let tn_a = TreeTN::::from_tensors( + vec![ + TensorDynLen::from_dense( + vec![s0.clone(), a0, bond_a.clone()], + (1..=8).map(|value| value as f64 / 8.0).collect(), + ) + .unwrap(), + TensorDynLen::from_dense( + vec![bond_a, s1.clone(), a1], + (1..=8).map(|value| value as f64 / 10.0).collect(), + ) + .unwrap(), + ], + vec!["A".to_string(), "B".to_string()], + ) + .unwrap(); + let tn_b = TreeTN::::from_tensors( + vec![ + TensorDynLen::from_dense( + vec![s0, b0, bond_b.clone()], + (1..=8).map(|value| (value as f64 - 2.0) / 9.0).collect(), + ) + .unwrap(), + TensorDynLen::from_dense( + vec![bond_b, s1, b1], + (1..=8).map(|value| (value as f64 + 1.0) / 11.0).collect(), + ) + .unwrap(), + ], + vec!["A".to_string(), "B".to_string()], + ) + .unwrap(); + (tn_a, tn_b) +} + +fn make_fit_initial_c( + tn_a: &TreeTN, + tn_b: &TreeTN, + center: &str, +) -> TreeTN { + tn_a.contract_zipup_preserving_topology_with( + tn_b, + ¢er.to_string(), + crate::CanonicalForm::Unitary, + None, + None, + ) + .unwrap() +} + fn make_single_node_treetn() -> TreeTN { let s0 = DynIndex::new_dyn(2); let s1 = DynIndex::new_dyn(3); @@ -160,9 +235,8 @@ fn test_fit_environment_verify_structural_consistency_valid() { #[test] fn test_fit_environment_get_or_compute_caches_leaf_environment() { - let tn_a = make_two_node_treetn(); - let tn_b = make_two_node_treetn(); - let tn_c = make_two_node_treetn(); + let (tn_a, tn_b) = make_contractible_two_node_pair_with_surviving_sites(); + let tn_c = make_fit_initial_c(&tn_a, &tn_b, "A"); let mut env = FitEnvironment::::new(); let from = "A".to_string(); @@ -311,8 +385,7 @@ fn test_contract_fit_rejects_topology_mismatch() { #[test] fn test_contract_fit_matches_naive_contraction_on_two_node_tree() { - let tn_a = make_two_node_treetn(); - let tn_b = make_two_node_treetn(); + let (tn_a, tn_b) = make_contractible_two_node_pair(); let fitted = contract_fit( &tn_a, @@ -324,7 +397,7 @@ fn test_contract_fit_matches_naive_contraction_on_two_node_tree() { let fitted_dense = fitted.to_dense().unwrap(); let expected_dense = tn_a.contract_naive(&tn_b).unwrap(); - assert!(fitted_dense.distance(&expected_dense).unwrap() < 1e-10); + assert!(fitted_dense.sub(&expected_dense).unwrap().maxabs() < 1e-10); } #[test] @@ -334,8 +407,7 @@ fn test_contract_fit_positive_sweeps_do_not_skip_without_truncation_options() { *state.borrow_mut() = None; }); - let tn_a = make_two_node_treetn(); - let tn_b = make_two_node_treetn(); + let (tn_a, tn_b) = make_contractible_two_node_pair(); let fitted = contract_fit( &tn_a, @@ -357,3 +429,77 @@ fn test_contract_fit_positive_sweeps_do_not_skip_without_truncation_options() { let expected_dense = tn_a.contract_naive(&tn_b).unwrap(); assert!(fitted_dense.distance(&expected_dense).unwrap() < 1e-10); } + +#[test] +fn test_contract_fit_rejects_leaf_site_space_that_contracts_away() { + let left = DynIndex::new_dyn(2); + let right = DynIndex::new_dyn(2); + let shared_left = DynIndex::new_dyn(2); + let shared_mid = DynIndex::new_dyn(2); + let shared_leaf = DynIndex::new_dyn(2); + + let a_ab = DynIndex::new_dyn(2); + let a_bc = DynIndex::new_dyn(2); + let b_ab = DynIndex::new_dyn(2); + let b_bc = DynIndex::new_dyn(2); + + let tn_a = TreeTN::::from_tensors( + vec![ + TensorDynLen::from_dense( + vec![left.clone(), shared_left.clone(), a_ab.clone()], + (1..=8).map(|value| value as f64 / 8.0).collect(), + ) + .unwrap(), + TensorDynLen::from_dense( + vec![a_ab.clone(), shared_mid.clone(), a_bc.clone()], + (1..=8).map(|value| value as f64 / 10.0).collect(), + ) + .unwrap(), + TensorDynLen::from_dense( + vec![a_bc.clone(), shared_leaf.clone()], + vec![0.5, 1.5, -0.5, 2.0], + ) + .unwrap(), + ], + vec!["A".to_string(), "B".to_string(), "C".to_string()], + ) + .unwrap(); + + let tn_b = TreeTN::::from_tensors( + vec![ + TensorDynLen::from_dense( + vec![b_ab.clone(), shared_left.clone()], + vec![1.0, -0.5, 0.25, 0.75], + ) + .unwrap(), + TensorDynLen::from_dense( + vec![ + b_ab.clone(), + shared_mid.clone(), + right.clone(), + b_bc.clone(), + ], + (1..=16).map(|value| (value as f64 - 3.0) / 7.0).collect(), + ) + .unwrap(), + TensorDynLen::from_dense( + vec![b_bc.clone(), shared_leaf.clone()], + vec![2.0, -1.0, 0.25, 0.75], + ) + .unwrap(), + ], + vec!["A".to_string(), "B".to_string(), "C".to_string()], + ) + .unwrap(); + + let err = contract_fit( + &tn_a, + &tn_b, + &"A".to_string(), + FitContractionOptions::new(1), + ) + .unwrap_err() + .to_string(); + + assert!(err.contains("Disconnected tensor network")); +} diff --git a/crates/tensor4all-treetn/src/treetn/localupdate.rs b/crates/tensor4all-treetn/src/treetn/localupdate.rs index b6e1101d..4a86496e 100644 --- a/crates/tensor4all-treetn/src/treetn/localupdate.rs +++ b/crates/tensor4all-treetn/src/treetn/localupdate.rs @@ -13,7 +13,7 @@ use std::hash::Hash; use anyhow::{Context, Result}; -use tensor4all_core::{AllowedPairs, IndexLike, TensorLike}; +use tensor4all_core::{IndexLike, TensorLike}; use super::TreeTN; use crate::node_name_network::NodeNameNetwork; @@ -530,8 +530,7 @@ where let tensor_b = subtree .tensor(idx_b) .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?} in subtree", node_b))?; - let tensor_ab = T::contract(&[tensor_a, tensor_b], AllowedPairs::All) - .context("Failed to contract A and B")?; + let tensor_ab = T::contract(&[tensor_a, tensor_b]).context("Failed to contract A and B")?; // Determine left indices (indices that will remain on A after factorization) // These are: all indices of A except the bond to B diff --git a/crates/tensor4all-treetn/src/treetn/mod.rs b/crates/tensor4all-treetn/src/treetn/mod.rs index e3b26b13..5c1293f9 100644 --- a/crates/tensor4all-treetn/src/treetn/mod.rs +++ b/crates/tensor4all-treetn/src/treetn/mod.rs @@ -32,9 +32,7 @@ use std::hash::Hash; use anyhow::{Context, Result}; use crate::algorithm::CanonicalForm; -use tensor4all_core::{ - AllowedPairs, Canonical, FactorizeAlg, FactorizeOptions, IndexLike, TensorLike, -}; +use tensor4all_core::{Canonical, FactorizeAlg, FactorizeOptions, IndexLike, TensorLike}; use crate::named_graph::NamedGraph; use crate::site_index_network::SiteIndexNetwork; @@ -58,8 +56,8 @@ pub use localupdate::{ // Re-export partial contraction types pub use partial_contraction::{ - hadamard, partial_contract, sum_over_indices, weighted_sum_over_index_pairs, - PartialContractionSpec, + hadamard, partial_contract, partial_contract_to_site_network, sum_over_indices, + weighted_sum_over_index_pairs, PartialContractionSpec, }; // Re-export swap types @@ -751,13 +749,12 @@ where .ok_or_else(|| anyhow::anyhow!("Tensor not found for dst node {:?}", dst)) .with_context(|| format!("{}: dst tensor not found", context_name))?; - let updated_dst_tensor = T::contract(&[tensor_dst, &right_tensor], AllowedPairs::All) - .with_context(|| { - format!( - "{}: failed to absorb right factor into dst tensor", - context_name - ) - })?; + let updated_dst_tensor = T::contract(&[tensor_dst, &right_tensor]).with_context(|| { + format!( + "{}: failed to absorb right factor into dst tensor", + context_name + ) + })?; // Update bond index FIRST, so replace_tensor validation matches let new_bond_index = factorize_result.bond_index; @@ -1576,8 +1573,7 @@ where )); } - let tensor_ab = T::contract(&[&tensor_a, &tensor_b], AllowedPairs::All) - .context("swap_on_edge: contract")?; + let tensor_ab = T::contract(&[&tensor_a, &tensor_b]).context("swap_on_edge: contract")?; let ab_indices = tensor_ab.external_indices(); let left_inds: Vec = ab_indices diff --git a/crates/tensor4all-treetn/src/treetn/ops.rs b/crates/tensor4all-treetn/src/treetn/ops.rs index 5e038324..757ebad4 100644 --- a/crates/tensor4all-treetn/src/treetn/ops.rs +++ b/crates/tensor4all-treetn/src/treetn/ops.rs @@ -14,7 +14,7 @@ use std::collections::HashMap; use std::hash::Hash; -use tensor4all_core::{AllowedPairs, AnyScalar, ColMajorArrayRef, IndexLike, TensorLike}; +use tensor4all_core::{AnyScalar, ColMajorArrayRef, IndexLike, TensorLike}; use super::{TreeTN, TreeTNEvaluator}; @@ -437,7 +437,7 @@ where node_name ) })?; - env = T::contract(&[&env, &child_env], AllowedPairs::All) + env = T::contract(&[&env, &child_env]) .context("inner: failed to absorb child environment")?; } @@ -445,7 +445,7 @@ where .tensor(node_idx_other) .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", node_name)) .context("inner: other tensor must exist")?; - env = T::contract(&[&env, other_tensor], AllowedPairs::All) + env = T::contract(&[&env, other_tensor]) .context("inner: failed to contract node bra-ket tensors")?; envs.insert(node_name, env); diff --git a/crates/tensor4all-treetn/src/treetn/partial_contraction.rs b/crates/tensor4all-treetn/src/treetn/partial_contraction.rs index ae89e496..65f690ff 100644 --- a/crates/tensor4all-treetn/src/treetn/partial_contraction.rs +++ b/crates/tensor4all-treetn/src/treetn/partial_contraction.rs @@ -12,11 +12,14 @@ use anyhow::{anyhow, bail, Context, Result}; use super::contraction::{contract, ContractionOptions}; use super::decompose::{factorize_tensor_to_treetn_with, TreeTopology}; +use super::swap::SwapOptions; use super::TreeTN; use crate::error::{format_anyhow_error, SelectedIndexContractionError}; +use crate::options::RestructureOptions; +use crate::site_index_network::SiteIndexNetwork; use tensor4all_core::{ - AllowedPairs, AnyScalar, DynIndex, FactorizeAlg, FactorizeOptions, IndexLike, TensorDynLen, - TensorIndex, TensorLike, + tensordot, AnyScalar, DynIndex, FactorizeAlg, FactorizeOptions, IndexLike, + TensorConstructionLike, TensorContractionLike, TensorDynLen, TensorIndex, TensorLike, }; type DiagonalPairApplication = ( @@ -290,19 +293,69 @@ where Ok(TreeTopology::new(nodes, union_edges)) } -fn validate_mismatched_union_topology( - a: &TreeTN, - b: &TreeTN, -) -> Result<()> +fn align_to_union_topology( + tn: &TreeTN, + node_names: &[V], + union_edges: &[(V, V)], +) -> Result> where V: Clone + Hash + Eq + Send + Sync + Debug + Ord, + ::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync, { - let node_names = compatible_union_node_names(a, b); - let mut union_edges = sorted_edge_set(a); - union_edges.extend(sorted_edge_set(b)); - union_edges.sort(); - union_edges.dedup(); - validate_union_topology(&node_names, &union_edges) + let existing_nodes: HashSet<_> = tn.node_names().into_iter().collect(); + let existing_edges: HashSet<_> = sorted_edge_set(tn).into_iter().collect(); + let mut structural_links = HashMap::>::new(); + + for (u, v) in union_edges { + if existing_edges.contains(&(u.clone(), v.clone())) { + continue; + } + + let link = DynIndex::new_dyn(1); + structural_links + .entry(u.clone()) + .or_default() + .push(link.clone()); + structural_links.entry(v.clone()).or_default().push(link); + } + + let mut tensors = Vec::with_capacity(node_names.len()); + let mut names = Vec::with_capacity(node_names.len()); + + for node_name in node_names { + let links = structural_links.remove(node_name).unwrap_or_default(); + let tensor = if existing_nodes.contains(node_name) { + let node = tn.node_index(node_name).ok_or_else(|| { + anyhow!( + "partial_contract: missing node {:?} while aligning topology", + node_name + ) + })?; + let mut tensor = tn.tensor(node).cloned().ok_or_else(|| { + anyhow!( + "partial_contract: missing tensor for node {:?} while aligning topology", + node_name + ) + })?; + if !links.is_empty() { + let link_tensor = ::ones(&links) + .context("partial_contract: failed to build dimension-1 structural links")?; + tensor = tensor + .outer_product(&link_tensor) + .context("partial_contract: failed to attach dimension-1 structural links")?; + } + tensor + } else { + ::ones(&links) + .context("partial_contract: failed to build missing-node scalar tensor")? + }; + + tensors.push(tensor); + names.push(node_name.clone()); + } + + TreeTN::from_tensors(tensors, names) + .context("partial_contract: failed to align TreeTN to union topology") } fn dense_element_count(indices: &[DynIndex]) -> Result { @@ -381,7 +434,28 @@ where V: Clone + Hash + Eq + Send + Sync + Debug + Ord, ::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync, { - validate_mismatched_union_topology(a, b)?; + let node_names = compatible_union_node_names(a, b); + let mut union_edges = sorted_edge_set(a); + union_edges.extend(sorted_edge_set(b)); + union_edges.sort(); + union_edges.dedup(); + validate_union_topology(&node_names, &union_edges)?; + + let structural_result = (|| { + let aligned_a = align_to_union_topology(a, &node_names, &union_edges)?; + let aligned_b = align_to_union_topology(b, &node_names, &union_edges)?; + contract(&aligned_a, &aligned_b, center, options.clone()) + .context("partial_contract: failed contraction after aligning mismatched topologies") + })(); + + match structural_result { + Ok(result) => return Ok(result), + Err(err) if options.mismatched_topology_dense_limit.is_none() => { + return Err(err); + } + Err(_) => {} + } + validate_mismatched_dense_reference_fallback(a, b, &options)?; let a_dense = a @@ -392,9 +466,9 @@ where .sim_internal_inds() .contract_to_tensor() .context("partial_contract: failed to contract second mismatched-topology TreeTN")?; - let contracted_tensor = - ::contract(&[&a_dense, &b_dense], AllowedPairs::All) - .context("partial_contract: failed dense contraction for mismatched topologies")?; + let contracted_tensor = a_dense + .contract_pair(&b_dense) + .context("partial_contract: failed dense contraction for mismatched topologies")?; if contracted_tensor.external_indices().is_empty() { let mut result = TreeTN::::new(); @@ -449,7 +523,7 @@ where let unique_current_nodes: HashSet<_> = current_nodes.iter().cloned().collect(); if unique_current_nodes.len() != current_nodes.len() { bail!( - "partial_contract: output_order currently requires at most one surviving site index per node" + "partial_contract: output_order currently requires at most one surviving site index per node; use partial_contract_to_site_network with an explicit target network to split surviving indices across nodes" ); } @@ -571,15 +645,18 @@ where idx_b.id() ) })?; - let expanded_tensor = local_tensor - .tensordot(©_tensor, &[(idx_a.clone(), idx_a.clone())]) - .with_context(|| { - format!( - "partial_contract: failed to apply diagonal structure for pair {:?} <- {:?}", - idx_a.id(), - idx_b.id() - ) - })?; + let expanded_tensor = tensordot( + &local_tensor, + ©_tensor, + &[(idx_a.clone(), idx_a.clone())], + ) + .with_context(|| { + format!( + "partial_contract: failed to apply diagonal structure for pair {:?} <- {:?}", + idx_a.id(), + idx_b.id() + ) + })?; a_modified .replace_tensor(node_idx, expanded_tensor) .with_context(|| { @@ -612,6 +689,49 @@ where Ok((a_modified, b_modified, restore_from, restore_to)) } +fn align_contract_pair_site_nodes( + a: &TreeTN, + b: &mut TreeTN, + contract_pairs: &[(DynIndex, DynIndex)], +) -> Result<()> +where + V: Clone + Hash + Eq + Send + Sync + Debug + Ord, + ::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync, +{ + let mut target_assignment = HashMap::new(); + + for (idx_a, _) in contract_pairs { + let left_node = a + .site_index_network() + .find_node_by_index(idx_a) + .cloned() + .ok_or_else(|| { + anyhow!( + "partial_contract: contract pair left index {:?} is not a site index of the first TreeTN", + idx_a.id() + ) + })?; + let right_node = b + .site_index_network() + .find_node_by_index(idx_a) + .cloned() + .ok_or_else(|| { + anyhow!( + "partial_contract: aligned contract index {:?} is not a site index of the second TreeTN", + idx_a.id() + ) + })?; + + if left_node != right_node { + target_assignment.insert(idx_a.clone(), left_node); + } + } + + b.swap_site_indices(&target_assignment, &SwapOptions::default()) + .context("partial_contract: failed to move aligned contract indices to matching nodes")?; + Ok(()) +} + /// Partially contract two TreeTNs according to the given specification. /// /// # Arguments @@ -689,6 +809,7 @@ where } let mut result = if a_modified.same_topology(&b_modified) { + align_contract_pair_site_nodes(&a_modified, &mut b_modified, &spec.contract_pairs)?; contract(&a_modified, &b_modified, center, options) .context("partial_contract: contraction failed")? } else { @@ -708,6 +829,124 @@ where } } +/// Partially contract two TreeTNs and restructure the result to a target site network. +/// +/// Use this when the surviving site indices need a specific output topology, +/// including cases where several surviving indices initially occupy the same +/// result node. The contraction itself is performed by [`partial_contract`] +/// without `output_order`; the returned TreeTN is then transformed with +/// [`TreeTN::restructure_to`]. +/// +/// # Arguments +/// * `a` - First tensor network. Left indices in `spec` must be site indices of this network. +/// * `b` - Second tensor network. Right indices in `spec` must be site indices of this network. +/// * `spec` - Site-index contraction and diagonal-pair specification. `output_order` +/// must be `None` because `target` supplies the output layout. +/// * `center` - Canonical center node used for the intermediate contraction. +/// * `target` - Target site-index network containing exactly the surviving result +/// indices, assigned to the desired output nodes and topology. +/// * `options` - Contraction algorithm options. +/// * `restructure_options` - Split, swap, and optional final truncation settings +/// used when transforming the intermediate result to `target`. +/// +/// # Returns +/// A TreeTN with node names and site-index assignment matching `target`. +/// +/// # Errors +/// Returns an error if `spec.output_order` is set, if the partial contraction +/// fails, or if the contracted result cannot be restructured to `target`. +/// +/// # Examples +/// +/// ``` +/// use std::collections::HashSet; +/// +/// use tensor4all_core::{DynIndex, TensorDynLen, TensorIndex}; +/// use tensor4all_treetn::{ +/// contraction::ContractionOptions, +/// partial_contract_to_site_network, +/// PartialContractionSpec, +/// RestructureOptions, +/// SiteIndexNetwork, +/// TreeTN, +/// }; +/// +/// let i = DynIndex::new_dyn(2); +/// let k_left = DynIndex::new_dyn(2); +/// let k_right = DynIndex::new_dyn(2); +/// let j = DynIndex::new_dyn(2); +/// +/// let a = TreeTN::::from_tensors( +/// vec![TensorDynLen::from_dense( +/// vec![i.clone(), k_left.clone()], +/// vec![1.0, 2.0, 3.0, 4.0], +/// ).unwrap()], +/// vec!["center"], +/// ).unwrap(); +/// let b = TreeTN::::from_tensors( +/// vec![TensorDynLen::from_dense( +/// vec![k_right.clone(), j.clone()], +/// vec![5.0, 6.0, 7.0, 8.0], +/// ).unwrap()], +/// vec!["center"], +/// ).unwrap(); +/// +/// let spec = PartialContractionSpec { +/// contract_pairs: vec![(k_left, k_right)], +/// diagonal_pairs: vec![], +/// output_order: None, +/// }; +/// +/// let mut target = SiteIndexNetwork::new(); +/// target.add_node("0_row", HashSet::from([i.clone()])).unwrap(); +/// target.add_node("1_col", HashSet::from([j.clone()])).unwrap(); +/// target.add_edge(&"0_row", &"1_col").unwrap(); +/// +/// let result = partial_contract_to_site_network( +/// &a, +/// &b, +/// &spec, +/// &"center", +/// &target, +/// ContractionOptions::default(), +/// &RestructureOptions::default(), +/// ).unwrap(); +/// let dense = result.to_dense().unwrap(); +/// +/// assert_eq!(dense.external_indices(), vec![i.clone(), j.clone()]); +/// let expected = vec![23.0, 34.0, 31.0, 46.0]; +/// for (actual, expected) in dense.to_vec::().unwrap().into_iter().zip(expected) { +/// assert!((actual - expected).abs() < 1e-12); +/// } +/// assert_eq!(result.site_index_network().find_node_by_index(&i), Some(&"0_row")); +/// assert_eq!(result.site_index_network().find_node_by_index(&j), Some(&"1_col")); +/// ``` +pub fn partial_contract_to_site_network( + a: &TreeTN, + b: &TreeTN, + spec: &PartialContractionSpec, + center: &V, + target: &SiteIndexNetwork, + options: ContractionOptions, + restructure_options: &RestructureOptions, +) -> Result> +where + V: Clone + Hash + Eq + Send + Sync + Debug + Ord, + TargetV: Clone + Hash + Eq + Send + Sync + Debug + Ord, + ::Id: Clone + Hash + Eq + Ord + Debug + Send + Sync, +{ + if spec.output_order.is_some() { + bail!( + "partial_contract_to_site_network: spec.output_order must be None because the target site network defines the output layout" + ); + } + + let result = partial_contract(a, b, spec, center, options)?; + result.restructure_to(target, restructure_options).context( + "partial_contract_to_site_network: failed to restructure result to target site network", + ) +} + /// Multiply two TreeTNs elementwise along selected external index pairs. /// /// This is a convenience wrapper around [`partial_contract`] using diagonal @@ -961,12 +1200,14 @@ where if let Some(mut links) = link_indices_by_node.remove(node) { indices.append(&mut links); } - tensors.push(TensorDynLen::ones(&indices).map_err(|error| { - SelectedIndexContractionError::BuildOnesTensor { - node: format!("{node:?}"), - message: format_anyhow_error(error), - } - })?); + tensors.push( + ::ones(&indices).map_err(|error| { + SelectedIndexContractionError::BuildOnesTensor { + node: format!("{node:?}"), + message: format_anyhow_error(error), + } + })?, + ); } let weights = TreeTN::from_tensors(tensors, node_names).map_err(|error| { diff --git a/crates/tensor4all-treetn/src/treetn/partial_contraction/tests/mod.rs b/crates/tensor4all-treetn/src/treetn/partial_contraction/tests/mod.rs index a6376618..56d46da3 100644 --- a/crates/tensor4all-treetn/src/treetn/partial_contraction/tests/mod.rs +++ b/crates/tensor4all-treetn/src/treetn/partial_contraction/tests/mod.rs @@ -1,8 +1,13 @@ use super::*; use crate::treetn::contraction::{ContractionMethod, ContractionOptions}; -use crate::{factorize_tensor_to_treetn, SelectedIndexContractionError, TreeTopology}; +use crate::{ + factorize_tensor_to_treetn, RestructureOptions, SelectedIndexContractionError, + SiteIndexNetwork, TreeTopology, +}; use num_complex::Complex64; -use tensor4all_core::{DynIndex, TensorDynLen}; +use tensor4all_core::{ + DynIndex, FactorizeAlg, SvdTruncationPolicy, TensorContractionLike, TensorDynLen, TensorIndex, +}; struct PartialContractionInputs { tn_a: TreeTN, @@ -78,6 +83,78 @@ fn test_partial_contraction_spec_creation() { assert!(spec.diagonal_pairs.is_empty()); } +#[test] +fn helper_topology_and_dense_limit_paths_are_exercised() { + assert_eq!(canonical_edge(&2usize, &1usize), (1, 2)); + + let empty = validate_union_topology::(&[], &[]).unwrap_err(); + assert!(empty.to_string().contains("at least one node")); + + let unknown = validate_union_topology(&[0usize], &[(0usize, 1usize)]).unwrap_err(); + assert!( + unknown.to_string().contains("unknown node") + || unknown.to_string().contains("incompatible topologies") + ); + + let disconnected = validate_union_topology(&[0usize, 1usize], &[]).unwrap_err(); + assert!(disconnected.to_string().contains("incompatible topologies")); + + let i = DynIndex::new_dyn(2); + let j = DynIndex::new_dyn(3); + assert_eq!(dense_element_count(&[i.clone(), j.clone()]).unwrap(), 6); + assert_eq!( + dense_contract_output_indices(&[i.clone(), j.clone()], std::slice::from_ref(&j)), + vec![i.clone()] + ); + assert!( + ensure_dense_reference_limit("lhs", &[i.clone(), j.clone()], 5) + .unwrap_err() + .to_string() + .contains("exceeding limit") + ); + ensure_dense_reference_limit("lhs", &[i.clone(), j.clone()], 6).unwrap(); +} + +#[test] +fn helper_factorize_options_and_union_topology_paths_are_exercised() { + let policy = SvdTruncationPolicy::new(1.0e-8); + let options = ContractionOptions::default() + .with_factorize_alg(FactorizeAlg::SVD) + .with_max_rank(3) + .with_svd_policy(policy); + let factorize_options = factorize_options_from_contraction_options(&options).unwrap(); + assert_eq!(factorize_options.alg, FactorizeAlg::SVD); + assert_eq!(factorize_options.max_rank, Some(3)); + assert_eq!(factorize_options.svd_policy, Some(policy)); + + let qr_options = factorize_options_from_contraction_options( + &ContractionOptions::default() + .with_factorize_alg(FactorizeAlg::QR) + .with_qr_rtol(1.0e-7), + ) + .unwrap(); + assert_eq!(qr_options.alg, FactorizeAlg::QR); + assert_eq!(qr_options.qr_rtol, Some(1.0e-7)); + + let PartialContractionInputs { tn_a, tn_b, .. } = make_partial_contraction_inputs(); + let dense_a = tn_a.contract_to_tensor().unwrap(); + let dense_b = tn_b.contract_to_tensor().unwrap(); + let contracted = dense_a.contract_pair(&dense_b).unwrap(); + let topology = union_result_topology(&tn_a, &tn_b, &contracted).unwrap(); + assert_eq!(topology.nodes.len(), 2); + + validate_mismatched_dense_reference_fallback( + &tn_a, + &tn_b, + &ContractionOptions::default().with_mismatched_topology_dense_limit(4096), + ) + .unwrap(); + let err = + validate_mismatched_dense_reference_fallback(&tn_a, &tn_b, &ContractionOptions::default()) + .unwrap_err(); + assert!(err.to_string().contains("explicit dense/reference limit")); +} + #[test] fn hadamard_multiplies_paired_external_indices() { let left_index = DynIndex::new_dyn(2); @@ -561,6 +638,86 @@ fn test_partial_contract_allows_same_node_in_second_network() { assert!(result.is_ok()); } +#[test] +fn test_partial_contract_moves_misaligned_same_topology_contract_pair() { + let a_row0 = DynIndex::new_dyn(2); + let a_contract = DynIndex::new_dyn(2); + let a_row1 = DynIndex::new_dyn(2); + let a_b01 = DynIndex::new_dyn(1); + let a_b12 = DynIndex::new_dyn(1); + let a_b23 = DynIndex::new_dyn(1); + let a_b34 = DynIndex::new_dyn(1); + + let b_contract = DynIndex::new_dyn(2); + let b_col0 = DynIndex::new_dyn(2); + let b_col1 = DynIndex::new_dyn(2); + let b_b01 = DynIndex::new_dyn(1); + let b_b12 = DynIndex::new_dyn(1); + let b_b23 = DynIndex::new_dyn(1); + let b_b34 = DynIndex::new_dyn(1); + + let tn_a = TreeTN::::from_tensors( + vec![ + TensorDynLen::from_dense(vec![a_row0.clone(), a_b01.clone()], vec![1.0; 2]).unwrap(), + TensorDynLen::from_dense( + vec![a_b01.clone(), a_contract.clone(), a_b12.clone()], + vec![1.0; 2], + ) + .unwrap(), + TensorDynLen::from_dense( + vec![a_b12.clone(), a_row1.clone(), a_b23.clone()], + vec![1.0; 2], + ) + .unwrap(), + TensorDynLen::from_dense(vec![a_b23.clone(), a_b34.clone()], vec![1.0]).unwrap(), + TensorDynLen::from_dense(vec![a_b34.clone()], vec![1.0]).unwrap(), + ], + vec![0, 1, 2, 3, 4], + ) + .unwrap(); + + let tn_b = TreeTN::::from_tensors( + vec![ + TensorDynLen::from_dense(vec![b_b01.clone()], vec![1.0]).unwrap(), + TensorDynLen::from_dense(vec![b_b01.clone(), b_b12.clone()], vec![1.0]).unwrap(), + TensorDynLen::from_dense(vec![b_b12.clone(), b_b23.clone()], vec![1.0]).unwrap(), + TensorDynLen::from_dense( + vec![ + b_b23.clone(), + b_contract.clone(), + b_col0.clone(), + b_b34.clone(), + ], + vec![1.0; 4], + ) + .unwrap(), + TensorDynLen::from_dense(vec![b_b34.clone(), b_col1.clone()], vec![1.0; 2]).unwrap(), + ], + vec![0, 1, 2, 3, 4], + ) + .unwrap(); + + let output_order = vec![ + a_row0.clone(), + a_row1.clone(), + b_col0.clone(), + b_col1.clone(), + ]; + let spec = PartialContractionSpec { + contract_pairs: vec![(a_contract, b_contract)], + diagonal_pairs: vec![], + output_order: Some(output_order.clone()), + }; + + let result = + partial_contract(&tn_a, &tn_b, &spec, &0usize, ContractionOptions::default()).unwrap(); + let dense = result.to_dense().unwrap(); + assert_eq!(dense.external_indices(), output_order); + for value in dense.to_vec::().unwrap() { + assert!((value - 2.0).abs() < 1e-12); + } +} + #[test] fn test_partial_contract_allows_compatible_topology_mismatch_with_gap_leaf() { // tn_a has 1 node, tn_b has 2 nodes. The union topology is still a tree, @@ -592,7 +749,7 @@ fn test_partial_contract_allows_compatible_topology_mismatch_with_gap_leaf() { &tn_b, &spec, &"A".to_string(), - ContractionOptions::default().with_mismatched_topology_dense_limit(64), + ContractionOptions::default(), ); assert!(result.is_ok(), "{result:?}"); @@ -605,7 +762,7 @@ fn test_partial_contract_allows_compatible_topology_mismatch_with_gap_leaf() { } #[test] -fn test_partial_contract_rejects_mismatched_topology_dense_fallback_without_explicit_limit() { +fn test_partial_contract_aligns_long_mismatched_chain_without_dense_limit() { fn binary_chain(node_count: usize) -> TreeTN { let mut tensors = Vec::with_capacity(node_count); let mut names = Vec::with_capacity(node_count); @@ -640,9 +797,11 @@ fn test_partial_contract_rejects_mismatched_topology_dense_fallback_without_expl output_order: None, }; - let err = - partial_contract(&tn_a, &tn_b, &spec, &0usize, ContractionOptions::default()).unwrap_err(); - assert!(err.to_string().contains("explicit dense/reference limit")); + let result = + partial_contract(&tn_a, &tn_b, &spec, &0usize, ContractionOptions::default()).unwrap(); + assert_eq!(result.node_count(), 25); + assert_eq!(result.edge_count(), 24); + assert_eq!(result.external_indices().len(), 49); } #[test] @@ -801,6 +960,196 @@ fn test_partial_contract_honors_output_order() { assert_eq!(indices[1].id(), a0.id()); } +#[test] +fn test_partial_contract_output_order_rejects_multiple_survivors_on_one_node() { + let a_row0 = DynIndex::new_dyn(2); + let a_contract = DynIndex::new_dyn(2); + let a_row1 = DynIndex::new_dyn(2); + let a_b01 = DynIndex::new_dyn(1); + let a_b12 = DynIndex::new_dyn(1); + let a_b23 = DynIndex::new_dyn(1); + let a_b34 = DynIndex::new_dyn(1); + + let b_contract = DynIndex::new_dyn(2); + let b_col0 = DynIndex::new_dyn(2); + let b_col1 = DynIndex::new_dyn(2); + let b_b01 = DynIndex::new_dyn(1); + let b_b12 = DynIndex::new_dyn(1); + let b_b23 = DynIndex::new_dyn(1); + let b_b34 = DynIndex::new_dyn(1); + + let tn_a = TreeTN::::from_tensors( + vec![ + TensorDynLen::from_dense(vec![a_row0.clone(), a_b01.clone()], vec![1.0; 2]).unwrap(), + TensorDynLen::from_dense( + vec![a_b01.clone(), a_contract.clone(), a_b12.clone()], + vec![1.0; 2], + ) + .unwrap(), + TensorDynLen::from_dense( + vec![a_b12.clone(), a_row1.clone(), a_b23.clone()], + vec![1.0; 2], + ) + .unwrap(), + TensorDynLen::from_dense(vec![a_b23.clone(), a_b34.clone()], vec![1.0]).unwrap(), + TensorDynLen::from_dense(vec![a_b34.clone()], vec![1.0]).unwrap(), + ], + vec![0, 1, 2, 3, 4], + ) + .unwrap(); + let tn_b = TreeTN::::from_tensors( + vec![ + TensorDynLen::from_dense(vec![b_b01.clone()], vec![1.0]).unwrap(), + TensorDynLen::from_dense(vec![b_b01.clone(), b_b12.clone()], vec![1.0]).unwrap(), + TensorDynLen::from_dense( + vec![b_b12.clone(), b_col1.clone(), b_b23.clone()], + vec![1.0; 2], + ) + .unwrap(), + TensorDynLen::from_dense( + vec![ + b_b23.clone(), + b_contract.clone(), + b_col0.clone(), + b_b34.clone(), + ], + vec![1.0; 4], + ) + .unwrap(), + TensorDynLen::from_dense(vec![b_b34.clone()], vec![1.0]).unwrap(), + ], + vec![0, 1, 2, 3, 4], + ) + .unwrap(); + + let output_order = vec![ + a_row0.clone(), + a_row1.clone(), + b_col0.clone(), + b_col1.clone(), + ]; + let spec = PartialContractionSpec { + contract_pairs: vec![(a_contract, b_contract)], + diagonal_pairs: vec![], + output_order: Some(output_order.clone()), + }; + + let result = partial_contract(&tn_a, &tn_b, &spec, &0usize, ContractionOptions::default()); + + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("partial_contract_to_site_network")); +} + +#[test] +fn test_partial_contract_to_site_network_splits_onto_explicit_target() { + let i = DynIndex::new_dyn(2); + let k_left = DynIndex::new_dyn(2); + let k_right = DynIndex::new_dyn(2); + let j = DynIndex::new_dyn(2); + + let a = TreeTN::::from_tensors( + vec![ + TensorDynLen::from_dense(vec![i.clone(), k_left.clone()], vec![1.0, 2.0, 3.0, 4.0]) + .unwrap(), + ], + vec!["center"], + ) + .unwrap(); + let b = TreeTN::::from_tensors( + vec![ + TensorDynLen::from_dense(vec![k_right.clone(), j.clone()], vec![5.0, 6.0, 7.0, 8.0]) + .unwrap(), + ], + vec!["center"], + ) + .unwrap(); + + let spec = PartialContractionSpec { + contract_pairs: vec![(k_left, k_right)], + diagonal_pairs: vec![], + output_order: None, + }; + let mut target = SiteIndexNetwork::new(); + target + .add_node("0_row", std::collections::HashSet::from([i.clone()])) + .unwrap(); + target + .add_node("1_col", std::collections::HashSet::from([j.clone()])) + .unwrap(); + target.add_edge(&"0_row", &"1_col").unwrap(); + + let result = partial_contract_to_site_network( + &a, + &b, + &spec, + &"center", + &target, + ContractionOptions::default(), + &RestructureOptions::default(), + ) + .unwrap(); + let dense = result.to_dense().unwrap(); + + assert_eq!(dense.external_indices(), vec![i.clone(), j.clone()]); + assert_eq!( + result.site_index_network().find_node_by_index(&i), + Some(&"0_row") + ); + assert_eq!( + result.site_index_network().find_node_by_index(&j), + Some(&"1_col") + ); + for (actual, expected) in dense + .to_vec::() + .unwrap() + .into_iter() + .zip([23.0, 34.0, 31.0, 46.0]) + { + assert!((actual - expected).abs() < 1e-12); + } +} + +#[test] +fn test_partial_contract_to_site_network_rejects_output_order() { + let i = DynIndex::new_dyn(2); + let j = DynIndex::new_dyn(2); + let a = TreeTN::::from_tensors( + vec![TensorDynLen::from_dense(vec![i.clone()], vec![1.0, 2.0]).unwrap()], + vec!["center"], + ) + .unwrap(); + let b = TreeTN::::from_tensors( + vec![TensorDynLen::from_dense(vec![j], vec![3.0, 4.0]).unwrap()], + vec!["center"], + ) + .unwrap(); + let spec = PartialContractionSpec { + contract_pairs: vec![], + diagonal_pairs: vec![], + output_order: Some(vec![i.clone()]), + }; + let mut target = SiteIndexNetwork::new(); + target + .add_node("out", std::collections::HashSet::from([i])) + .unwrap(); + + let result = partial_contract_to_site_network( + &a, + &b, + &spec, + &"center", + &target, + ContractionOptions::default(), + &RestructureOptions::default(), + ); + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("output_order")); +} + #[test] fn test_partial_contract_complex_diagonal_pair_keeps_left_leg() { let i = DynIndex::new_dyn(2); diff --git a/crates/tensor4all-treetn/src/treetn/restructure/mod.rs b/crates/tensor4all-treetn/src/treetn/restructure/mod.rs index ebdc78e3..efe677cc 100644 --- a/crates/tensor4all-treetn/src/treetn/restructure/mod.rs +++ b/crates/tensor4all-treetn/src/treetn/restructure/mod.rs @@ -829,6 +829,213 @@ where Ok(Some(assignment)) } +fn steiner_tree_indices( + graph: &NamedGraph, + terminals: &HashSet, +) -> HashSet +where + T: TensorLike, + V: Clone + Hash + Eq + Send + Sync + std::fmt::Debug, +{ + if terminals.len() <= 1 { + return terminals.clone(); + } + + let terms: Vec<_> = terminals.iter().copied().collect(); + let root = terms[0]; + let mut result = HashSet::from([root]); + for &term in &terms[1..] { + if let Some((_, path)) = petgraph::algo::astar( + graph.graph(), + root, + |node| node == term, + |_| 1usize, + |_| 0usize, + ) { + result.extend(path); + } + } + result +} + +fn canonical_node_pair(left: &NodeName, right: &NodeName) -> (NodeName, NodeName) +where + NodeName: Clone + Ord, +{ + if left <= right { + (left.clone(), right.clone()) + } else { + (right.clone(), left.clone()) + } +} + +fn choose_site_free_absorption_target( + neighbor_targets: &HashSet, + target_edges: &HashSet<(NodeName, NodeName)>, +) -> Option +where + NodeName: Clone + Hash + Eq + Ord, +{ + if neighbor_targets.is_empty() { + return None; + } + + let mut candidates = neighbor_targets.iter().cloned().collect::>(); + candidates.sort(); + candidates.into_iter().find(|candidate| { + neighbor_targets.iter().all(|other| { + candidate == other || target_edges.contains(&canonical_node_pair(candidate, other)) + }) + }) +} + +fn target_quotient_matches_topology( + current: &SiteIndexNetwork, + target: &SiteIndexNetwork, + site_to_target: &HashMap, + full_graph: &NamedGraph, +) -> Result +where + T: TensorLike, + CurrentV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug, + TargetV: Clone + Hash + Eq + Ord + Send + Sync + std::fmt::Debug, + ::Id: Clone + Hash + Eq + Ord + std::fmt::Debug + Send + Sync, +{ + let mut target_to_current: HashMap> = HashMap::new(); + for current_node_name in current.node_names() { + let site_space = current.site_space(current_node_name).ok_or_else(|| { + anyhow::anyhow!( + "restructure_to: current node {:?} has no registered site space", + current_node_name + ) + })?; + if site_space.is_empty() { + continue; + } + + let target_names: HashSet<_> = site_space + .iter() + .map(|site_idx| { + site_to_target.get(site_idx).cloned().ok_or_else(|| { + anyhow::anyhow!( + "restructure_to: site index {:?} is present in the current network but missing from the target", + site_idx + ) + }) + }) + .collect::>()?; + let Some(target_name) = target_names.into_iter().next() else { + continue; + }; + target_to_current + .entry(target_name) + .or_default() + .insert(current_node_name.clone()); + } + + for target_node_name in target.node_names() { + if !target_to_current.contains_key(target_node_name) { + return Ok(false); + } + } + + for current_nodes in target_to_current.values_mut() { + let seed_indices: HashSet<_> = current_nodes + .iter() + .filter_map(|name| full_graph.node_index(name)) + .collect(); + if seed_indices.len() <= 1 { + continue; + } + for idx in steiner_tree_indices(full_graph, &seed_indices) { + let Some(name) = full_graph.node_name(idx) else { + continue; + }; + if current + .site_space(name) + .is_none_or(|site_space| site_space.is_empty()) + { + current_nodes.insert(name.clone()); + } + } + } + + let mut current_to_target: HashMap = HashMap::new(); + for (target_name, current_nodes) in &target_to_current { + for current_node_name in current_nodes { + if let Some(existing) = + current_to_target.insert(current_node_name.clone(), target_name.clone()) + { + if existing != *target_name { + return Ok(false); + } + } + } + } + + let target_edges: HashSet<_> = target + .edges() + .map(|(left, right)| canonical_node_pair(&left, &right)) + .collect(); + + loop { + let mut additions = Vec::<(CurrentV, TargetV)>::new(); + for current_node_name in current.node_names() { + if current_to_target.contains_key(current_node_name) + || current + .site_space(current_node_name) + .is_some_and(|site_space| !site_space.is_empty()) + { + continue; + } + + let neighbor_targets: HashSet = current + .neighbors(current_node_name) + .filter_map(|neighbor| current_to_target.get(&neighbor).cloned()) + .collect(); + if let Some(target_name) = + choose_site_free_absorption_target(&neighbor_targets, &target_edges) + { + additions.push((current_node_name.clone(), target_name)); + } + } + + if additions.is_empty() { + break; + } + + for (current_node_name, target_name) in additions { + current_to_target.insert(current_node_name, target_name); + } + } + + let mut quotient_edges = HashSet::new(); + for edge in full_graph.graph().edge_indices() { + let Some((left_idx, right_idx)) = full_graph.graph().edge_endpoints(edge) else { + continue; + }; + let Some(left_name) = full_graph.node_name(left_idx) else { + continue; + }; + let Some(right_name) = full_graph.node_name(right_idx) else { + continue; + }; + match ( + current_to_target.get(left_name), + current_to_target.get(right_name), + ) { + (Some(left_target), Some(right_target)) if left_target == right_target => {} + (Some(left_target), Some(right_target)) => { + quotient_edges.insert(canonical_node_pair(left_target, right_target)); + } + (None, None) => {} + _ => return Ok(false), + } + } + + Ok(quotient_edges == target_edges) +} + fn clone_tree(tree: &TreeTN) -> Result> where T: TensorLike, @@ -885,12 +1092,19 @@ where bail!("restructure_to: current and target must contain the same site indices"); } - if current_nodes_map_uniquely_to_targets::(current, &site_to_target)? { + let current_nodes_are_unique = + current_nodes_map_uniquely_to_targets::(current, &site_to_target)?; + if current_nodes_are_unique { if target_nodes_span_connected_currents::( current, target, &site_to_current, full_graph, + )? && target_quotient_matches_topology::( + current, + target, + &site_to_target, + full_graph, )? { return Ok(RestructurePlan { kind: RestructurePlanKind::FuseOnly, @@ -908,7 +1122,15 @@ where } } - if target_nodes_map_uniquely_to_currents::(target, &site_to_current)? { + if target_nodes_map_uniquely_to_currents::(target, &site_to_current)? + && (!current_nodes_are_unique + || target_quotient_matches_topology::( + current, + target, + &site_to_target, + full_graph, + )?) + { return Ok(RestructurePlan { kind: RestructurePlanKind::SplitOnly, }); @@ -981,7 +1203,20 @@ where } }?; - apply_final_truncation(result, options) + let result = apply_final_truncation(result, options)?; + if target.edge_count() > 0 + && !result + .site_index_network() + .share_equivalent_site_index_network(target) + { + bail!( + "restructure_to: result topology does not match target: expected edges {:?}, got {:?}", + target.edges().collect::>(), + result.site_index_network().edges().collect::>() + ); + } + + Ok(result) } impl TreeTN @@ -1207,6 +1442,78 @@ mod tests { Ok(()) } + #[test] + fn test_restructure_to_absorbs_site_free_dangling_leaf() -> anyhow::Result<()> { + let left = DynIndex::new_dyn(2); + let right = DynIndex::new_dyn(2); + let b01 = DynIndex::new_dyn(2); + let b12 = DynIndex::new_dyn(2); + + let t0 = + TensorDynLen::from_dense(vec![left.clone(), b01.clone()], vec![1.0, 2.0, 3.0, 4.0])?; + let t1 = TensorDynLen::from_dense( + vec![b01.clone(), right.clone(), b12.clone()], + (1..=8).map(|value| value as f64 / 3.0).collect(), + )?; + let t2 = TensorDynLen::from_dense(vec![b12], vec![0.5, -1.25])?; + let treetn = TreeTN::::from_tensors( + vec![t0, t1, t2], + vec!["A".to_string(), "B".to_string(), "C".to_string()], + )?; + + let before = treetn.to_dense()?; + let mut target: SiteIndexNetwork = SiteIndexNetwork::new(); + target.add_node("A".to_string(), HashSet::from([left]))?; + target.add_node("B".to_string(), HashSet::from([right]))?; + target.add_edge(&"A".to_string(), &"B".to_string())?; + + let result = treetn.restructure_to(&target, &RestructureOptions::default())?; + + assert_eq!(result.node_count(), 2); + assert_eq!(result.edge_count(), 1); + assert!(result + .site_index_network() + .share_equivalent_site_index_network(&target)); + assert!(result.to_dense()?.sub(&before)?.maxabs() < 1.0e-12); + + Ok(()) + } + + #[test] + fn test_restructure_to_absorbs_site_free_internal_node() -> anyhow::Result<()> { + let left = DynIndex::new_dyn(2); + let right = DynIndex::new_dyn(2); + let b01 = DynIndex::new_dyn(2); + let b12 = DynIndex::new_dyn(2); + + let t0 = + TensorDynLen::from_dense(vec![left.clone(), b01.clone()], vec![1.0, 2.0, 3.0, 4.0])?; + let t1 = + TensorDynLen::from_dense(vec![b01.clone(), b12.clone()], vec![0.5, -1.0, 1.25, 0.75])?; + let t2 = TensorDynLen::from_dense(vec![b12, right.clone()], vec![2.0, -0.5, 1.5, 3.0])?; + let treetn = TreeTN::::from_tensors( + vec![t0, t1, t2], + vec!["A".to_string(), "M".to_string(), "B".to_string()], + )?; + + let before = treetn.to_dense()?; + let mut target: SiteIndexNetwork = SiteIndexNetwork::new(); + target.add_node("A".to_string(), HashSet::from([left]))?; + target.add_node("B".to_string(), HashSet::from([right]))?; + target.add_edge(&"A".to_string(), &"B".to_string())?; + + let result = treetn.restructure_to(&target, &RestructureOptions::default())?; + + assert_eq!(result.node_count(), 2); + assert_eq!(result.edge_count(), 1); + assert!(result + .site_index_network() + .share_equivalent_site_index_network(&target)); + assert!(result.to_dense()?.sub(&before)?.maxabs() < 1.0e-12); + + Ok(()) + } + #[test] fn test_restructure_to_split_only_matches_target_structure() -> anyhow::Result<()> { let (treetn, left, right) = two_node_chain()?; @@ -1476,6 +1783,45 @@ mod tests { Ok(()) } + #[test] + fn test_restructure_to_site_permutation_preserves_target_path() -> anyhow::Result<()> { + let (treetn, x0, x1, y0, y1) = four_node_interleaved_chain()?; + + let mut target: SiteIndexNetwork = SiteIndexNetwork::new(); + target + .add_node("0".to_string(), HashSet::from([x1.clone()])) + .map_err(anyhow::Error::msg)?; + target + .add_node("1".to_string(), HashSet::from([x0.clone()])) + .map_err(anyhow::Error::msg)?; + target + .add_node("2".to_string(), HashSet::from([y1.clone()])) + .map_err(anyhow::Error::msg)?; + target + .add_node("3".to_string(), HashSet::from([y0.clone()])) + .map_err(anyhow::Error::msg)?; + target + .add_edge(&"0".to_string(), &"1".to_string()) + .map_err(anyhow::Error::msg)?; + target + .add_edge(&"1".to_string(), &"2".to_string()) + .map_err(anyhow::Error::msg)?; + target + .add_edge(&"2".to_string(), &"3".to_string()) + .map_err(anyhow::Error::msg)?; + + let result = treetn.restructure_to(&target, &RestructureOptions::default())?; + let dense_expected = treetn.contract_to_tensor()?; + let dense_actual = result.contract_to_tensor()?; + + assert!(result + .site_index_network() + .share_equivalent_site_index_network(&target)); + assert!(dense_actual.distance(&dense_expected).unwrap() < 1e-10); + + Ok(()) + } + // ======================================================================== // Y-shape branching topology tests for restructure_to // ======================================================================== diff --git a/crates/tensor4all-treetn/src/treetn/swap.rs b/crates/tensor4all-treetn/src/treetn/swap.rs index 04cc652a..3f36733d 100644 --- a/crates/tensor4all-treetn/src/treetn/swap.rs +++ b/crates/tensor4all-treetn/src/treetn/swap.rs @@ -21,13 +21,14 @@ use super::{localupdate::LocalUpdateSweepPlan, TreeTN}; /// Factorize a tensor into left and right parts connected by a bond index. /// -/// Extends [`TensorLike::factorize`] to handle degenerate cases where all +/// Extends [`TensorFactorizationLike::factorize`](tensor4all_core::TensorFactorizationLike::factorize) +/// to handle degenerate cases where all /// indices go to one side (empty `left_inds` or `left_inds == all_inds`). /// For these cases a dimension-1 trivial bond is created so that /// `contract(left, right)` recovers the input tensor exactly. /// /// With `Canonical::Left` (the only mode used by swap): -/// - **Normal case**: delegates to `TensorLike::factorize`. +/// - **Normal case**: delegates to `TensorFactorizationLike::factorize`. /// - **Empty `left_inds`**: `left = [1]` (dim-1 scalar isometry), /// `right = tensor ⊗ [1]` (acquires the trivial bond). /// - **Full `left_inds`**: `left = (tensor ⊗ [1]) / ‖tensor‖`, @@ -89,7 +90,7 @@ where }); } - // Normal case: delegate to TensorLike::factorize + // Normal case: delegate to TensorFactorizationLike::factorize tensor .factorize(left_inds, factorize_options) .map_err(|e| anyhow::anyhow!("factorize_or_trivial: factorize: {}", e)) diff --git a/crates/tensor4all-treetn/src/treetn/transform.rs b/crates/tensor4all-treetn/src/treetn/transform.rs index cd8aeb01..5cba28fa 100644 --- a/crates/tensor4all-treetn/src/treetn/transform.rs +++ b/crates/tensor4all-treetn/src/treetn/transform.rs @@ -11,7 +11,7 @@ use anyhow::{Context, Result}; use petgraph::stable_graph::{NodeIndex, StableGraph}; use tensor4all_core::{ - index_ops, AllowedPairs, Canonical, FactorizeAlg, FactorizeOptions, IndexLike, TensorLike, + index_ops, Canonical, FactorizeAlg, FactorizeOptions, IndexLike, TensorLike, }; use super::TreeTN; @@ -45,6 +45,37 @@ fn steiner_tree_indices( result } +fn canonical_node_pair(left: &NodeName, right: &NodeName) -> (NodeName, NodeName) +where + NodeName: Clone + Ord, +{ + if left <= right { + (left.clone(), right.clone()) + } else { + (right.clone(), left.clone()) + } +} + +fn choose_site_free_absorption_target( + neighbor_targets: &HashSet, + target_edges: &HashSet<(NodeName, NodeName)>, +) -> Option +where + NodeName: Clone + Hash + Eq + Ord, +{ + if neighbor_targets.is_empty() { + return None; + } + + let mut candidates = neighbor_targets.iter().cloned().collect::>(); + candidates.sort(); + candidates.into_iter().find(|candidate| { + neighbor_targets.iter().all(|other| { + candidate == other || target_edges.contains(&canonical_node_pair(candidate, other)) + }) + }) +} + /// Check if nodes form a connected induced subgraph in the given graph. /// DFS restricted to edges where both endpoints are in `nodes`. fn is_connected_subset_on_graph( @@ -209,6 +240,56 @@ where } } + // Step 4b: Absorb site-free dangling subtrees into the unique adjacent + // target group, or into an adjacent target group that preserves all + // requested quotient edges. These nodes carry only gauge/internal + // factors; dropping them would change values. + let target_edges: HashSet<_> = target + .edges() + .map(|(left, right)| canonical_node_pair(&left, &right)) + .collect(); + let mut current_to_target = HashMap::::new(); + for (target_name, current_nodes) in &target_to_current { + for current_node in current_nodes { + current_to_target.insert(current_node.clone(), target_name.clone()); + } + } + loop { + let mut additions = Vec::<(TargetV, V)>::new(); + for current_name in self.node_names() { + if current_to_target.contains_key(¤t_name) + || self + .site_space(¤t_name) + .is_some_and(|site_space| !site_space.is_empty()) + { + continue; + } + + let neighbor_targets: HashSet = self + .site_index_network + .neighbors(¤t_name) + .filter_map(|neighbor| current_to_target.get(&neighbor).cloned()) + .collect(); + if let Some(target_name) = + choose_site_free_absorption_target(&neighbor_targets, &target_edges) + { + additions.push((target_name, current_name)); + } + } + + if additions.is_empty() { + break; + } + + for (target_name, current_name) in additions { + target_to_current + .entry(target_name.clone()) + .or_default() + .insert(current_name.clone()); + current_to_target.insert(current_name, target_name); + } + } + // Step 5: For each target node, contract all its current nodes into one tensor let mut result_tensors: HashMap = HashMap::new(); @@ -330,9 +411,9 @@ where .remove(&to) .ok_or_else(|| anyhow::anyhow!("Tensor not found for node {:?}", to))?; - // Contract using TensorLike::contract + // Contract using TensorContractionLike::contract // (bond indices are auto-detected via is_contractable) - let contracted = T::contract(&[&to_tensor, &from_tensor], AllowedPairs::All) + let contracted = T::contract(&[&to_tensor, &from_tensor]) .map_err(|e| anyhow::anyhow!("Failed to contract tensors: {}", e))?; tensors.insert(to, contracted); diff --git a/crates/tensor4all-treetn/tests/ad_treetn.rs b/crates/tensor4all-treetn/tests/ad_treetn.rs index 1395ccc5..a32ca524 100644 --- a/crates/tensor4all-treetn/tests/ad_treetn.rs +++ b/crates/tensor4all-treetn/tests/ad_treetn.rs @@ -1,6 +1,6 @@ //! Tests for reverse-mode automatic differentiation through TreeTN operations. -use tensor4all_core::{contract_multi, AllowedPairs, DynIndex, IndexLike, TensorDynLen}; +use tensor4all_core::{contract, DynIndex, IndexLike, TensorDynLen}; use tensor4all_treetn::TreeTN; fn make_three_site_mps_data() -> (Vec>, Vec>) { @@ -49,7 +49,7 @@ fn backward_ad_to_dense_propagates_gradients() { vec![1.0; dense.indices().iter().map(|i| i.dim()).product::()], ) .unwrap(); - let scalar = contract_multi(&[&dense, &ones], AllowedPairs::All).unwrap(); + let scalar = contract(&[&dense, &ones]).unwrap(); scalar.backward().unwrap(); @@ -83,7 +83,7 @@ fn backward_ad_gradient_matches_finite_diff() { vec![1.0; dense.indices().iter().map(|i| i.dim()).product::()], ) .unwrap(); - let scalar = contract_multi(&[&dense, &ones], AllowedPairs::All).unwrap(); + let scalar = contract(&[&dense, &ones]).unwrap(); scalar.backward().unwrap(); @@ -148,7 +148,7 @@ fn backward_accumulates_until_clear_grad_across_treetn_nodes() { ) .unwrap(); - let scalar = contract_multi(&[&dense, &ones], AllowedPairs::All).unwrap(); + let scalar = contract(&[&dense, &ones]).unwrap(); scalar.backward().unwrap(); let first_grads: Vec> = ttn @@ -165,7 +165,7 @@ fn backward_accumulates_until_clear_grad_across_treetn_nodes() { }) .collect(); - let scalar = contract_multi(&[&dense, &ones], AllowedPairs::All).unwrap(); + let scalar = contract(&[&dense, &ones]).unwrap(); scalar.backward().unwrap(); for (node_pos, (&ni, first_grad)) in ttn.node_indices().iter().zip(&first_grads).enumerate() { diff --git a/crates/tensor4all-treetn/tests/basic.rs b/crates/tensor4all-treetn/tests/basic.rs index 031226ac..84a776c3 100644 --- a/crates/tensor4all-treetn/tests/basic.rs +++ b/crates/tensor4all-treetn/tests/basic.rs @@ -2123,7 +2123,7 @@ fn test_fit_vs_naive_5node_chain() { */ // ============================================================================ -// Tests for contract_zipup_tree_accumulated +// Tests for contract_zipup_with // ============================================================================ /// Helper: Create a simple 2-node TreeTN with String node names for testing @@ -2151,13 +2151,14 @@ fn create_two_node_treetn_string() -> TreeTN { /// Test single node contraction #[test] -fn test_zipup_accumulated_single_node() { +fn test_zipup_with_single_node() { let mut tn_a = TreeTN::::new(); let mut tn_b = TreeTN::::new(); - // Use different indices so contraction produces a result with indices + // Use the same site index so this is a real contraction, not an implicit + // outer product of unrelated inputs. let phys_a = DynIndex::new_dyn(2); - let phys_b = DynIndex::new_dyn(2); + let phys_b = phys_a.clone(); let tensor_a = TensorDynLen::from_dense(vec![phys_a.clone()], vec![1.0, 2.0]).unwrap(); let tensor_b = TensorDynLen::from_dense(vec![phys_b.clone()], vec![3.0, 4.0]).unwrap(); @@ -2165,41 +2166,25 @@ fn test_zipup_accumulated_single_node() { tn_b.add_tensor("X".to_string(), tensor_b).unwrap(); let result = tn_a - .contract_zipup_tree_accumulated( - &tn_b, - &"X".to_string(), - CanonicalForm::Unitary, - None, - None, - ) + .contract_zipup_with(&tn_b, &"X".to_string(), CanonicalForm::Unitary, None, None) .unwrap(); assert_eq!(result.node_count(), 1); let result_tensor = result .tensor(result.node_index(&"X".to_string()).unwrap()) .unwrap(); - // When indices are different, result should have 2 indices (outer product) - // When indices are the same, result is a scalar (0 indices) - assert!( - result_tensor.external_indices().is_empty() || result_tensor.external_indices().len() == 2 - ); + assert!(result_tensor.external_indices().is_empty()); assert!(result.canonical_region().contains(&"X".to_string())); } /// Test 2-node chain contraction #[test] -fn test_zipup_accumulated_two_node_chain() { +fn test_zipup_with_two_node_chain() { let tn_a = create_two_node_treetn_string(); let tn_b = create_two_node_treetn_string(); let result = tn_a - .contract_zipup_tree_accumulated( - &tn_b, - &"B".to_string(), - CanonicalForm::Unitary, - None, - None, - ) + .contract_zipup_with(&tn_b, &"B".to_string(), CanonicalForm::Unitary, None, None) .unwrap(); assert_eq!(result.node_count(), 2); @@ -2211,7 +2196,7 @@ fn test_zipup_accumulated_two_node_chain() { /// Test 3-node chain contraction #[test] -fn test_zipup_accumulated_three_node_chain() { +fn test_zipup_with_three_node_chain() { let mut tn_a = TreeTN::::new(); let mut tn_b = TreeTN::::new(); @@ -2265,13 +2250,7 @@ fn test_zipup_accumulated_three_node_chain() { tn_b.connect(n_b_b, &bond23_b, n_c_b, &bond23_b).unwrap(); let result = tn_a - .contract_zipup_tree_accumulated( - &tn_b, - &"C".to_string(), - CanonicalForm::Unitary, - None, - None, - ) + .contract_zipup_with(&tn_b, &"C".to_string(), CanonicalForm::Unitary, None, None) .unwrap(); assert_eq!(result.node_count(), 3); @@ -2280,7 +2259,7 @@ fn test_zipup_accumulated_three_node_chain() { /// Test star topology (multiple leaves connected to root) #[test] -fn test_zipup_accumulated_star_topology() { +fn test_zipup_with_star_topology() { let mut tn_a = TreeTN::::new(); let mut tn_b = TreeTN::::new(); @@ -2368,13 +2347,7 @@ fn test_zipup_accumulated_star_topology() { tn_b.connect(n_c_b, &bond_cd_b, n_d_b, &bond_cd_b).unwrap(); let result = tn_a - .contract_zipup_tree_accumulated( - &tn_b, - &"D".to_string(), - CanonicalForm::Unitary, - None, - None, - ) + .contract_zipup_with(&tn_b, &"D".to_string(), CanonicalForm::Unitary, None, None) .unwrap(); assert_eq!(result.node_count(), 4); diff --git a/crates/tensor4all-treetn/tests/bug_swap_values.rs b/crates/tensor4all-treetn/tests/bug_swap_values.rs index d2d265a4..b863612a 100644 --- a/crates/tensor4all-treetn/tests/bug_swap_values.rs +++ b/crates/tensor4all-treetn/tests/bug_swap_values.rs @@ -7,7 +7,8 @@ use std::collections::HashMap; use tensor4all_core::{ - common_inds, DynIndex, FactorizeOptions, IndexLike, TensorDynLen, TensorIndex, TensorLike, + common_inds, DynIndex, FactorizeOptions, IndexLike, TensorContractionLike, TensorDynLen, + TensorFactorizationLike, TensorIndex, }; use tensor4all_treetn::{SwapOptions, TreeTN}; @@ -149,7 +150,7 @@ fn test_contract_factorize_roundtrip() { let t1 = treetn.tensor(n1).unwrap().clone(); let t2 = treetn.tensor(n2).unwrap().clone(); - let contracted = t1.contract(&t2).unwrap(); + let contracted = t1.contract_pair(&t2).unwrap(); // Find left_inds for factorization let t1_ids: std::collections::HashSet<_> = t1 @@ -180,7 +181,7 @@ fn test_contract_factorize_roundtrip() { .factorize(&left_inds, &factorize_options) .unwrap(); - let reconstructed = result.left.contract(&result.right).unwrap(); + let reconstructed = result.left.contract_pair(&result.right).unwrap(); let recon_aligned = reconstructed .permute_indices(&contracted.external_indices()) .unwrap(); @@ -206,21 +207,21 @@ fn test_manual_sweep_edge_steps() { // Reference: full contraction of original tensors let full_ref = t[3] - .contract(&t[2]) + .contract_pair(&t[2]) .unwrap() - .contract(&t[1]) + .contract_pair(&t[1]) .unwrap() - .contract(&t[0]) + .contract_pair(&t[0]) .unwrap(); // Helper: contract all 4 tensors and compare to reference let check_full = |_label: &str, tensors: &[TensorDynLen]| -> f64 { let full = tensors[3] - .contract(&tensors[2]) + .contract_pair(&tensors[2]) .unwrap() - .contract(&tensors[1]) + .contract_pair(&tensors[1]) .unwrap() - .contract(&tensors[0]) + .contract_pair(&tensors[0]) .unwrap(); let aligned = full.permute_indices(&full_ref.external_indices()).unwrap(); let neg = full_ref @@ -256,7 +257,7 @@ fn test_manual_sweep_edge_steps() { let fr = t[src] .factorize(&left_inds, &FactorizeOptions::qr()) .unwrap(); - let new_dst = t[dst].contract(&fr.right).unwrap(); + let new_dst = t[dst].contract_pair(&fr.right).unwrap(); t[src] = fr.left; t[dst] = new_dst; }; @@ -296,7 +297,7 @@ fn test_qr_roundtrip_tall_matrix() { .factorize(&[i2.clone(), i3.clone()], &FactorizeOptions::qr()) .unwrap(); - let recon = fr.left.contract(&fr.right).unwrap(); + let recon = fr.left.contract_pair(&fr.right).unwrap(); let recon_aligned = recon.permute_indices(&t.external_indices()).unwrap(); let neg = t.scale(tensor4all_core::AnyScalar::new_real(-1.0)).unwrap(); let diff = recon_aligned.add(&neg).unwrap(); diff --git a/crates/tensor4all-treetn/tests/issue192_regression.rs b/crates/tensor4all-treetn/tests/issue192_regression.rs index 99ea754a..cd740318 100644 --- a/crates/tensor4all-treetn/tests/issue192_regression.rs +++ b/crates/tensor4all-treetn/tests/issue192_regression.rs @@ -178,9 +178,9 @@ fn issue192_regression_no_svd_nan_n5_identity_ones() -> anyhow::Result<()> { let cutoff = 1e-8_f64; let rtol = cutoff.sqrt(); - let krylov_tol = 1e-6_f64; - let krylov_maxiter = 20usize; - let krylov_dim = 30usize; + let gmres_tol = 1e-6_f64; + let gmres_max_restarts = 20usize; + let gmres_restart_dim = 30usize; let (rhs, site_indices) = create_n_site_ones_mps(n_sites, phys_dim, bond_dim)?; let (mpo, s_in_tmp, s_out_tmp) = @@ -206,9 +206,9 @@ fn issue192_regression_no_svd_nan_n5_identity_ones() -> anyhow::Result<()> { let options = LinsolveOptions::default() .with_nfullsweeps(n_sweeps) .with_truncation(truncation) - .with_krylov_tol(krylov_tol) - .with_krylov_maxiter(krylov_maxiter) - .with_krylov_dim(krylov_dim) + .with_gmres_tol(gmres_tol) + .with_gmres_max_restarts(gmres_max_restarts) + .with_gmres_restart_dim(gmres_restart_dim) .with_coefficients(a0, a1); let plan = LocalUpdateSweepPlan::from_treetn(&x, ¢er, 2) diff --git a/crates/tensor4all-treetn/tests/linsolve.rs b/crates/tensor4all-treetn/tests/linsolve.rs index bbf06fa5..f73ffcdd 100644 --- a/crates/tensor4all-treetn/tests/linsolve.rs +++ b/crates/tensor4all-treetn/tests/linsolve.rs @@ -5,10 +5,13 @@ use std::collections::HashMap; -use tensor4all_core::{DynIndex, IndexLike, TensorDynLen, TensorIndex}; +use tensor4all_core::{ + AnyScalar, DynIndex, IndexLike, TensorContractionLike, TensorDynLen, TensorIndex, +}; use tensor4all_treetn::{ - EnvironmentCache, IndexMapping, LinearOperator, LinsolveOptions, NetworkTopology, - ProjectedOperator, ProjectedState, SquareLinsolveUpdater, TreeTN, + relative_linear_system_residual, ApplyOptions, EnvironmentCache, IndexMapping, LinearOperator, + LinsolveOptions, NetworkTopology, ProjectedOperator, ProjectedState, SquareLinsolveUpdater, + TreeTN, }; type FixedSiteMappings = ( @@ -158,11 +161,11 @@ fn test_linsolve_options_default() { let opts = LinsolveOptions::default(); assert_eq!(opts.nfullsweeps, 5); - assert_eq!(opts.krylov_tol, 1e-10); - assert_eq!(opts.krylov_maxiter, 100); - assert_eq!(opts.krylov_dim, 30); - assert_eq!(opts.a0, 0.0); - assert_eq!(opts.a1, 1.0); + assert_eq!(opts.gmres_tol, 1e-10); + assert_eq!(opts.gmres_max_restarts, 100); + assert_eq!(opts.gmres_restart_dim, 30); + assert_eq!(opts.a0, AnyScalar::new_real(0.0)); + assert_eq!(opts.a1, AnyScalar::new_real(1.0)); assert!(opts.convergence_tol.is_none()); } @@ -170,18 +173,18 @@ fn test_linsolve_options_default() { fn test_linsolve_options_builder() { let opts = LinsolveOptions::default() .with_nfullsweeps(5) - .with_krylov_tol(1e-8) - .with_krylov_maxiter(50) - .with_krylov_dim(20) + .with_gmres_tol(1e-8) + .with_gmres_max_restarts(50) + .with_gmres_restart_dim(20) .with_coefficients(1.0, -1.0) .with_convergence_tol(1e-6); assert_eq!(opts.nfullsweeps, 5); - assert_eq!(opts.krylov_tol, 1e-8); - assert_eq!(opts.krylov_maxiter, 50); - assert_eq!(opts.krylov_dim, 20); - assert_eq!(opts.a0, 1.0); - assert_eq!(opts.a1, -1.0); + assert_eq!(opts.gmres_tol, 1e-8); + assert_eq!(opts.gmres_max_restarts, 50); + assert_eq!(opts.gmres_restart_dim, 20); + assert_eq!(opts.a0, AnyScalar::new_real(1.0)); + assert_eq!(opts.a1, AnyScalar::new_real(-1.0)); assert_eq!(opts.convergence_tol, Some(1e-6)); } @@ -494,9 +497,9 @@ fn test_diagonal_linsolve_with_mappings(diag_values: &[f64], b_values: &[f64], t // for a 2-site diagonal operator. let options = LinsolveOptions::default() .with_nfullsweeps(5) - .with_krylov_tol(1e-10) - .with_krylov_dim(10) - .with_krylov_maxiter(30) + .with_gmres_tol(1e-10) + .with_gmres_restart_dim(10) + .with_gmres_max_restarts(30) .with_max_rank(4); let nsweeps = options.nfullsweeps; @@ -809,7 +812,7 @@ fn test_linsolve_3site_identity() { // Solve I * x = b let options = LinsolveOptions::default() .with_nfullsweeps(2) - .with_krylov_tol(1e-8) + .with_gmres_tol(1e-8) .with_max_rank(4); let mut updater = SquareLinsolveUpdater::with_index_mappings( @@ -899,6 +902,59 @@ fn create_mpo_with_internal_indices( ) } +/// Create a 3-node operator where the first two nodes are mapped identity MPO +/// sites and the last node is a scalar no-op spectator. +fn create_mapped_identity_mpo_with_spectator_node( + phys_dim: usize, +) -> ( + TreeTN, + Vec, + Vec, +) { + let mut mpo = TreeTN::::new(); + + let s0_in_tmp = DynIndex::new_dyn(phys_dim); + let s1_in_tmp = DynIndex::new_dyn(phys_dim); + let s0_out_tmp = DynIndex::new_dyn(phys_dim); + let s1_out_tmp = DynIndex::new_dyn(phys_dim); + let b01 = DynIndex::new_dyn(1); + let b12 = DynIndex::new_dyn(1); + + let mut id_data = vec![0.0; phys_dim * phys_dim]; + for i in 0..phys_dim { + id_data[i * phys_dim + i] = 1.0; + } + + let t0 = TensorDynLen::from_dense( + vec![s0_out_tmp.clone(), s0_in_tmp.clone(), b01.clone()], + id_data.clone(), + ) + .unwrap(); + let t1 = TensorDynLen::from_dense( + vec![ + b01.clone(), + s1_out_tmp.clone(), + s1_in_tmp.clone(), + b12.clone(), + ], + id_data, + ) + .unwrap(); + let t2 = TensorDynLen::from_dense(vec![b12.clone()], vec![1.0]).unwrap(); + + let n0 = mpo.add_tensor("site0", t0).unwrap(); + let n1 = mpo.add_tensor("site1", t1).unwrap(); + let n2 = mpo.add_tensor("site2", t2).unwrap(); + mpo.connect(n0, &b01, n1, &b01).unwrap(); + mpo.connect(n1, &b12, n2, &b12).unwrap(); + + ( + mpo, + vec![s0_in_tmp, s1_in_tmp], + vec![s0_out_tmp, s1_out_tmp], + ) +} + #[test] fn test_linear_operator_creation() { let phys_dim = 2; @@ -1214,7 +1270,7 @@ fn test_linsolve_with_index_mappings_identity() { // Create SquareLinsolveUpdater with index mappings let options = LinsolveOptions::default() .with_nfullsweeps(1) - .with_krylov_tol(1e-10) + .with_gmres_tol(1e-10) .with_max_rank(4); let mut updater = SquareLinsolveUpdater::with_index_mappings( @@ -1265,7 +1321,7 @@ fn test_linsolve_with_index_mappings_diagonal() { // Create SquareLinsolveUpdater with index mappings let options = LinsolveOptions::default() .with_nfullsweeps(3) - .with_krylov_tol(1e-10) + .with_gmres_tol(1e-10) .with_max_rank(4); let mut updater = SquareLinsolveUpdater::with_index_mappings( @@ -1421,7 +1477,7 @@ fn test_linsolve_with_index_mappings_three_site_identity() { // Create SquareLinsolveUpdater with index mappings let options = LinsolveOptions::default() .with_nfullsweeps(1) - .with_krylov_tol(1e-10) + .with_gmres_tol(1e-10) .with_max_rank(4); let mut updater = SquareLinsolveUpdater::with_index_mappings( @@ -1478,7 +1534,7 @@ fn test_linsolve_with_index_mappings_three_site_diagonal() { // Create SquareLinsolveUpdater with index mappings let options = LinsolveOptions::default() .with_nfullsweeps(5) - .with_krylov_tol(1e-10) + .with_gmres_tol(1e-10) .with_max_rank(4); let mut updater = SquareLinsolveUpdater::with_index_mappings( @@ -1612,7 +1668,7 @@ fn test_linsolve_pauli_x() { // Solve X * x = b let options = LinsolveOptions::default() .with_nfullsweeps(20) - .with_krylov_tol(1e-12) + .with_gmres_tol(1e-12) .with_max_rank(8); let mut updater = SquareLinsolveUpdater::with_index_mappings( @@ -1781,7 +1837,7 @@ fn test_linsolve_general_matrix() { // Solve A * x = b let options = LinsolveOptions::default() .with_nfullsweeps(30) - .with_krylov_tol(1e-12) + .with_gmres_tol(1e-12) .with_max_rank(8); let mut updater = SquareLinsolveUpdater::with_index_mappings( @@ -1869,7 +1925,7 @@ fn test_linsolve_general_matrix_nonsymmetric() { let options = LinsolveOptions::default() .with_nfullsweeps(30) - .with_krylov_tol(1e-12) + .with_gmres_tol(1e-12) .with_max_rank(8); let mut updater = SquareLinsolveUpdater::with_index_mappings( @@ -2140,7 +2196,7 @@ fn test_linsolve_n_site_identity_impl(n_sites: usize) { // Create SquareLinsolveUpdater with index mappings let options = LinsolveOptions::default() .with_nfullsweeps(1) - .with_krylov_tol(1e-10) + .with_gmres_tol(1e-10) .with_max_rank(4); let mut updater = SquareLinsolveUpdater::with_index_mappings( @@ -2196,10 +2252,11 @@ fn test_square_linsolve_with_mappings_identity() { let init = rhs.clone(); let options = LinsolveOptions::default() .with_nfullsweeps(3) - .with_krylov_tol(1e-10) - .with_krylov_dim(10) - .with_krylov_maxiter(30) - .with_max_rank(4); + .with_gmres_tol(1e-10) + .with_gmres_restart_dim(10) + .with_gmres_max_restarts(30) + .with_max_rank(4) + .with_convergence_tol(1e-8); // This previously failed with index mismatch when mappings were not supported let result = square_linsolve( @@ -2215,6 +2272,8 @@ fn test_square_linsolve_with_mappings_identity() { assert_eq!(result.solution.node_count(), 2); assert!(result.sweeps > 0); + assert!(result.converged); + assert!(result.residual.is_some_and(|residual| residual < 1.0e-8)); // For identity operator, solution should match RHS let contracted = result.solution.contract_to_tensor().unwrap(); @@ -2236,6 +2295,171 @@ fn test_square_linsolve_with_mappings_identity() { ); } +#[test] +fn test_relative_linear_system_residual_with_mapped_coefficients() { + let phys_dim = 2; + let (solution, site_indices, _bonds) = create_mps_from_values(&[1.0, 2.0, 3.0, 4.0], phys_dim); + let (mpo, s_in_tmp, s_out_tmp) = create_mpo_with_internal_indices(&[2.0, 2.0], phys_dim); + let (input_mapping, output_mapping) = + create_fixed_site_index_mappings(["site0", "site1"], &site_indices, &s_in_tmp, &s_out_tmp); + let operator = LinearOperator::new(mpo, input_mapping, output_mapping); + let mut rhs = solution.clone(); + rhs.scale(AnyScalar::new_real(7.0)).unwrap(); + + let residual = relative_linear_system_residual( + &operator, + &solution, + &rhs, + AnyScalar::new_real(3.0), + AnyScalar::new_real(1.0), + ApplyOptions::naive(), + ) + .unwrap(); + + assert!(residual < 1.0e-12, "residual={residual}"); +} + +/// Mapped local linsolve should allow operator nodes with no site indices. +/// +/// This is the generic tensor-network form of an operator acting on selected +/// sites while later state sites are carried as spectators. +#[test] +fn test_square_linsolve_with_mappings_allows_unmapped_spectator_nodes() { + use tensor4all_treetn::square_linsolve; + + let phys_dim = 2; + let (rhs, site_indices, _bonds) = create_simple_mps_chain(); + let (mpo, s_in_tmp, s_out_tmp) = create_mapped_identity_mpo_with_spectator_node(phys_dim); + + let (input_mapping, output_mapping) = create_fixed_site_index_mappings( + ["site0", "site1"], + &site_indices[..2], + &s_in_tmp, + &s_out_tmp, + ); + assert!(mpo + .site_space(&"site2") + .is_some_and(|space| space.is_empty())); + + let options = LinsolveOptions::default() + .with_nfullsweeps(2) + .with_gmres_tol(1e-10) + .with_gmres_restart_dim(10) + .with_gmres_max_restarts(30) + .with_max_rank(8); + let result = square_linsolve( + &mpo, + &rhs, + rhs.clone(), + &"site0", + options, + Some(input_mapping), + Some(output_mapping), + ) + .unwrap(); + + assert_eq!(result.solution.node_count(), 3); + let got = result + .solution + .contract_to_tensor() + .unwrap() + .permuteinds(&site_indices) + .unwrap() + .to_vec::() + .unwrap(); + let expected = rhs + .contract_to_tensor() + .unwrap() + .permuteinds(&site_indices) + .unwrap() + .to_vec::() + .unwrap(); + let diff_norm: f64 = got + .iter() + .zip(expected.iter()) + .map(|(&got, &expected)| (got - expected).powi(2)) + .sum::() + .sqrt(); + let expected_norm: f64 = expected + .iter() + .map(|&value| value.powi(2)) + .sum::() + .sqrt(); + assert!( + diff_norm / expected_norm < 1e-8, + "spectator-node linsolve relative error = {}", + diff_norm / expected_norm + ); +} + +/// The mapped solver must preserve the identity-term-only linear system. +#[test] +fn test_square_linsolve_with_mappings_identity_term_only() { + use tensor4all_treetn::square_linsolve; + + let phys_dim = 2; + let (rhs, site_indices, _bonds) = create_mps_from_values(&[1.0, 2.0, 3.0, 4.0], phys_dim); + let mut init = rhs.clone(); + init.scale(AnyScalar::new_real(1.0 + 1.0e-6)).unwrap(); + let (mpo, s_in_tmp, s_out_tmp) = create_mpo_with_internal_indices(&[0.0, 0.0], phys_dim); + let (input_mapping, output_mapping) = + create_fixed_site_index_mappings(["site0", "site1"], &site_indices, &s_in_tmp, &s_out_tmp); + + let options = LinsolveOptions::default() + .with_nfullsweeps(3) + .with_coefficients(1.0, 1.0) + .with_gmres_tol(1e-12) + .with_gmres_restart_dim(10) + .with_gmres_max_restarts(30) + .with_max_rank(4) + .with_convergence_tol(1e-8); + let result = square_linsolve( + &mpo, + &rhs, + init, + &"site0", + options, + Some(input_mapping), + Some(output_mapping), + ) + .unwrap(); + + assert_eq!(result.sweeps, 0); + let got = result + .solution + .contract_to_tensor() + .unwrap() + .permuteinds(&site_indices) + .unwrap() + .to_vec::() + .unwrap(); + let expected = rhs + .contract_to_tensor() + .unwrap() + .permuteinds(&site_indices) + .unwrap() + .to_vec::() + .unwrap(); + let diff_norm: f64 = got + .iter() + .zip(expected.iter()) + .map(|(&got, &expected)| (got - expected).powi(2)) + .sum::() + .sqrt(); + let expected_norm: f64 = expected + .iter() + .map(|&value| value.powi(2)) + .sum::() + .sqrt(); + assert!( + diff_norm / expected_norm < 1e-8, + "identity-term-only linsolve relative error = {}", + diff_norm / expected_norm + ); + assert!(result.converged); + assert!(result.residual.is_some_and(|residual| residual < 1.0e-8)); +} + /// Test that square_linsolve still works without mappings (backward compat). #[test] fn test_square_linsolve_no_mappings_shared_indices() { diff --git a/crates/tensor4all-treetn/tests/linsolve_mpo_xb.rs b/crates/tensor4all-treetn/tests/linsolve_mpo_xb.rs index 9680e659..2eeda0e6 100644 --- a/crates/tensor4all-treetn/tests/linsolve_mpo_xb.rs +++ b/crates/tensor4all-treetn/tests/linsolve_mpo_xb.rs @@ -107,9 +107,9 @@ fn test_linsolve_allows_two_site_indices_per_node_for_rhs_alignment() -> anyhow: let options = LinsolveOptions::default() .with_nfullsweeps(1) - .with_krylov_tol(1e-8) - .with_krylov_maxiter(10) - .with_krylov_dim(10) + .with_gmres_tol(1e-8) + .with_gmres_max_restarts(10) + .with_gmres_restart_dim(10) .with_max_rank(4) .with_coefficients(0.0, 1.0); @@ -154,9 +154,9 @@ fn test_linsolve_precheck_fails_when_init_rhs_index_structure_mismatch() { let options = LinsolveOptions::default() .with_nfullsweeps(1) - .with_krylov_tol(1e-8) - .with_krylov_maxiter(10) - .with_krylov_dim(10) + .with_gmres_tol(1e-8) + .with_gmres_max_restarts(10) + .with_gmres_restart_dim(10) .with_max_rank(4) .with_coefficients(0.0, 1.0); diff --git a/crates/tensor4all-treetn/tests/ops.rs b/crates/tensor4all-treetn/tests/ops.rs index 2a3668e1..0594d107 100644 --- a/crates/tensor4all-treetn/tests/ops.rs +++ b/crates/tensor4all-treetn/tests/ops.rs @@ -2,7 +2,7 @@ use num_complex::Complex64; use tensor4all_core::{ - AnyScalar, ColMajorArrayRef, DynIndex, IndexLike, TensorDynLen, TensorIndex, TensorLike, + AnyScalar, ColMajorArrayRef, DynIndex, IndexLike, TensorDynLen, TensorIndex, }; use tensor4all_treetn::TreeTN; diff --git a/crates/tensor4all-treetn/tests/simplett_bridge.rs b/crates/tensor4all-treetn/tests/simplett_bridge.rs index 3c601bd5..f9d7b115 100644 --- a/crates/tensor4all-treetn/tests/simplett_bridge.rs +++ b/crates/tensor4all-treetn/tests/simplett_bridge.rs @@ -3,8 +3,10 @@ use num_complex::Complex64; use tensor4all_core::{DynIndex, IndexLike, TensorDynLen, TensorIndex}; use tensor4all_simplett::{tensor3_from_data, AbstractTensorTrain, TensorTrain}; use tensor4all_treetn::{ + fix_and_remove_site_from_treetn_chain, insert_onehot_site_in_treetn_chain, tensor_train_to_treetn, tensor_train_to_treetn_with_names, - tensor_train_to_treetn_with_names_and_site_indices, treetn_to_tensor_train, TreeTN, + tensor_train_to_treetn_with_names_and_site_indices, treetn_to_tensor_train, + weighted_remove_site_from_treetn_chain, TreeTN, }; fn two_site_tensor_train_f64() -> TensorTrain { @@ -70,6 +72,86 @@ fn two_site_tensor_train_c64() -> TensorTrain { .expect("valid complex two-site tensor train") } +fn dense_offset(indices: &[usize], dims: &[usize]) -> usize { + let mut stride = 1usize; + let mut offset = 0usize; + for (&index, &dim) in indices.iter().zip(dims) { + offset += index * stride; + stride *= dim; + } + offset +} + +fn decode_col_major(mut offset: usize, dims: &[usize]) -> Vec { + dims.iter() + .map(|&dim| { + let index = offset % dim; + offset /= dim; + index + }) + .collect() +} + +fn fixed_removed_dense( + values: &[f64], + dims: &[usize], + position: usize, + fixed_value: usize, +) -> Vec { + let mut new_dims = dims.to_vec(); + new_dims.remove(position); + let total = new_dims.iter().product(); + let mut result = Vec::with_capacity(total); + + for offset in 0..total { + let new_multi = decode_col_major(offset, &new_dims); + let mut old_multi = Vec::with_capacity(dims.len()); + let mut next_new = 0usize; + for site in 0..dims.len() { + if site == position { + old_multi.push(fixed_value); + } else { + old_multi.push(new_multi[next_new]); + next_new += 1; + } + } + result.push(values[dense_offset(&old_multi, dims)]); + } + result +} + +fn weighted_removed_dense( + values: &[f64], + dims: &[usize], + position: usize, + weights: &[f64], +) -> Vec { + let mut new_dims = dims.to_vec(); + new_dims.remove(position); + let total = new_dims.iter().product(); + let mut result = Vec::with_capacity(total); + + for offset in 0..total { + let new_multi = decode_col_major(offset, &new_dims); + let mut value = 0.0; + for (removed_value, weight) in weights.iter().enumerate() { + let mut old_multi = Vec::with_capacity(dims.len()); + let mut next_new = 0usize; + for site in 0..dims.len() { + if site == position { + old_multi.push(removed_value); + } else { + old_multi.push(new_multi[next_new]); + next_new += 1; + } + } + value += weight * values[dense_offset(&old_multi, dims)]; + } + result.push(value); + } + result +} + #[test] fn tensor_train_to_treetn_preserves_dense_values() -> Result<()> { let tt = two_site_tensor_train_f64(); @@ -226,6 +308,290 @@ fn treetn_to_tensor_train_rejects_ad_tracked_site_tensor() -> Result<()> { Ok(()) } +#[test] +fn insert_onehot_site_in_treetn_chain_prepends_fixed_site() -> Result<()> { + let tt = two_site_tensor_train_f64(); + let (treetn, old_sites) = tensor_train_to_treetn(&tt)?; + let inserted_site = DynIndex::new_dyn(2); + + let result = insert_onehot_site_in_treetn_chain::(treetn, 0, inserted_site.clone(), 0)?; + let dense = result.contract_to_tensor()?; + let (old_values, _) = tt.fulltensor(); + let expected = TensorDynLen::from_dense( + vec![inserted_site, old_sites[0].clone(), old_sites[1].clone()], + vec![ + old_values[0], + 0.0, + old_values[1], + 0.0, + old_values[2], + 0.0, + old_values[3], + 0.0, + ], + )?; + + assert_eq!(result.node_names(), vec![0, 1, 2]); + assert!(dense.distance(&expected).unwrap() < 1.0e-12); + Ok(()) +} + +#[test] +fn insert_onehot_site_in_treetn_chain_preserves_edge_bond_flow() -> Result<()> { + let tt = two_site_tensor_train_f64(); + let (treetn, old_sites) = tensor_train_to_treetn(&tt)?; + let inserted_site = DynIndex::new_dyn(2); + + let result = insert_onehot_site_in_treetn_chain::(treetn, 1, inserted_site.clone(), 1)?; + let dense = result.contract_to_tensor()?; + let (old_values, _) = tt.fulltensor(); + let expected = TensorDynLen::from_dense( + vec![old_sites[0].clone(), inserted_site, old_sites[1].clone()], + vec![ + 0.0, + 0.0, + old_values[0], + old_values[1], + 0.0, + 0.0, + old_values[2], + old_values[3], + ], + )?; + + assert_eq!(result.node_names(), vec![0, 1, 2]); + assert_eq!( + treetn_to_tensor_train::(result.clone())?.site_dims(), + vec![2, 2, 2] + ); + assert!(dense.distance(&expected).unwrap() < 1.0e-12); + Ok(()) +} + +#[test] +fn insert_onehot_site_in_treetn_chain_appends_fixed_site() -> Result<()> { + let tt = two_site_tensor_train_f64(); + let (treetn, old_sites) = tensor_train_to_treetn(&tt)?; + let inserted_site = DynIndex::new_dyn(2); + + let result = insert_onehot_site_in_treetn_chain::(treetn, 2, inserted_site.clone(), 1)?; + let dense = result.contract_to_tensor()?; + let (old_values, _) = tt.fulltensor(); + let expected = TensorDynLen::from_dense( + vec![old_sites[0].clone(), old_sites[1].clone(), inserted_site], + vec![ + 0.0, + 0.0, + 0.0, + 0.0, + old_values[0], + old_values[1], + old_values[2], + old_values[3], + ], + )?; + + assert_eq!(result.node_names(), vec![0, 1, 2]); + assert!(dense.distance(&expected).unwrap() < 1.0e-12); + Ok(()) +} + +#[test] +fn insert_onehot_site_in_treetn_chain_rejects_invalid_position() -> Result<()> { + let tt = two_site_tensor_train_f64(); + let (treetn, _) = tensor_train_to_treetn(&tt)?; + + let err = + insert_onehot_site_in_treetn_chain::(treetn, 3, DynIndex::new_dyn(2), 0).unwrap_err(); + + assert!(err.to_string().contains("position 3 is out of range 0..=2")); + Ok(()) +} + +#[test] +fn insert_onehot_site_in_treetn_chain_rejects_invalid_fixed_value() -> Result<()> { + let tt = two_site_tensor_train_f64(); + let (treetn, _) = tensor_train_to_treetn(&tt)?; + + let err = + insert_onehot_site_in_treetn_chain::(treetn, 0, DynIndex::new_dyn(2), 2).unwrap_err(); + + assert!(err + .to_string() + .contains("fixed value 2 exceeds site dimension 2")); + Ok(()) +} + +#[test] +fn fix_and_remove_site_from_treetn_chain_removes_first_site() -> Result<()> { + let tt = three_site_tensor_train_f64(); + let (treetn, site_indices) = tensor_train_to_treetn(&tt)?; + + let result = fix_and_remove_site_from_treetn_chain::(treetn, 0, 1)?; + let dense = result.contract_to_tensor()?; + let (old_values, old_dims) = tt.fulltensor(); + let expected = TensorDynLen::from_dense( + vec![site_indices[1].clone(), site_indices[2].clone()], + fixed_removed_dense(&old_values, &old_dims, 0, 1), + )?; + + assert_eq!(result.node_names(), vec![0, 1]); + assert_eq!( + treetn_to_tensor_train::(result.clone())?.site_dims(), + vec![2, 2] + ); + assert!(dense.distance(&expected).unwrap() < 1.0e-12); + Ok(()) +} + +#[test] +fn fix_and_remove_site_from_treetn_chain_removes_middle_site() -> Result<()> { + let tt = three_site_tensor_train_f64(); + let (treetn, site_indices) = tensor_train_to_treetn(&tt)?; + + let result = fix_and_remove_site_from_treetn_chain::(treetn, 1, 0)?; + let dense = result.contract_to_tensor()?; + let (old_values, old_dims) = tt.fulltensor(); + let expected = TensorDynLen::from_dense( + vec![site_indices[0].clone(), site_indices[2].clone()], + fixed_removed_dense(&old_values, &old_dims, 1, 0), + )?; + + assert_eq!(result.node_names(), vec![0, 1]); + assert_eq!( + treetn_to_tensor_train::(result.clone())?.site_dims(), + vec![2, 2] + ); + assert!(dense.distance(&expected).unwrap() < 1.0e-12); + Ok(()) +} + +#[test] +fn fix_and_remove_site_from_treetn_chain_removes_last_site() -> Result<()> { + let tt = three_site_tensor_train_f64(); + let (treetn, site_indices) = tensor_train_to_treetn(&tt)?; + + let result = fix_and_remove_site_from_treetn_chain::(treetn, 2, 1)?; + let dense = result.contract_to_tensor()?; + let (old_values, old_dims) = tt.fulltensor(); + let expected = TensorDynLen::from_dense( + vec![site_indices[0].clone(), site_indices[1].clone()], + fixed_removed_dense(&old_values, &old_dims, 2, 1), + )?; + + assert_eq!(result.node_names(), vec![0, 1]); + assert_eq!( + treetn_to_tensor_train::(result.clone())?.site_dims(), + vec![2, 2] + ); + assert!(dense.distance(&expected).unwrap() < 1.0e-12); + Ok(()) +} + +#[test] +fn weighted_remove_site_from_treetn_chain_removes_middle_site() -> Result<()> { + let tt = three_site_tensor_train_f64(); + let (treetn, site_indices) = tensor_train_to_treetn(&tt)?; + let weights = [0.25, 0.75]; + + let result = weighted_remove_site_from_treetn_chain::(treetn, 1, &weights)?; + let dense = result.contract_to_tensor()?; + let (old_values, old_dims) = tt.fulltensor(); + let expected = TensorDynLen::from_dense( + vec![site_indices[0].clone(), site_indices[2].clone()], + weighted_removed_dense(&old_values, &old_dims, 1, &weights), + )?; + + assert_eq!(result.node_names(), vec![0, 1]); + assert_eq!( + treetn_to_tensor_train::(result.clone())?.site_dims(), + vec![2, 2] + ); + assert!(dense.distance(&expected).unwrap() < 1.0e-12); + Ok(()) +} + +#[test] +fn weighted_remove_site_from_treetn_chain_removes_boundary_site() -> Result<()> { + let tt = three_site_tensor_train_f64(); + let (treetn, site_indices) = tensor_train_to_treetn(&tt)?; + let weights = [1.5, -0.5]; + + let result = weighted_remove_site_from_treetn_chain::(treetn, 2, &weights)?; + let dense = result.contract_to_tensor()?; + let (old_values, old_dims) = tt.fulltensor(); + let expected = TensorDynLen::from_dense( + vec![site_indices[0].clone(), site_indices[1].clone()], + weighted_removed_dense(&old_values, &old_dims, 2, &weights), + )?; + + assert_eq!(result.node_names(), vec![0, 1]); + assert_eq!( + treetn_to_tensor_train::(result.clone())?.site_dims(), + vec![2, 2] + ); + assert!(dense.distance(&expected).unwrap() < 1.0e-12); + Ok(()) +} + +#[test] +fn fix_and_remove_site_from_treetn_chain_rejects_invalid_position() -> Result<()> { + let tt = two_site_tensor_train_f64(); + let (treetn, _) = tensor_train_to_treetn(&tt)?; + + let err = fix_and_remove_site_from_treetn_chain::(treetn, 2, 0).unwrap_err(); + + assert!(err.to_string().contains("position 2 is out of range 0..2")); + Ok(()) +} + +#[test] +fn fix_and_remove_site_from_treetn_chain_rejects_invalid_fixed_value() -> Result<()> { + let tt = two_site_tensor_train_f64(); + let (treetn, _) = tensor_train_to_treetn(&tt)?; + + let err = fix_and_remove_site_from_treetn_chain::(treetn, 0, 2).unwrap_err(); + + assert!(err + .to_string() + .contains("fixed value 2 exceeds site dimension 2")); + Ok(()) +} + +#[test] +fn weighted_remove_site_from_treetn_chain_rejects_weight_length_mismatch() -> Result<()> { + let tt = two_site_tensor_train_f64(); + let (treetn, _) = tensor_train_to_treetn(&tt)?; + + let err = weighted_remove_site_from_treetn_chain::(treetn, 1, &[1.0]).unwrap_err(); + + assert!(err + .to_string() + .contains("weights length 1 must match site dimension 2")); + Ok(()) +} + +#[test] +fn remove_site_from_treetn_chain_rejects_single_site_chain() -> Result<()> { + let tt = single_site_tensor_train_f64(); + let (treetn, _) = tensor_train_to_treetn(&tt)?; + + let err = fix_and_remove_site_from_treetn_chain::(treetn, 0, 1).unwrap_err(); + + assert!(err.to_string().contains( + "cannot remove the only site because scalar zero-site TreeTN chains are not supported" + )); + + let (treetn, _) = tensor_train_to_treetn(&tt)?; + let err = + weighted_remove_site_from_treetn_chain::(treetn, 0, &[1.0, 1.0, 1.0]).unwrap_err(); + + assert!(err.to_string().contains( + "cannot remove the only site because scalar zero-site TreeTN chains are not supported" + )); + Ok(()) +} + #[test] fn tensor_train_to_treetn_with_names_rejects_length_mismatch() { let tt = two_site_tensor_train_f64(); diff --git a/docs/book/src/guides/tensor-basics.md b/docs/book/src/guides/tensor-basics.md index 9aaef984..b313a02d 100644 --- a/docs/book/src/guides/tensor-basics.md +++ b/docs/book/src/guides/tensor-basics.md @@ -112,7 +112,7 @@ Think of it as a generalization of matrix multiplication. ### Pairwise contraction ```rust -use tensor4all_core::{TensorDynLen, Index}; +use tensor4all_core::{TensorDynLen, Index, contract}; use tensor4all_core::index::DynId; // A[i,j] and B[j,k] — contracting over j gives C[i,k]. @@ -123,18 +123,20 @@ let k = Index::new_dyn(4); let a = TensorDynLen::zeros::(vec![i.clone(), j.clone()]).unwrap(); let b = TensorDynLen::zeros::(vec![j.clone(), k.clone()]).unwrap(); -let c = a.contract(&b).unwrap(); // or equivalently: &a * &b +let c = contract(&[&a, &b]).unwrap(); assert_eq!(c.dims(), vec![2, 4]); // j is summed away ``` ### Multi-tensor contraction -`contract_multi` contracts a list of tensors, handling disconnected components -via outer products. `contract_connected` is the same but returns an error if the -contraction graph is disconnected. +`contract` contracts a connected list of tensors. Disconnected inputs are an +error; use `outer_product` explicitly when a tensor product of disconnected +pieces is intended. ```rust -use tensor4all_core::{TensorDynLen, Index, contract_multi, contract_connected, AllowedPairs}; +use tensor4all_core::{ + TensorDynLen, Index, contract, outer_product, +}; use tensor4all_core::index::DynId; let i = Index::new_dyn(2); @@ -154,21 +156,14 @@ let c: TensorDynLen = TensorDynLen::random::(&mut rng, vec![k.clone(), l.clone()]).unwrap(); // Contract A(i,j) * B(j,k) * C(k,l) -> result(i,l) -let result = contract_multi(&[&a, &b, &c], AllowedPairs::All).unwrap(); +let result = contract(&[&a, &b, &c]).unwrap(); assert_eq!(result.dims().iter().product::(), 2 * 5); // i * l -// Restrict which tensor pairs may contract (useful for tree tensor networks). -// Here only (A,B) and (B,C) are connected, so j and k are contracted. -let pairs = [(0usize, 1usize), (1, 2)]; -let result2 = contract_multi(&[&a, &b, &c], AllowedPairs::Specified(&pairs)).unwrap(); -assert_eq!(result2.dims().iter().product::(), 2 * 5); +// Disconnected products are explicit. +let product = outer_product(&a, &c).unwrap(); +assert_eq!(product.dims().iter().product::(), 2 * 3 * 4 * 5); ``` -`AllowedPairs::All` contracts all tensor pairs with matching indices. -`AllowedPairs::Specified` takes a slice of `(usize, usize)` tensor-index pairs -and only contracts between those pairs — useful when the connectivity is known -(e.g. tree tensor networks). - ## Factorization The unified `factorize()` function dispatches to SVD, QR, LU, or CI based on diff --git a/docs/tutorial-code/src/qtt_partial_fourier2d_common.rs b/docs/tutorial-code/src/qtt_partial_fourier2d_common.rs index b2df0255..dfca9df2 100644 --- a/docs/tutorial-code/src/qtt_partial_fourier2d_common.rs +++ b/docs/tutorial-code/src/qtt_partial_fourier2d_common.rs @@ -13,7 +13,7 @@ use std::path::Path; use num_complex::Complex64; use tensor4all_core::index::{DynId, Index, TagSet}; -use tensor4all_core::{ColMajorArrayRef, IndexLike, TensorDynLen, TensorLike}; +use tensor4all_core::{outer_product, ColMajorArrayRef, IndexLike, TensorDynLen}; use tensor4all_quanticstci::{ quanticscrossinterpolate, DiscretizedGrid, QtciOptions, QuanticsTensorCI2, UnfoldingScheme, }; @@ -360,7 +360,7 @@ fn expand_operator_to_interleaved_state( let tensor = tensors_by_node .get_mut(&mid) .ok_or_else(|| format!("missing tensor at expanded node {mid}"))?; - *tensor = tensor.outer_product(&bridge)?; + *tensor = outer_product(tensor, &bridge)?; } } @@ -374,13 +374,13 @@ fn expand_operator_to_interleaved_state( let tensor = tensors_by_node .get_mut(&last_x) .ok_or_else(|| format!("missing tensor at final x node {last_x}"))?; - *tensor = tensor.outer_product(&left_ones)?; + *tensor = outer_product(tensor, &left_ones)?; } { let tensor = tensors_by_node .get_mut(&last_t) .ok_or_else(|| format!("missing tensor at final t node {last_t}"))?; - *tensor = tensor.outer_product(&right_ones)?; + *tensor = outer_product(tensor, &right_ones)?; } } diff --git a/plan/tensor4all-core-api-cleanup.md b/plan/tensor4all-core-api-cleanup.md new file mode 100644 index 00000000..820adbe8 --- /dev/null +++ b/plan/tensor4all-core-api-cleanup.md @@ -0,0 +1,128 @@ +# tensor4all-core API cleanup notes + +## N-ary contraction API + +Current direction: + +- Prefer one public N-ary contraction entry point over separate binary and + N-ary APIs. +- Remove public `contract_pair`-style convenience APIs. Binary contraction + should be expressed as N-ary contraction with two operands. +- Prefer borrowed operands, i.e. a slice of tensor references, for the default + API. This avoids forcing callers to move or clone tensors just to contract + existing values. +- Keep an owned variant only as an explicit optimization path for callers that + can transfer ownership. +- The public API should accept structural labels directly instead of building a + string equation and parsing it again downstream. + +Open naming sketch: + +- `contract(&[&a, &b, &c])` +- `contract_owned(vec![a, b, c])` + +Implemented cleanup steps: + +- Added `contract(&[&TensorDynLen])` as the connected-network default entry. +- Added `contract_owned(Vec)` with the same connected-network + semantics. +- Added `contract_with_options` and `contract_owned_with_options` as the + connected-network advanced entries while retained-index users are migrated to + TreeTN-level APIs. +- Removed the legacy `contract_multi*` names and the temporary + `contract_components_and_outer_product*` helpers. +- Simplified `TensorContractionLike::contract` to the connected default + signature `contract(&[&Self])`. Tensor-edge restrictions now live only in + concrete `ContractionOptions` APIs. + +- `tenferro` now has a canonical `EinsumSubscripts { inputs, output }` payload + for `StdTensorOp::NaryEinsum`. +- String APIs in `tenferro` remain compatibility wrappers that parse once into + integer labels. +- `tensor4all-tensorbackend` native einsum and `tensor4all-core` AD-backed + eager contractions now pass integer subscripts to `tenferro`; strings are + retained only for human-readable diagnostics and path reports. + +## Retained indices + +Current direction: + +- Remove the retained-index feature from the public contraction API unless a + concrete production use case appears. +- Existing uses are mostly tests and C API plumbing. The feature complicates + graph connectivity, AD behavior, and user-facing semantics. +- Site-index-aware partial contraction should remain a TreeTN-level concept + (`contract_pairs`, `diagonal_pairs`, etc.), not a dense tensor contraction + retained-index feature. + +## Allowed tensor edges + +Current direction: + +- `AllowedPairs` has been removed from `tensor4all-core`. +- Normal public contraction now always considers all tensor pairs and contracts + matching contractable indices across a connected tensor graph. +- Tensor-edge restrictions should live at the TreeTN/topology layer, where the + graph is explicit, instead of leaking into dense `TensorDynLen` contraction. + +## Connected vs disconnected contraction + +Current direction: + +- Prefer connected contraction as the normal public semantic. A default + contraction over disconnected components should error because silently + returning an outer product can hide missing links or index bugs. +- Provide an explicit outer-product/combine API for intentional disconnected + products. +- `contract_multi*` and `contract_components_and_outer_product*` have been + removed from the Rust core API. Callers now use `contract` / + `contract_with_options` for connected networks and spell disconnected + products explicitly with `outer_product`. +- The C API multi-tensor retained-index entry was renamed from + `t4a_tensor_contract_multi` to `t4a_tensor_contract_many_retain` so the + legacy name does not leak across the boundary. + +## Outer product + +Current findings: + +- `outer_product` is not part of the core numerical contraction path. Current + uses are mostly shape/topology construction helpers: + - explicit disconnected products after contracting connected components + - default multi-index `delta` construction from pairwise diagonals + - TreeTN dummy links, ones tensors, bridge deltas, and trivial factorization + boundary cases +- Structured tensor storage likely makes several of these uses unnecessary. + For example, multi-index delta/copy tensors and adding unit-valued dummy + axes can be represented directly instead of multiplying by separate tensors. +- Do not keep `outer_product` as a required `TensorContractionLike` method in + the long-term public API unless a concrete generic use case remains. + +Possible replacements: + +- direct constructors for structured multi-index delta/copy tensors +- an explicit helper to attach unit-valued dummy axes to a tensor +- a clearly named low-level tensor product helper only where intentional + +## End-of-session visibility audit + +Completed: + +- `ContractionSpec`, `ContractionError`, `prepare_contraction`, and + `prepare_contraction_pairs` were removed from the top-level public re-exports + and demoted to `pub(crate)`. +- Direct planning tests that need those helpers moved from integration tests to + `index_ops` unit tests. Public integration tests now cover the user-facing + index APIs only. +- The unused `result_dims` field was removed from `ContractionSpec`; result + shape remains derived from the result indices, avoiding duplicate metadata. + +Remaining public-surface cleanup candidates: + +- `TensorContractionLike::outer_product` is still public because TreeTN and + construction helpers currently use it. Long term, replace those use cases with + explicit structured constructors or dummy-axis helpers before removing it from + the trait. +- `contract_pair` remains as compatibility API in several examples/tests. The + intended public direction is still `contract(&[...])` for connected + contractions plus explicit `outer_product` for disconnected products.