Type Stabled Jacobian of a Hessian

I’ve run into a type-stability issue when trying to compute third-order derivatives with ForwardDiff.

I have a potential function that takes a vector as input. Computing the Hessian with ForwardDiff.hessian works well and is entirely type-stable for my use case. However, I also need the Jacobian of the Hessian (i.e. a third-order derivative tensor). Ideally, I would like to construct both a HessianConfig and a JacobianConfig for the Hessian, so that I can reuse them and keep the cost per evaluation low — this part of the code runs very frequently.

Unfortunately, it seems that nested ForwardDiff calls of this sort are not supported, or at least I haven’t been able to make them type-stable.

Right now my workaround is to cache Hessian configs keyed by input dual types, which avoids repeated allocations, but the resulting Jacobian-of-Hessian function is still type-unstable.

Has anyone found an alternative approach for computing third derivatives efficiently, or a way to make this pattern type-stable? Or is there something fundamental I’m misunderstanding about how ForwardDiff handles nested AD?

Any suggestions or pointers would be appreciated!

I have a few ideas but before answering: how big is your problem in terms of input dimension? And how complicated is your function?
This is meant to figure out whether ForwardDiff is the right framework at all

The input dimension can be quite large (depends on user), and this code needs to be general so a user can use whatever potential function they desire.

Solved: Looking at the ForwardDiff source code, this can be done in a similar way to hessian computation by setting up the configs carefully.

function gen_∂G∂θ_fwd(Vfunc, x; f=identity)
    chunk = ForwardDiff.Chunk(x)
    tag = ForwardDiff.Tag(Vfunc, eltype(x))
    jac_cfg = ForwardDiff.JacobianConfig(Vfunc, x, chunk, tag)
    hess_cfg = ForwardDiff.HessianConfig(Vfunc, jac_cfg.duals, chunk, tag)

    d = length(x)
    out = zeros(eltype(x), d^2, d)

    function ∂G∂θ_fwd(y)
        hess = z -> ForwardDiff.hessian(Vfunc, z, hess_cfg, Val{false}())
        ForwardDiff.jacobian!(out, hess, y, jac_cfg, Val{false}())
        return out
    end
    
    return ∂G∂θ_fwd
end