Help resolving perturbation confusion in `ForwardDiff.jl` with manual higher-order derivatives

Hey ForwardDiffers,

I’m wondering if I could get some help with resolving a perturbation confusing issue I’m having with some code. I use the follow setup to compute higher-order derivatives with ForwardDiff in a fast and non-allocating way:

using ForwardDiff

fwd_ho(x, ::Val{0}) = x
fwd_ho(x, ::Val{1}) = ForwardDiff.Dual(x, (1.0,))
fwd_ho(x, ::Val{N}) where{N} = ForwardDiff.Dual(fwd_ho(x, Val(N-1)), (1.0,))

get_partial(x, ::Val{0}) = x
get_partial(x, ::Val{1}) = x.partials[1]
get_partial(x, ::Val{N}) where{N} = get_partial(x.partials[1], Val(N-1))

get_value(x, ::Val{0})  = x
get_value(x, ::Val{1})  = x.value
get_value(x, ::Val{N}) where{N} = get_value(x.value, Val(N-1))

@generated function hoderivatives(fn::F, x, ::Val{O}) where{F,O}
    dx  = fwd_ho(x, Val($O));
    fdx = fn(dx)
    d0  = get_partial(fdx, Val($O))
    tmp = fdx
    dxs = Base.Cartesian.@ntuple $O j-> begin
      layer = get_partial(tmp, Val($O - j + 1))
      tmp   = tmp.value
    reverse((dxs..., tmp))

It works great for my setting of asking for 5-6 derivatives of a function, and TaylorDiff.jl only starts to beat it out at orders of, like, 10 or so. And it doesn’t allocate, which is important. But if I pass in a function that is itself a ForwardDiff.derivative, I get perturbation confusion errors. Here is an example:

# this is how I format my functions. In my setting, I am always going to
# need to work with partial derivatives of fun with respect to parameters.
fun(x, params) = params[1]*exp(-params[2]*abs(x)) # for example

# these will normally be some kind of closure created in a closed scope,
# but for this example I just create them as regular functions.
fun(x)      = fun(x, (1.0, 1.0))
dfun_dp2(x) = ForwardDiff.derivative(p2 -> fun(x, (1.0, p2)), 1.0)

# this works, and makes no allocations when used correctly.
@time derivs1 = hoderivatives(fun, 0.01, Val(5))

# this doesn't work and causes perturbation confusion:
derivs2 = hoderivatives(dfun_dp2, 0.01, Val(5))

I have tried adding a variety of manual tags to the types created by fwd_ho, but nothing seems to help. If anybody has any thoughts on how to resolve this issue I’d appreciate hearing them!

Also, I should say: I am aware of TaylorDiff.jl and TaylorSeries.jl and so on, but I would like to stick with ForwardDiff.jl and get this code working.