From 73450a5cf5f3e66ce033fe30d21306b99b61e09c Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Sat, 21 Feb 2026 10:50:53 -0500 Subject: [PATCH] Add VF64-specialized LinearCache for shorter stack traces Implement the VF64 pattern for LinearCache and DefaultLinearSolverInit to reduce type string lengths in stack traces from ~1500 chars to ~500 chars per instance. This addresses the LinearSolve.jl component of DifferentialEquations.jl#1128. New types: - LinearCacheVF64{Tp, Tc, Tlv, S}: 4 type params vs 12 in LinearCache, hardcodes A::Matrix{Float64}, b/u::Vector{Float64}, Pl/Pr::IdentityOperator, abstol/reltol::Float64, assumptions::OperatorAssumptions{Bool} - DefaultLinearSolverInitVF64{TA}: 1 type param vs 25 in DefaultLinearSolverInit, hardcodes all 24 factorization cache slot types for Matrix{Float64} - LinearCacheType = Union{LinearCache, LinearCacheVF64} for dispatch The VF64 cache is automatically constructed when init() detects Matrix{Float64} + Vector{Float64} + DefaultLinearSolver. All existing dispatch is preserved through the Union type. Co-Authored-By: Chris Rackauckas Co-Authored-By: Claude Opus 4.6 --- ext/LinearSolveAMDGPUExt.jl | 4 +- ext/LinearSolveAlgebraicMultigridExt.jl | 4 +- ext/LinearSolveBLISExt.jl | 4 +- ext/LinearSolveCUDAExt.jl | 8 +- ext/LinearSolveCUSOLVERRFExt.jl | 2 +- ext/LinearSolveCliqueTreesExt.jl | 2 +- ext/LinearSolveFastLapackInterfaceExt.jl | 4 +- ext/LinearSolveHYPREExt.jl | 6 +- ext/LinearSolveIterativeSolversExt.jl | 4 +- ext/LinearSolveKrylovKitExt.jl | 4 +- ext/LinearSolveMetalExt.jl | 6 +- ext/LinearSolvePETScExt.jl | 4 +- ext/LinearSolvePardisoExt.jl | 2 +- ext/LinearSolveRecursiveFactorizationExt.jl | 6 +- ext/LinearSolveSparseArraysExt.jl | 6 +- ext/LinearSolveSparspakExt.jl | 2 +- src/LinearSolve.jl | 6 + src/appleaccelerate.jl | 4 +- src/common.jl | 63 ++++++- src/default.jl | 16 +- src/factorization.jl | 19 ++- src/iterative_wrappers.jl | 2 +- src/mkl.jl | 4 +- src/openblas.jl | 4 +- src/simplegmres.jl | 8 +- src/simplelu.jl | 2 +- src/solve_function.jl | 6 +- src/vf64_types.jl | 178 ++++++++++++++++++++ 28 files changed, 309 insertions(+), 71 deletions(-) create mode 100644 src/vf64_types.jl diff --git a/ext/LinearSolveAMDGPUExt.jl b/ext/LinearSolveAMDGPUExt.jl index 7826f33c9..1827411ca 100644 --- a/ext/LinearSolveAMDGPUExt.jl +++ b/ext/LinearSolveAMDGPUExt.jl @@ -7,7 +7,7 @@ using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase # LU Factorization function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::AMDGPUOffloadLUFactorization; + cache::LinearSolve.LinearCacheType, alg::AMDGPUOffloadLUFactorization; kwargs... ) if cache.isfresh @@ -36,7 +36,7 @@ end # QR Factorization function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::AMDGPUOffloadQRFactorization; + cache::LinearSolve.LinearCacheType, alg::AMDGPUOffloadQRFactorization; kwargs... ) if cache.isfresh diff --git a/ext/LinearSolveAlgebraicMultigridExt.jl b/ext/LinearSolveAlgebraicMultigridExt.jl index 05d607456..4444d380e 100644 --- a/ext/LinearSolveAlgebraicMultigridExt.jl +++ b/ext/LinearSolveAlgebraicMultigridExt.jl @@ -1,7 +1,7 @@ module LinearSolveAlgebraicMultigridExt using LinearSolve, AlgebraicMultigrid, LinearAlgebra -using LinearSolve: LinearCache, LinearVerbosity, OperatorAssumptions +using LinearSolve: LinearCache, LinearCacheType, LinearVerbosity, OperatorAssumptions using SciMLBase: SciMLBase, ReturnCode function LinearSolve.init_cacheval( @@ -19,7 +19,7 @@ function LinearSolve.init_cacheval( return SciMLBase.init(amg_alg, A, b; alg.kwargs...) end -function SciMLBase.solve!(cache::LinearCache, alg::AlgebraicMultigridJL; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::AlgebraicMultigridJL; kwargs...) if cache.isfresh cache.cacheval = LinearSolve.init_cacheval( alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, diff --git a/ext/LinearSolveBLISExt.jl b/ext/LinearSolveBLISExt.jl index a8a4a16c2..6291c109c 100644 --- a/ext/LinearSolveBLISExt.jl +++ b/ext/LinearSolveBLISExt.jl @@ -9,7 +9,7 @@ using LinearSolve using LinearAlgebra: BlasInt, LU using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, @blasfunc, chkargsok -using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, SciMLBase, LinearVerbosity, get_blas_operation_info, blas_info_msg +using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, LinearCacheType, SciMLBase, LinearVerbosity, get_blas_operation_info, blas_info_msg using SciMLLogging: SciMLLogging, @SciMLMessage using SciMLBase: ReturnCode @@ -272,7 +272,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearCache, alg::BLISLUFactorization; + cache::LinearCacheType, alg::BLISLUFactorization; kwargs... ) A = cache.A diff --git a/ext/LinearSolveCUDAExt.jl b/ext/LinearSolveCUDAExt.jl index 3c445777a..b936e7613 100644 --- a/ext/LinearSolveCUDAExt.jl +++ b/ext/LinearSolveCUDAExt.jl @@ -58,7 +58,7 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR) end function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::CudaOffloadLUFactorization; + cache::LinearSolve.LinearCacheType, alg::CudaOffloadLUFactorization; kwargs... ) if cache.isfresh @@ -92,7 +92,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::CudaOffloadQRFactorization; + cache::LinearSolve.LinearCacheType, alg::CudaOffloadQRFactorization; kwargs... ) if cache.isfresh @@ -120,7 +120,7 @@ end # Keep the deprecated CudaOffloadFactorization working by forwarding to QR function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization; + cache::LinearSolve.LinearCacheType, alg::CudaOffloadFactorization; kwargs... ) if cache.isfresh @@ -164,7 +164,7 @@ end # Mixed precision CUDA LU implementation function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization; + cache::LinearSolve.LinearCacheType, alg::CUDAOffload32MixedLUFactorization; kwargs... ) if cache.isfresh diff --git a/ext/LinearSolveCUSOLVERRFExt.jl b/ext/LinearSolveCUSOLVERRFExt.jl index 582312c2e..0d01c6662 100644 --- a/ext/LinearSolveCUSOLVERRFExt.jl +++ b/ext/LinearSolveCUSOLVERRFExt.jl @@ -31,7 +31,7 @@ function LinearSolve.init_cacheval( return RFLU(A_gpu; nrhs = nrhs, symbolic = symbolic) end -function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::LinearSolve.CUSOLVERRFFactorization; kwargs...) +function SciMLBase.solve!(cache::LinearSolve.LinearCacheType, alg::LinearSolve.CUSOLVERRFFactorization; kwargs...) A = cache.A # Convert to appropriate GPU format if needed diff --git a/ext/LinearSolveCliqueTreesExt.jl b/ext/LinearSolveCliqueTreesExt.jl index e77d24fc5..081ea5c9b 100644 --- a/ext/LinearSolveCliqueTreesExt.jl +++ b/ext/LinearSolveCliqueTreesExt.jl @@ -26,7 +26,7 @@ function LinearSolve.init_cacheval( return makefactor(A, alg.alg, alg.snd) end -function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CliqueTreesFactorization; kwargs...) +function SciMLBase.solve!(cache::LinearSolve.LinearCacheType, alg::CliqueTreesFactorization; kwargs...) A = cache.A u = cache.u b = cache.b diff --git a/ext/LinearSolveFastLapackInterfaceExt.jl b/ext/LinearSolveFastLapackInterfaceExt.jl index c8b8680ff..de15c9b50 100644 --- a/ext/LinearSolveFastLapackInterfaceExt.jl +++ b/ext/LinearSolveFastLapackInterfaceExt.jl @@ -21,7 +21,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::FastLUFactorization; kwargs... + cache::LinearSolve.LinearCacheType, alg::FastLUFactorization; kwargs... ) A = cache.A A = convert(AbstractMatrix, A) @@ -78,7 +78,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::FastQRFactorization{P}; + cache::LinearSolve.LinearCacheType, alg::FastQRFactorization{P}; kwargs... ) where {P} A = cache.A diff --git a/ext/LinearSolveHYPREExt.jl b/ext/LinearSolveHYPREExt.jl index c95e86af6..cbbbb81f4 100644 --- a/ext/LinearSolveHYPREExt.jl +++ b/ext/LinearSolveHYPREExt.jl @@ -3,7 +3,7 @@ module LinearSolveHYPREExt using LinearAlgebra using HYPRE.LibHYPRE: HYPRE_Complex using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector -using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve, +using LinearSolve: HYPREAlgorithm, LinearCache, LinearCacheType, LinearProblem, LinearSolve, OperatorAssumptions, default_tol, init_cacheval, __issquare, __conditioning, LinearSolveAdjoint, LinearVerbosity using SciMLLogging: SciMLLogging, verbosity_to_int, @SciMLMessage @@ -176,7 +176,7 @@ create_solver(::Type{S}, comm) where {S <: COMM_SOLVERS} = S(comm) const NO_COMM_SOLVERS = Union{HYPRE.BoomerAMG, HYPRE.Hybrid, HYPRE.ILU} create_solver(::Type{S}, comm) where {S <: NO_COMM_SOLVERS} = S() -function create_solver(alg::HYPREAlgorithm, cache::LinearCache) +function create_solver(alg::HYPREAlgorithm, cache::LinearCacheType) # If the solver is already instantiated, return it directly if alg.solver isa HYPRE.HYPRESolver return alg.solver @@ -231,7 +231,7 @@ function create_solver(alg::HYPREAlgorithm, cache::LinearCache) end # TODO: How are args... and kwargs... supposed to be used here? -function SciMLBase.solve!(cache::LinearCache, alg::HYPREAlgorithm, args...; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::HYPREAlgorithm, args...; kwargs...) # It is possible to reach here without HYPRE.Init() being called if HYPRE structures are # only to be created here internally (i.e. when cache.A::SparseMatrixCSC and not a # ::HYPREMatrix created externally by the user). Be nice to the user and call it :) diff --git a/ext/LinearSolveIterativeSolversExt.jl b/ext/LinearSolveIterativeSolversExt.jl index cafdc03d4..3dba7e8a8 100644 --- a/ext/LinearSolveIterativeSolversExt.jl +++ b/ext/LinearSolveIterativeSolversExt.jl @@ -1,7 +1,7 @@ module LinearSolveIterativeSolversExt using LinearSolve, LinearAlgebra -using LinearSolve: LinearCache, DEFAULT_PRECS, LinearVerbosity +using LinearSolve: LinearCache, LinearCacheType, DEFAULT_PRECS, LinearVerbosity import LinearSolve: IterativeSolversJL using SciMLLogging: SciMLLogging, @SciMLMessage @@ -132,7 +132,7 @@ function LinearSolve.init_cacheval( return iterable end -function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::IterativeSolversJL; kwargs...) if cache.precsisfresh && !isnothing(alg.precs) Pl, Pr = alg.precs(cache.Pl, cache.Pr) cache.Pl = Pl diff --git a/ext/LinearSolveKrylovKitExt.jl b/ext/LinearSolveKrylovKitExt.jl index 4e4b24998..c404fa5ef 100644 --- a/ext/LinearSolveKrylovKitExt.jl +++ b/ext/LinearSolveKrylovKitExt.jl @@ -1,7 +1,7 @@ module LinearSolveKrylovKitExt using LinearSolve, KrylovKit, LinearAlgebra -using LinearSolve: LinearCache, DEFAULT_PRECS +using LinearSolve: LinearCache, LinearCacheType, DEFAULT_PRECS using SciMLLogging: SciMLLogging, @SciMLMessage, verbosity_to_int function LinearSolve.KrylovKitJL( @@ -24,7 +24,7 @@ end LinearSolve.default_alias_A(::KrylovKitJL, ::Any, ::Any) = true LinearSolve.default_alias_b(::KrylovKitJL, ::Any, ::Any) = true -function SciMLBase.solve!(cache::LinearCache, alg::KrylovKitJL; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::KrylovKitJL; kwargs...) atol = float(cache.abstol) rtol = float(cache.reltol) maxiter = cache.maxiters diff --git a/ext/LinearSolveMetalExt.jl b/ext/LinearSolveMetalExt.jl index 2a989e582..b7d82c81c 100644 --- a/ext/LinearSolveMetalExt.jl +++ b/ext/LinearSolveMetalExt.jl @@ -4,7 +4,7 @@ using Metal, LinearSolve using LinearAlgebra, SciMLBase using SciMLBase: AbstractSciMLOperator using LinearSolve: ArrayInterface, MKLLUFactorization, MetalOffload32MixedLUFactorization, - @get_cacheval, LinearCache, SciMLBase, OperatorAssumptions, LinearVerbosity + @get_cacheval, LinearCache, LinearCacheType, SciMLBase, OperatorAssumptions, LinearVerbosity @static if Sys.isapple() @@ -24,7 +24,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearCache, alg::MetalLUFactorization; + cache::LinearCacheType, alg::MetalLUFactorization; kwargs... ) A = cache.A @@ -63,7 +63,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearCache, alg::MetalOffload32MixedLUFactorization; + cache::LinearCacheType, alg::MetalOffload32MixedLUFactorization; kwargs... ) A = cache.A diff --git a/ext/LinearSolvePETScExt.jl b/ext/LinearSolvePETScExt.jl index a4505fae1..ae37af95e 100644 --- a/ext/LinearSolvePETScExt.jl +++ b/ext/LinearSolvePETScExt.jl @@ -5,7 +5,7 @@ using PETSc using PETSc: MPI using PETSc: petsclibs using SparseArrays: SparseMatrixCSC, sparse -using LinearSolve: PETScAlgorithm, LinearCache, LinearProblem, LinearSolve, +using LinearSolve: PETScAlgorithm, LinearCache, LinearCacheType, LinearProblem, LinearSolve, OperatorAssumptions, default_tol, init_cacheval, __issquare, __conditioning, LinearSolveAdjoint, LinearVerbosity using SciMLLogging: SciMLLogging, verbosity_to_int, @SciMLMessage @@ -98,7 +98,7 @@ function pc_type_string(pc_type::Symbol) return get(pc_types, pc_type, string(pc_type)) end -function SciMLBase.solve!(cache::LinearCache, alg::PETScAlgorithm, args...; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::PETScAlgorithm, args...; kwargs...) pcache = cache.cacheval # Get element type from the problem diff --git a/ext/LinearSolvePardisoExt.jl b/ext/LinearSolvePardisoExt.jl index 1c20ddec9..4f2da8ea4 100644 --- a/ext/LinearSolvePardisoExt.jl +++ b/ext/LinearSolvePardisoExt.jl @@ -140,7 +140,7 @@ function LinearSolve.init_cacheval( return solver end -function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs...) +function SciMLBase.solve!(cache::LinearSolve.LinearCacheType, alg::PardisoJL; kwargs...) (; A, b, u) = cache A = convert(AbstractMatrix, A) if cache.isfresh diff --git a/ext/LinearSolveRecursiveFactorizationExt.jl b/ext/LinearSolveRecursiveFactorizationExt.jl index bee4bb413..8ec82da71 100644 --- a/ext/LinearSolveRecursiveFactorizationExt.jl +++ b/ext/LinearSolveRecursiveFactorizationExt.jl @@ -10,7 +10,7 @@ using SciMLLogging: @SciMLMessage LinearSolve.userecursivefactorization(A::Union{Nothing, AbstractMatrix}) = true function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::RFLUFactorization{P, T}; + cache::LinearSolve.LinearCacheType, alg::RFLUFactorization{P, T}; kwargs... ) where {P, T} A = cache.A @@ -63,7 +63,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::RF32MixedLUFactorization{P, T}; + cache::LinearSolve.LinearCacheType, alg::RF32MixedLUFactorization{P, T}; kwargs... ) where {P, T} A = cache.A @@ -115,7 +115,7 @@ function SciMLBase.solve!( end function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::ButterflyFactorization; + cache::LinearSolve.LinearCacheType, alg::ButterflyFactorization; kwargs... ) cache_A = cache.A diff --git a/ext/LinearSolveSparseArraysExt.jl b/ext/LinearSolveSparseArraysExt.jl index 6f72fc5fb..aa0efe124 100644 --- a/ext/LinearSolveSparseArraysExt.jl +++ b/ext/LinearSolveSparseArraysExt.jl @@ -233,7 +233,7 @@ end end function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::UMFPACKFactorization; kwargs... + cache::LinearSolve.LinearCacheType, alg::UMFPACKFactorization; kwargs... ) A = cache.A A = convert(AbstractMatrix, A) @@ -284,7 +284,7 @@ end else function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::UMFPACKFactorization; kwargs... + cache::LinearSolve.LinearCacheType, alg::UMFPACKFactorization; kwargs... ) error("UMFPACKFactorization requires GPL libraries (UMFPACK). Rebuild Julia with USE_GPL_LIBS=1 or use an alternative algorithm like SparspakFactorization") end @@ -399,7 +399,7 @@ function LinearSolve.init_cacheval( end end -function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::KLUFactorization; kwargs...) +function SciMLBase.solve!(cache::LinearSolve.LinearCacheType, alg::KLUFactorization; kwargs...) A = cache.A A = convert(AbstractMatrix, A) if cache.isfresh diff --git a/ext/LinearSolveSparspakExt.jl b/ext/LinearSolveSparspakExt.jl index 513419bd8..ff7c730d4 100644 --- a/ext/LinearSolveSparspakExt.jl +++ b/ext/LinearSolveSparspakExt.jl @@ -47,7 +47,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::SparspakFactorization; kwargs... + cache::LinearSolve.LinearCacheType, alg::SparspakFactorization; kwargs... ) A = cache.A if cache.isfresh diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index a36de6f9d..3788c2510 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -388,6 +388,9 @@ end const BLASELTYPES = Union{Float32, Float64, ComplexF32, ComplexF64} function defaultalg_symbol end +function defaultalg_symbol(::Type{T}) where {T} + return Base.typename(SciMLBase.parameterless_type(T)).name +end include("verbosity.jl") include("blas_logging.jl") @@ -405,6 +408,7 @@ include("preconditioners.jl") include("preferences.jl") include("solve_function.jl") include("default.jl") +include("vf64_types.jl") include("init.jl") include("adjoint.jl") @@ -537,4 +541,6 @@ export LinearSolveAdjoint export LinearVerbosity +export LinearCacheVF64, LinearCacheType, DefaultLinearSolverInitVF64, DefaultLinearSolverInitType + end diff --git a/src/appleaccelerate.jl b/src/appleaccelerate.jl index ce61c1ea7..f6c5e39d0 100644 --- a/src/appleaccelerate.jl +++ b/src/appleaccelerate.jl @@ -300,7 +300,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearCache, alg::AppleAccelerateLUFactorization; + cache::LinearCacheType, alg::AppleAccelerateLUFactorization; kwargs... ) __appleaccelerate_isavailable() || @@ -407,7 +407,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearCache, alg::AppleAccelerate32MixedLUFactorization; + cache::LinearCacheType, alg::AppleAccelerate32MixedLUFactorization; kwargs... ) __appleaccelerate_isavailable() || diff --git a/src/common.jl b/src/common.jl index 27d5f5c20..9f37324c7 100644 --- a/src/common.jl +++ b/src/common.jl @@ -126,7 +126,51 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, Tlv <: Linear sensealg::S end -function Base.setproperty!(cache::LinearCache, name::Symbol, x) +""" + LinearCacheVF64{Tp, Tc, Tlv, S} + +VF64-specialized variant of `LinearCache` for the common case of in-place `Vector{Float64}` +problems with `DefaultLinearSolver`. Reduces type parameter count from 12 to 4 by hardcoding +`A::Matrix{Float64}`, `b::Vector{Float64}`, `u::Vector{Float64}`, +`Pl::IdentityOperator`, `Pr::IdentityOperator`, `abstol/reltol::Float64`, +and `assumptions::OperatorAssumptions{Bool}`. + +This dramatically shortens type strings in stack traces (from ~1500 chars to ~100 chars) +while maintaining identical behavior through Union-based dispatch with `LinearCache`. +All field names match `LinearCache` exactly, so all existing field access works unchanged. +""" +mutable struct LinearCacheVF64{Tp, Tc, Tlv <: LinearVerbosity, S} + A::Matrix{Float64} + b::Vector{Float64} + u::Vector{Float64} + p::Tp + alg::DefaultLinearSolver + cacheval::Tc + isfresh::Bool + precsisfresh::Bool + Pl::IdentityOperator + Pr::IdentityOperator + abstol::Float64 + reltol::Float64 + maxiters::Int + verbose::Tlv + assumptions::OperatorAssumptions{Bool} + sensealg::S +end + +""" + LinearCacheType + +Union type for dispatch compatibility between `LinearCache` and `LinearCacheVF64`. +All methods that previously dispatched on `LinearCache` should dispatch on +`LinearCacheType` to support both variants transparently. +""" +const LinearCacheType = Union{LinearCache, LinearCacheVF64} + +# Stub for VF64 cache construction - overridden in vf64_types.jl +_try_build_vf64_cache(args...) = nothing + +function Base.setproperty!(cache::LinearCacheType, name::Symbol, x) if name === :A setfield!(cache, :isfresh, true) setfield!(cache, :precsisfresh, true) @@ -136,14 +180,12 @@ function Base.setproperty!(cache::LinearCache, name::Symbol, x) # In case there is something that needs to be done when b is updated update_cacheval!(cache, :b, x) elseif name === :cacheval && cache.alg isa DefaultLinearSolver - @assert cache.cacheval isa DefaultLinearSolverInit return __setfield!(cache.cacheval, cache.alg, x) - # return setfield!(cache.cacheval, Symbol(cache.alg.alg), x) end return setfield!(cache, name, x) end -function update_cacheval!(cache::LinearCache, name::Symbol, x) +function update_cacheval!(cache::LinearCacheType, name::Symbol, x) return update_cacheval!(cache, cache.cacheval, name, x) end update_cacheval!(cache, cacheval, name::Symbol, x) = cacheval @@ -401,6 +443,15 @@ function __init( precsisfresh = false Tc = typeof(cacheval) + # Try VF64 specialization for the common Vector{Float64} + DefaultLinearSolver path + vf64_cache = _try_build_vf64_cache( + A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, + Pl, Pr, abstol, reltol, maxiters, verbose_spec, assumptions, sensealg + ) + if vf64_cache !== nothing + return vf64_cache + end + cache = LinearCache{ typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc, typeof(Pl), typeof(Pr), typeof(reltol), typeof(verbose_spec), typeof(assumptions.issq), @@ -413,7 +464,7 @@ function __init( end function SciMLBase.reinit!( - cache::LinearCache; + cache::LinearCacheType; A = nothing, b = cache.b, u = cache.u, @@ -463,7 +514,7 @@ function SciMLBase.solve( return solve!(init(prob, alg, args...; kwargs...)) end -function SciMLBase.solve!(cache::LinearCache, args...; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, args...; kwargs...) return solve!(cache, cache.alg, args...; kwargs...) end diff --git a/src/default.jl b/src/default.jl index d7ccf5b47..a9402b3e9 100644 --- a/src/default.jl +++ b/src/default.jl @@ -476,7 +476,7 @@ function SciMLBase.init( end function SciMLBase.solve!( - cache::LinearCache, alg::Nothing, + cache::LinearCacheType, alg::Nothing, args...; assump::OperatorAssumptions = OperatorAssumptions(), kwargs... ) @@ -555,12 +555,10 @@ end return Expr(:call, :DefaultLinearSolverInit, caches..., :A_original) end -function defaultalg_symbol(::Type{T}) where {T} - return Base.typename(SciMLBase.parameterless_type(T)).name -end -defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization - -defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted +## NOTE: The generic defaultalg_symbol(::Type{T}) method is defined in LinearSolve.jl +## before factorization.jl is included, so that @generated functions can use it. +## The specialized overloads below are at the end of factorization.jl (after the +## types they dispatch on are defined). """ if alg.alg === DefaultAlgorithmChoice.LUFactorization @@ -570,7 +568,7 @@ else end """ @generated function SciMLBase.solve!( - cache::LinearCache, alg::DefaultLinearSolver, + cache::LinearCacheType, alg::DefaultLinearSolver, args...; assump::OperatorAssumptions = OperatorAssumptions(), kwargs... @@ -822,7 +820,7 @@ else end ``` """ -@generated function defaultalg_adjoint_eval(cache::LinearCache, dy) +@generated function defaultalg_adjoint_eval(cache::LinearCacheType, dy) ex = :() for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T)) newex = if alg in Symbol.( diff --git a/src/factorization.jl b/src/factorization.jl index 904ef5b00..03dce2b59 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -1,5 +1,5 @@ @generated function SciMLBase.solve!( - cache::LinearCache, alg::AbstractFactorization; + cache::LinearCacheType, alg::AbstractFactorization; kwargs... ) return quote @@ -146,7 +146,7 @@ end GenericLUFactorization() = GenericLUFactorization(RowMaximum()) -function SciMLBase.solve!(cache::LinearCache, alg::LUFactorization; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::LUFactorization; kwargs...) A = cache.A A = convert(AbstractMatrix, A) if cache.isfresh @@ -229,7 +229,7 @@ function init_cacheval( end function SciMLBase.solve!( - cache::LinearSolve.LinearCache, alg::GenericLUFactorization; + cache::LinearSolve.LinearCacheType, alg::GenericLUFactorization; kwargs... ) A = cache.A @@ -1129,7 +1129,7 @@ function init_cacheval( return nothing end -function SciMLBase.solve!(cache::LinearCache, alg::CHOLMODFactorization; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::CHOLMODFactorization; kwargs...) A = cache.A A = convert(AbstractMatrix, A) @@ -1220,7 +1220,7 @@ function init_cacheval( return nothing end -function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::NormalCholeskyFactorization; kwargs...) A = cache.A A = convert(AbstractMatrix, A) if cache.isfresh @@ -1286,7 +1286,7 @@ function init_cacheval( end function SciMLBase.solve!( - cache::LinearCache, alg::NormalBunchKaufmanFactorization; + cache::LinearCacheType, alg::NormalBunchKaufmanFactorization; kwargs... ) A = cache.A @@ -1317,7 +1317,7 @@ function init_cacheval( end function SciMLBase.solve!( - cache::LinearCache, alg::DiagonalFactorization; + cache::LinearCacheType, alg::DiagonalFactorization; kwargs... ) A = convert(AbstractMatrix, cache.A) @@ -1466,3 +1466,8 @@ for alg in vcat( ) end end + +# Specialized defaultalg_symbol overloads (must be after GenericFactorization and +# QRFactorization are defined, but before any @generated code is invoked for these types) +defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization +defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index c7e405734..7b1cffcc4 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -276,7 +276,7 @@ function init_cacheval( return nothing end -function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::KrylovJL; kwargs...) if cache.precsisfresh && !isnothing(alg.precs) Pl, Pr = alg.precs(cache.A, cache.p) cache.Pl = Pl diff --git a/src/mkl.jl b/src/mkl.jl index b0d17db7c..53bcb98ae 100644 --- a/src/mkl.jl +++ b/src/mkl.jl @@ -290,7 +290,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearCache, alg::MKLLUFactorization; + cache::LinearCacheType, alg::MKLLUFactorization; kwargs... ) __mkl_isavailable() || @@ -411,7 +411,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearCache, alg::MKL32MixedLUFactorization; + cache::LinearCacheType, alg::MKL32MixedLUFactorization; kwargs... ) __mkl_isavailable() || diff --git a/src/openblas.jl b/src/openblas.jl index 2c16f0d0e..88dc498a5 100644 --- a/src/openblas.jl +++ b/src/openblas.jl @@ -311,7 +311,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearCache, alg::OpenBLASLUFactorization; + cache::LinearCacheType, alg::OpenBLASLUFactorization; kwargs... ) __openblas_isavailable() || @@ -417,7 +417,7 @@ function LinearSolve.init_cacheval( end function SciMLBase.solve!( - cache::LinearCache, alg::OpenBLAS32MixedLUFactorization; + cache::LinearCacheType, alg::OpenBLAS32MixedLUFactorization; kwargs... ) __openblas_isavailable() || diff --git a/src/simplegmres.jl b/src/simplegmres.jl index 0d6a22973..c460ea7e4 100644 --- a/src/simplegmres.jl +++ b/src/simplegmres.jl @@ -87,7 +87,7 @@ end warm_start::Bool end -function update_cacheval!(cache::LinearCache, cacheval::SimpleGMRESCache, name::Symbol, x) +function update_cacheval!(cache::LinearCacheType, cacheval::SimpleGMRESCache, name::Symbol, x) (name != :b || cache.isfresh) && return cacheval vec(cacheval.w) .= vec(x) fill!(cacheval.x, 0) @@ -151,7 +151,7 @@ _norm2(x, dims) = .√(sum(abs2, x; dims)) default_alias_A(::SimpleGMRES, ::Any, ::Any) = false default_alias_b(::SimpleGMRES, ::Any, ::Any) = false -function SciMLBase.solve!(cache::LinearCache, alg::SimpleGMRES; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::SimpleGMRES; kwargs...) if cache.isfresh solver = init_cacheval( alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, @@ -230,7 +230,7 @@ function _init_cacheval( ) end -function SciMLBase.solve!(cache::SimpleGMRESCache{false}, lincache::LinearCache) +function SciMLBase.solve!(cache::SimpleGMRESCache{false}, lincache::LinearCacheType) (; memory, n, restart, maxiters, blocksize, ε, PlisI, PrisI, Pl, Pr) = cache (; Δx, q, p, x, A, b, abstol, reltol, w, V, s, c, z, R, β, warm_start) = cache @@ -474,7 +474,7 @@ function _init_cacheval( ) end -function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCache) +function SciMLBase.solve!(cache::SimpleGMRESCache{true}, lincache::LinearCacheType) (; memory, n, restart, maxiters, blocksize, ε, PlisI, PrisI, Pl, Pr) = cache (; Δx, q, p, x, A, b, abstol, reltol, w, V, s, c, z, R, β, warm_start) = cache bsize = n ÷ blocksize diff --git a/src/simplelu.jl b/src/simplelu.jl index e9ffc6b54..be4f051b2 100644 --- a/src/simplelu.jl +++ b/src/simplelu.jl @@ -205,7 +205,7 @@ end default_alias_A(::SimpleLUFactorization, ::Any, ::Any) = true default_alias_b(::SimpleLUFactorization, ::Any, ::Any) = true -function SciMLBase.solve!(cache::LinearCache, alg::SimpleLUFactorization; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::SimpleLUFactorization; kwargs...) if cache.isfresh cache.cacheval.A .= cache.A simplelu_factorize!(cache.cacheval, alg.pivot) diff --git a/src/solve_function.jl b/src/solve_function.jl index a6e9e4257..999a6fdea 100644 --- a/src/solve_function.jl +++ b/src/solve_function.jl @@ -46,7 +46,7 @@ struct LinearSolveFunction{F} <: AbstractSolveFunction end function SciMLBase.solve!( - cache::LinearCache, alg::LinearSolveFunction, + cache::LinearCacheType, alg::LinearSolveFunction, args...; kwargs... ) (; A, b, u, p, isfresh, Pl, Pr, cacheval) = cache @@ -99,7 +99,7 @@ struct DirectLdiv!{cache} <: AbstractSolveFunction end # Default solve! for non-caching or matrix types that don't need caching -function SciMLBase.solve!(cache::LinearCache, alg::DirectLdiv!{false}, args...; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::DirectLdiv!{false}, args...; kwargs...) (; A, b, u) = cache ldiv!(u, A, b) return SciMLBase.build_linear_solution(alg, u, nothing, cache) @@ -107,7 +107,7 @@ end # For caching DirectLdiv! with general matrices, just use regular ldiv! # (caching is only needed for specific matrix types like Tridiagonal) -function SciMLBase.solve!(cache::LinearCache, alg::DirectLdiv!{true}, args...; kwargs...) +function SciMLBase.solve!(cache::LinearCacheType, alg::DirectLdiv!{true}, args...; kwargs...) (; A, b, u) = cache ldiv!(u, A, b) return SciMLBase.build_linear_solution(alg, u, nothing, cache) diff --git a/src/vf64_types.jl b/src/vf64_types.jl new file mode 100644 index 000000000..9be1e1a23 --- /dev/null +++ b/src/vf64_types.jl @@ -0,0 +1,178 @@ +# VF64 specialized types for reducing type parameter counts +# when A::Matrix{Float64} and b::Vector{Float64} with DefaultLinearSolver. +# +# This file defines DefaultLinearSolverInitVF64 which hardcodes all 24 factorization +# cache types for Matrix{Float64}, reducing the type parameter count from 25 to 1 +# (only A_backup type remains parameterized). +# +# Together with LinearCacheVF64 (defined in common.jl), this reduces the total +# LinearCache type string from ~1500 chars to ~100 chars in stack traces. + +# Type alias for the default LinearVerbosity (Standard preset) +const _DefaultLinearVerbosity = typeof(LinearVerbosity()) + +# Compute concrete factorization types for Matrix{Float64} at module load time +# and define DefaultLinearSolverInitVF64 with hardcoded types. +let + _A = [1.0 0.0; 0.0 1.0] + _b = [1.0, 1.0] + _u = [0.0, 0.0] + _Pl = IdentityOperator(2) + _Pr = IdentityOperator(2) + _verbose = LinearVerbosity() + _assump = OperatorAssumptions(true) + _alg = DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization) + + _cacheval = _init_default_cacheval( + _alg, _A, _b, _u, _Pl, _Pr, 2, sqrt(eps()), sqrt(eps()), + _verbose, _assump, _A + ) + + _T = typeof(_cacheval) + _tparams = _T.parameters + + @eval begin + """ + DefaultLinearSolverInitVF64{TA} + + VF64-specialized variant of `DefaultLinearSolverInit` for the common case of + `Matrix{Float64}` linear systems. All 24 factorization cache slot types are + hardcoded to their concrete types for `Matrix{Float64}`, reducing the type + parameter count from 25 to 1 (only `A_backup::TA` remains parameterized). + + Field names match `DefaultLinearSolverInit` exactly for transparent dispatch. + """ + mutable struct DefaultLinearSolverInitVF64{TA} + LUFactorization::$(_tparams[1]) + QRFactorization::$(_tparams[2]) + DiagonalFactorization::$(_tparams[3]) + var"DirectLdiv!"::$(_tparams[4]) + SparspakFactorization::$(_tparams[5]) + KLUFactorization::$(_tparams[6]) + UMFPACKFactorization::$(_tparams[7]) + KrylovJL_GMRES::$(_tparams[8]) + GenericLUFactorization::$(_tparams[9]) + RFLUFactorization::$(_tparams[10]) + LDLtFactorization::$(_tparams[11]) + BunchKaufmanFactorization::$(_tparams[12]) + CHOLMODFactorization::$(_tparams[13]) + SVDFactorization::$(_tparams[14]) + CholeskyFactorization::$(_tparams[15]) + NormalCholeskyFactorization::$(_tparams[16]) + AppleAccelerateLUFactorization::$(_tparams[17]) + MKLLUFactorization::$(_tparams[18]) + QRFactorizationPivoted::$(_tparams[19]) + KrylovJL_CRAIGMR::$(_tparams[20]) + KrylovJL_LSMR::$(_tparams[21]) + BLISLUFactorization::$(_tparams[22]) + CudaOffloadLUFactorization::$(_tparams[23]) + MetalLUFactorization::$(_tparams[24]) + A_backup::TA + end + end +end + +""" + DefaultLinearSolverInitType + +Union type for dispatch compatibility between `DefaultLinearSolverInit` and +`DefaultLinearSolverInitVF64`. +""" +const DefaultLinearSolverInitType = Union{DefaultLinearSolverInit, DefaultLinearSolverInitVF64} + +# Extend the trait for VF64 variant +_is_default_linear_solver_init(::DefaultLinearSolverInitVF64) = true + +# __setfield! for DefaultLinearSolverInitVF64 - same generated logic as DefaultLinearSolverInit +@generated function __setfield!(cache::DefaultLinearSolverInitVF64, alg::DefaultLinearSolver, v) + ex = :() + for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T)) + newex = quote + setfield!(cache, $(Meta.quot(alg)), v) + end + alg_enum = getproperty(LinearSolve.DefaultAlgorithmChoice, alg) + ex = if ex == :() + Expr( + :elseif, :(alg.alg == $(alg_enum)), newex, + :(error("Algorithm Choice not Allowed")) + ) + else + Expr(:elseif, :(alg.alg == $(alg_enum)), newex, ex) + end + end + return ex = Expr(:if, ex.args...) +end + +# Handle special case of Column-pivoted QR fallback for LU +function __setfield!( + cache::DefaultLinearSolverInitVF64, + alg::DefaultLinearSolver, v::LinearAlgebra.QRPivoted + ) + return setfield!(cache, :QRFactorizationPivoted, v) +end + +""" + _convert_to_vf64_cacheval(cache::DefaultLinearSolverInit) + +Convert a generic `DefaultLinearSolverInit` to `DefaultLinearSolverInitVF64` +by copying all field values. This is called during `LinearCacheVF64` construction. +""" +function _convert_to_vf64_cacheval(cache::DefaultLinearSolverInit) + return DefaultLinearSolverInitVF64( + cache.LUFactorization, + cache.QRFactorization, + cache.DiagonalFactorization, + getfield(cache, Symbol("DirectLdiv!")), + cache.SparspakFactorization, + cache.KLUFactorization, + cache.UMFPACKFactorization, + cache.KrylovJL_GMRES, + cache.GenericLUFactorization, + cache.RFLUFactorization, + cache.LDLtFactorization, + cache.BunchKaufmanFactorization, + cache.CHOLMODFactorization, + cache.SVDFactorization, + cache.CholeskyFactorization, + cache.NormalCholeskyFactorization, + cache.AppleAccelerateLUFactorization, + cache.MKLLUFactorization, + cache.QRFactorizationPivoted, + cache.KrylovJL_CRAIGMR, + cache.KrylovJL_LSMR, + cache.BLISLUFactorization, + cache.CudaOffloadLUFactorization, + cache.MetalLUFactorization, + cache.A_backup, + ) +end + +""" + _try_build_vf64_cache(A, b, u, p, alg, cacheval, isfresh, precsisfresh, + Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg) + +Attempt to construct a `LinearCacheVF64` if all types match the VF64 pattern: +- `A::Matrix{Float64}`, `b::Vector{Float64}`, `u::Vector{Float64}` +- `alg::DefaultLinearSolver` +- `cacheval::DefaultLinearSolverInit` +- `Pl::IdentityOperator`, `Pr::IdentityOperator` +- `abstol::Float64`, `reltol::Float64` +- `assumptions::OperatorAssumptions{Bool}` + +Returns `nothing` if the types don't match, allowing fallback to generic `LinearCache`. +""" +function _try_build_vf64_cache( + A::Matrix{Float64}, b::Vector{Float64}, u::Vector{Float64}, + p, alg::DefaultLinearSolver, cacheval::DefaultLinearSolverInit, + isfresh::Bool, precsisfresh::Bool, + Pl::IdentityOperator, Pr::IdentityOperator, + abstol::Float64, reltol::Float64, maxiters::Int, + verbose, assumptions::OperatorAssumptions{Bool}, + sensealg + ) + vf64_cacheval = _convert_to_vf64_cacheval(cacheval) + return LinearCacheVF64{typeof(p), typeof(vf64_cacheval), typeof(verbose), typeof(sensealg)}( + A, b, u, p, alg, vf64_cacheval, isfresh, precsisfresh, + Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg + ) +end