Skip to content

Deferred computation for precompilable tags? #812

@longemen3000

Description

@longemen3000

I was looking at #807 and this comment caught my attention:

The pattern (outer derivative of an inner scalar derivative whose closure captures the outer point) is the canonical scalar nested derivative — it isn't exotic, and it's exactly what would be expected to suffer from precompile-induced tagcount inversion in downstream packages.

When the inner derivative is seeded with a literal Float64, the inner tag is Tag{, Float64} with depth 1 — even though that closure captures a Dual{T_middle, ..., 2} (the outer-nesting context is in F, not in V)

As of ForwardDiff.jl v1.4 the closure-over-AD pattern is solved via the runtime tagcount mechanism.

another known AD framework with forward mode is Enzyme, and they explicitly provide a separate function to allow nested differentiation, namely (autodiff_deferred).

is it possible to implement a similar mechanism in ForwardDiff?. Instead of recursing over the closure type, at closure construction a Deferred{F,T} struct is created that stores, along with the closure, tag information that could be exposed to the outer ForwardDiff contexts?. For example here:

using ForwardDiff
const Tag = ForwardDiff.Tag

struct OuterF end
struct InnerF
    x_dual::ForwardDiff.Dual{Tag{OuterF, Float64}, Float64, 1}
end
(c::InnerF)(y) = sin(c.x_dual * y)
(::OuterF)(x_dual::ForwardDiff.Dual{Tag{OuterF, Float64}, Float64, 1}) =
    ForwardDiff.derivative(InnerF(x_dual), 1.0)

# Force tagcount to be evaluated in INVERTED order (simulating a precompile race):
ForwardDiff.tagcount(Tag{InnerF, Float64})    # returns 0
ForwardDiff.tagcount(Tag{OuterF, Float64})    # returns 1

# Analytic: d/dx (d/dy sin(x*y)|_{y=1}) at x=0.5 = cos(0.5) - 0.5*sin(0.5) ≈ 0.6378697925882713
ForwardDiff.derivative(OuterF(), 0.5)

The current path creates the two tags Tag{InnerF, Float64} and Tag{OuterF, Float64}. But if we had something like:

struct Deferred{F,P}
  F::F #function to be evaluated
  p::P #context
end

(f::Deferred{FF,PP})(x) where {FF,PP} = f.F(f.p)(x)

#TODO: generalize for abstract arrays and other contexts.
deferred_valtype(::Type{V},p::T) where {V<:Number,T<:Number} = promote_type(V,T)

function ForwardDiff.Tag(f::Deferred{F,P},::Type{V}) where {F,P,V}
  return ForwardDiff.Tag(f.F,deferred_valtype(V,f.p))
end

#use deferred instead of instantiating InnerF directly
(::OuterF)(x_dual::ForwardDiff.Dual{Tag{OuterF, Float64}, Float64, 1}) =
           ForwardDiff.derivative(Deferred(InnerF,x_dual), 1.0)

#a macro for user convenience:
#(::OuterF)(x_dual::ForwardDiff.Dual{Tag{OuterF, Float64}, Float64, 1}) =
           ForwardDiff.derivative(@deferred(InnerF(x_dual)), 1.0)

ForwardDiff.derivative(OuterF(), 0.5) #0.6378697925882713

In this case, the tags created would be Tag{OuterF,Float64} and Tag{InnerF,Dual{OuterF, Float64}, Float64, 1}, but that is because we transformed the problem into one that is more amenable to what ForwardDiff already does, while giving a path for tag comparisons that don't depend on a global variable that could change between sessions

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions