Skip to content
This repository was archived by the owner on May 12, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FlexUnits = "76e01b6b-c995-4ce6-8559-91e72a3d4e95"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -51,6 +52,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
[extensions]
DiffEqBaseCUDAExt = "CUDA"
DiffEqBaseChainRulesCoreExt = "ChainRulesCore"
DiffEqBaseDynamicQuantitiesExt = "DynamicQuantities"
DiffEqBaseEnzymeExt = ["ChainRulesCore", "Enzyme"]
DiffEqBaseFlexUnitsExt = "FlexUnits"
DiffEqBaseForwardDiffExt = ["ForwardDiff"]
Expand All @@ -74,6 +76,7 @@ ChainRulesCore = "1"
ConcreteStructs = "0.2.3"
DifferentiationInterface = "0.7"
Distributions = "0.25"
DynamicQuantities = "1"
DocStringExtensions = "0.9"
Enzyme = "0.13.100"
FastBroadcast = "0.3.5"
Expand Down Expand Up @@ -119,6 +122,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FlexUnits = "76e01b6b-c995-4ce6-8559-91e72a3d4e95"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -138,4 +142,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Distributed", "Measurements", "Unitful", "FlexUnits", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Test", "Distributions", "Aqua"]
test = ["Distributed", "Measurements", "Unitful", "FlexUnits", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Pkg", "Random", "ReverseDiff", "StaticArrays", "SafeTestsets", "Test", "Distributions", "DynamicQuantities", "Aqua"]
67 changes: 67 additions & 0 deletions ext/DiffEqBaseDynamicQuantitiesExt.jl

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MilesCranmerBot I think you forgot to include unittests in this PR? Please add them and ensure 100% code coverage here.

Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
module DiffEqBaseDynamicQuantitiesExt

using DiffEqBase
using DynamicQuantities
using LinearAlgebra
import DiffEqBase: default_factorize

@inline DiffEqBase.ODE_DEFAULT_NORM(u::UnionAbstractQuantity, t) = abs(ustrip(u))
@inline function DiffEqBase.UNITLESS_ABS2(x::UnionAbstractQuantity)
return real(abs2(ustrip(x)))
end

DiffEqBase._rate_prototype(u, t::UnionAbstractQuantity, onet) = u / oneunit(t)
DiffEqBase.timedepentdtmin(t::UnionAbstractQuantity, dtmin) =
abs(ustrip(dtmin / oneunit(t)) * oneunit(t))

# Rosenbrock/SDIRK solvers form W/J matrices with Quantity eltype. Factorize/solve in
# value-space (Float64), but return solutions with the RHS units.
struct DQUnitlessLU{F, UT}
F::F
ut::UT
end

@inline function _infer_ut(A::AbstractMatrix{<:UnionAbstractQuantity})
@inbounds for a in A
va = ustrip(a)
if !iszero(va)
return oneunit(inv(a))
end
end
return oneunit(1.0)
end

function default_factorize(A::AbstractMatrix{<:UnionAbstractQuantity})
isempty(A) && return DQUnitlessLU(
lu(Matrix{Float64}(undef, 0, 0); check = false),
oneunit(1.0),
)
ut = _infer_ut(A)
return DQUnitlessLU(lu(ustrip.(A); check = false), ut)
end

function LinearAlgebra.ldiv!(
x::AbstractVector{<:UnionAbstractQuantity},
W::DQUnitlessLU,
b::AbstractVector{<:UnionAbstractQuantity},
)
vb = ustrip.(b)
vx = similar(vb)
LinearAlgebra.ldiv!(vx, W.F, vb)
@inbounds for i in eachindex(x)
x[i] = vx[i] * (oneunit(b[i]) * W.ut)
end
return x
end

function Base.:(\)(W::DQUnitlessLU, b::AbstractVector{<:UnionAbstractQuantity})
vb = ustrip.(b)
vx = W.F \ vb
out = similar(b)
@inbounds for i in eachindex(out)
out[i] = vx[i] * (oneunit(b[i]) * W.ut)
end
return out
end

end
Comment on lines +43 to +67

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are piracy

@MilesCranmer MilesCranmer Mar 17, 2026

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DQUnitlessLU is defined above, so I think it's fine, no?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, okay.

44 changes: 44 additions & 0 deletions test/dynamicquantities_ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using Test
using DiffEqBase
using DynamicQuantities
using LinearAlgebra

@testset "DiffEqBaseDynamicQuantitiesExt" begin
# Basic quantity hooks
q = 3.0u"m"
@test DiffEqBase.ODE_DEFAULT_NORM(q, 0.0) == 3.0

qc = (3.0 + 4.0im)u"m"
@test DiffEqBase.UNITLESS_ABS2(qc) == 25.0

r = DiffEqBase._rate_prototype(2.0u"m", 4.0u"s", 1)
@test isapprox(ustrip(r), 2.0)
@test oneunit(r) == oneunit(1.0u"m") / oneunit(1.0u"s")

dt = DiffEqBase.timedepentdtmin(1.0u"s", 1.0u"ms")
@test isapprox(ustrip(dt), 0.001)
@test oneunit(dt) == oneunit(1.0u"s")

# Factorization bridge for Quantity matrices
A = [2.0u"m" 0.0u"m"; 0.0u"m" 4.0u"m"]
b = [4.0u"m", 8.0u"m"]

W = DiffEqBase.default_factorize(A)
x = W \ b
@test maximum(abs.(ustrip.(A * x .- b))) ≤ 1e-12

x2 = similar(b)
ldiv!(x2, W, b)
@test maximum(abs.(ustrip.(A * x2 .- b))) ≤ 1e-12

# _infer_ut fallback when all entries are zero
Az = fill(0.0u"m", 2, 2)
Wz = DiffEqBase.default_factorize(Az)
@test Wz.ut == oneunit(1.0)

# empty-matrix path (coverage + sanity)
A0 = Matrix{typeof(1.0u"m")}(undef, 0, 0)
W0 = DiffEqBase.default_factorize(A0)
b0 = typeof(1.0u"m")[]
@test length(W0 \ b0) == 0
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ end
@time @safetestset "Utils" include("utils.jl")
@time @safetestset "ForwardDiff Dual Detection" include("forwarddiff_dual_detection.jl")
@time @safetestset "ODE default norm" include("ode_default_norm.jl")
@time @safetestset "DynamicQuantities extension" include("dynamicquantities_ext.jl")
@time @safetestset "ODE default unstable check" include("ode_default_unstable_check.jl")
@time @safetestset "Problem Kwargs Merging" include("problem_kwargs_merging.jl")
end
Expand Down
Loading