Allocations with nested Jacobians with StaticArrays and ForwardDiff

I am working a code where we do some nested Jacobians with StaticArrays, however I do not seem to manage to get it done without some allocations due to some internal type instability.

The MWE below shows the allocations, and a weird behavior I have also seen while debugging our larger code, where type stability and allocations disappear if, after benchmarking (or simply calling) my function, I redefine J_flat:

using ForwardDiff, StaticArrays, Chairmarks

@inline f(x) = SVector(x[1]^2 * x[2], sin(x[1]) + x[2]^3)

function J_flat(x) 
    y = ForwardDiff.jacobian(f, x)
    SA[y[1], y[2]]
end

function nested_jacobian(x0::SVector{2, T}) where T
    ForwardDiff.jacobian(J_flat, x0)
end

x0 = SVector(1.0, 2.0)
@code_warntype nested_jacobian(x0) # shows a type instability
@b nested_jacobian($x0) # 131.447 ns (4 allocs: 176 bytes)

function J_flat(x) 
    y = ForwardDiff.jacobian(f, x)
    SA[y[1], y[2]]
end

@code_warntype nested_jacobian(x0) # type instability is gone
@b nested_jacobian($x0) # 26.594 ns

Redefining J_flat without benchmarking or calling nested_jacobian does not result in allocation-free call

using ForwardDiff, StaticArrays, Chairmarks

@inline f(x) = SVector(x[1]^2 * x[2], sin(x[1]) + x[2]^3)

function J_flat(x) 
    y = ForwardDiff.jacobian(f, x)
    SA[y[1], y[2]]
end

function nested_jacobian(x0::SVector{2, T}) where T
    ForwardDiff.jacobian(J_flat, x0)
end

x0 = SVector(1.0, 2.0)

function J_flat(x) 
    y = ForwardDiff.jacobian(f, x)
    SA[y[1], y[2]]
end

@code_warntype nested_jacobian(x0) # shows a type instability
@b nested_jacobian($x0) # 129.829 ns (4 allocs: 176 bytes)

Is there any workaround to avoid these allocations?

I think there have been a few threads on this over the years, but the only one I can find at the moment is my own from quite a while ago now: Type stability for higher derivatives in ForwardDiff

As best as I’m aware there is currently no easy workaround for this. However, if some manual intervention is feasible in your case, then you can do something like this:

f(x) = SVector(x[1]^2 * x[2], sin(x[1]) + x[2]^3)

function J_flat(x)  
    y1 = ForwardDiff.derivative(x1 -> f(SA[x1, x[2]])[1], x[1])
    y2 = ForwardDiff.derivative(x1 -> f(SA[x1, x[2]])[2], x[1])

    SA[y1, y2]
end

function nested_jacobian(x0::SVector{2, T}) where T
    ForwardDiff.jacobian(J_flat, x0)
end

x0 = SVector(1.0, 2.0)
@code_warntype nested_jacobian(x0) # Is now type stable

Edit: Forgot to get first index of the output of f
Edit 2: Got the wrong indices I think. I guess this showcase why doing this manually is less than ideal.