ForwardDiff caches usage

The solution to this problem was using a call to reinterpret as @ChrisRackauckas suggested in the getter functions. Here is a code that works.

Be aware this is not optimized code.

using ForwardDiff
using Random

mutable struct modelf <: Function
    common::Vector
    common_dual::Vector
    du::Vector
    du_dual::Vector
    function modelf(du0::Vector{T}) where T <: Number
        # Forced the CHUNK to 3 for the MWE.
        U = ForwardDiff.Dual{nothing, T, 3}
        new(zeros(T, 3), zeros(U, 3), similar(du0), zeros(U, length(du0)))
    end
end

get_common(m::modelf, ::Type{Float64}) = m.common
get_common(m::modelf, ::Type{T}) where T <: ForwardDiff.Dual = reinterpret(T, m.common_dual)
get_du(m::modelf, ::Type{Float64}) = m.du
get_du(m::modelf, ::Type{T}) where T <: ForwardDiff.Dual = reinterpret(T, m.du_dual)

function (f::modelf)(u::AbstractArray{T}, t) where T
    system(get_du(f, T), u, get_common(f, T), t)
    return get_du(f, T)
end

function component1(u, common) where T
    common[1] = 1.0*u[1]
    return -0.1 * u[1]
end

function component2(u, common) where T
    common[2] = -0.8 * u[2]
    return -0.4 * u[2]
end

function component3(u, common) where T
    return sum(common .* [-0.1, 0.2, 0.01]) + 0.02 * u[3]
end

function system(du, u, common, t) where T
    du[1] = component1(u, common)
    du[2] = component2(u, common)
    du[3] = component3(u, common)
    nothing
end

model = modelf(zeros(3))
model(zeros(3), 0)

fu = (u) -> model(u, 0)
jconfig = ForwardDiff.JacobianConfig(fu , rand(3), ForwardDiff.Chunk{3}())
J = ForwardDiff.jacobian(fu, rand(3), jconfig)
2 Likes