diff --git a/src/solve.jl b/src/solve.jl index 6e3512040..1d2107331 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -810,12 +810,34 @@ function promote_f( if f.tgrad !== nothing && !(f.tgrad isa FunctionWrappersWrappers.FunctionWrappersWrapper) f = @set f.tgrad = wrapfun_jac_iip(f.tgrad, (u0, u0, p, t)) end - # Wrap the Jacobian if present, so its type is also erased + # Wrap the Jacobian if present, so its type is also erased. + # Include both dense and sparse matrix signatures when the function + # has a sparsity pattern, since the solver may use either depending on + # the autodiff configuration (AutoSparse creates sparse J from sparsity). if f.jac !== nothing && !(f.jac isa FunctionWrappersWrappers.FunctionWrappersWrapper) - n = length(u0) - J_proto = f.jac_prototype !== nothing ? similar(f.jac_prototype, uElType) : - zeros(uElType, n, n) - f = @set f.jac = wrapfun_jac_iip(f.jac, (J_proto, u0, p, t)) + if f.jac_prototype !== nothing + J_T = Base.promote_op(similar, typeof(f.jac_prototype), Type{uElType}) + sig = Tuple{J_T, typeof(u0), typeof(p), typeof(t)} + f = @set f.jac = FunctionWrappersWrappers.FunctionWrappersWrapper( + Void(f.jac), (sig,), (Nothing,)) + elseif isdefined(f, :sparsity) && f.sparsity isa AbstractMatrix && + !(f.sparsity isa Matrix) + # The sparsity pattern is a non-dense matrix (e.g. SparseMatrixCSC). + # The solver may call the Jacobian with either a dense or sparse matrix + # depending on the autodiff config, so wrap for both signatures. + dense_sig = Tuple{Matrix{uElType}, typeof(u0), typeof(p), typeof(t)} + sparse_J_T = Base.promote_op(similar, typeof(f.sparsity), Type{uElType}) + sparse_sig = Tuple{sparse_J_T, typeof(u0), typeof(p), typeof(t)} + f = @set f.jac = FunctionWrappersWrappers.FunctionWrappersWrapper( + Void(f.jac), + (dense_sig, sparse_sig), + (Nothing, Nothing) + ) + else + sig = Tuple{Matrix{uElType}, typeof(u0), typeof(p), typeof(t)} + f = @set f.jac = FunctionWrappersWrappers.FunctionWrappersWrapper( + Void(f.jac), (sig,), (Nothing,)) + end end return unwrapped_f(f, wrapfun_iip(f.f, (u0, u0, p, t), Val(CS))) else