Hello!
I’m trying to combine sparsity pattern detection with a nested ForwardDiff call together with PreallocationTools. A contrived, minimal example is below:
using ForwardDiff
using PreallocationTools
using Symbolics
function foo(x, cache)
d = get_tmp(cache, x)
d[:] = x
0.5 * x'*x
end
function residual(r, x, cache)
function foo_wrap(x)
foo(x, cache)
end
r[:] = ForwardDiff.gradient(foo_wrap, x)
end
cache = DiffCache(zeros(2))
pattern = Symbolics.jacobian_sparsity((r, x) -> residual(r, x, cache), zeros(2), zeros(2))
which fails with the following stacktrace:
ERROR: ArgumentError: cannot reinterpret `Float64` as `ForwardDiff.Dual{ForwardDiff.Tag{var"#foo_wrap#7"{DiffCache{Vector{Float64}, Vector{Float64}}}, Num}, Num, 2}`, type `ForwardDiff.Dual{ForwardDiff.Tag{var"#foo_wrap#7"{DiffCache{Vector{Float64}, Vector{Float64}}}, Num}, Num, 2}` is not a bits type
Stacktrace:
[1] (::Base.var"#throwbits#323")(S::Type, T::Type, U::Type)
@ Base ./reinterpretarray.jl:16
[2] reinterpret(#unused#::Type{ForwardDiff.Dual{ForwardDiff.Tag{var"#foo_wrap#7"{DiffCache{Vector{Float64}, Vector{Float64}}}, Num}, Num, 2}}, a::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true})
@ Base ./reinterpretarray.jl:62
[3] get_tmp(dc::DiffCache{Vector{Float64}, Vector{Float64}}, u::Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#foo_wrap#7"{DiffCache{Vector{Float64}, Vector{Float64}}}, Num}, Num, 2}})
@ PreallocationTools ~/.julia/packages/PreallocationTools/mJSsc/src/PreallocationTools.jl:124
Based on the discussion here I thought that I could just use cache = FixedSizeDiffCache(zeros(Symbolics.Num, 2), 2)
instead in order to make non-bitstypes work, but this yields a similar error:
ERROR: ArgumentError: cannot reinterpret `ForwardDiff.Dual{nothing, Num, 2}` as `ForwardDiff.Dual{ForwardDiff.Tag{var"#foo_wrap#13"{FixedSizeDiffCache{Vector{Num}, Vector{ForwardDiff.Dual{nothing, Num, 2}}}}, Num}, Num, 2}`, type `ForwardDiff.Dual{ForwardDiff.Tag{var"#foo_wrap#13"{FixedSizeDiffCache{Vector{Num}, Vector{ForwardDiff.Dual{nothing, Num, 2}}}}, Num}, Num, 2}` is not a bits type
I’m unsure whether this is a supported use case or not. Would anyone happen to have any insights?