Type stability for higher derivatives in ForwardDiff

I’m trying to compute all third partial derivatives of a function using ForwardDiff, but cannot get it to be type stable and I don’t understand why. Here is a minimal example:

using ForwardDiff
using StaticArrays

function tf( x )
    return ( exp( 0.1*x[1] + 0.2*x[2] ) )
end

function get_derivatives( f ) 
    
    ∇f(x) = ForwardDiff.gradient( f , x ) 
    ∇²f(x) = ForwardDiff.hessian( f , x ) 
    ∇³f(x) = ForwardDiff.jacobian( ∇²f , x )
    
    return ∇f, ∇²f, ∇³f

end

∇tf, ∇²tf, ∇³tf = get_derivatives( tf )

@code_warntype ∇³tf( @SVector [1.0,1.0] )

The output of @code_warntype is:

Variables
  #self#::Core.Const(var"#∇³f#329"{var"#∇²f#328"{typeof(tf)}}(var"#∇²f#328"{typeof(tf)}(tf)))
  x::SVector{2, Float64}

Body::Any
1 ─ %1 = ForwardDiff.jacobian::Core.Const(ForwardDiff.jacobian)
│   %2 = Core.getfield(#self#, :∇²f)::Core.Const(var"#∇²f#328"{typeof(tf)}(tf))
│   %3 = (%1)(%2, x)::Any
└──      return %3

The gradient and hessian are both type stable. Searching around for related topics I read about the ForwardDiff chunk size, so I also tried implementing the following:

function get_derivatives_2( f ) 
    
    cfg_grad = ForwardDiff.GradientConfig( f , SVector{2}( [1.0,1.0] ) , ForwardDiff.Chunk{1}() )
    ∇f(x) = ForwardDiff.gradient( f , x , cfg_grad ) 

    cfg_jac_1 = ForwardDiff.JacobianConfig( ∇f , SVector{2}( [1.0,1.0] ) , ForwardDiff.Chunk{1}() )
    ∇²f(x) = ForwardDiff.jacobian( ∇f , x , cfg_jac_1 ) 

    cfg_jac_2 = ForwardDiff.JacobianConfig( ∇²f , SVector{2}( [1.0,1.0] ) , ForwardDiff.Chunk{1}() )
    ∇³f(x) = ForwardDiff.jacobian( ∇²f , x , cfg_jac_2 ) 
    
    return ∇f, ∇²f, ∇³f

end

However, this yielded the same result with the gradient and hessian being type stable, but the third derivative failing to be so.

Julia version: 1.6.2
ForwardDiff version: v0.10.19

Any help with this would be much appreciated.

2 Likes

yes, i found the same problem, but when needing the primal, first and second derivative. for what i understand, when using StaticArrays, the chunk is always the length of the SArray

related: