Fast f(x) ∂f∂x and ∂2f∂x2

Hi, im trying to make a fast evaluation of the value, first and second derivative, in the least amount of forward passes as possible. i’m using at the moment this function:

using DiffResults, StaticArrays, ForwardDiff
function f∂f∂2f(f,x::T) where T
    _f(z) = f(only(z))
    x_vec =   SVector(x)
    ∂result = DiffResults.HessianResult(x_vec)  
    _∂f =  ForwardDiff.hessian!(∂result, _f,x_vec)
    fx =  DiffResults.value(_∂f)
    ∂f∂x = only(DiffResults.gradient(_∂f))
    ∂²f∂²x =  only(DiffResults.hessian(_∂f))
    return fx,∂f∂x,∂²f∂²x
end

but, even as everything seems defined statically, this function allocates:

julia> @btime f∂f∂2f(sin,$Ref(2.0)[])
  78.189 ns (1 allocation: 96 bytes)
(0.9092974268256817, -0.4161468365471424, -0.9092974268256817)

the same version, but only obtaining the first derivative, doesn’t allocate:

function f∂f(f,x::T) where T
    _f(z) = f(only(z))
    x_vec =   SVector(x)
    ∂result = DiffResults.GradientResult(x_vec)  
    _∂f =  ForwardDiff.gradient!(∂result, _f,x_vec)
    fx =  DiffResults.value(_∂f)
    ∂f∂x = only(DiffResults.gradient(_∂f))
    return fx,∂f∂x
end
>julia @btime f∂f(sin,$Ref(2.0)[])
16.116 ns (0 allocations: 0 bytes)
(0.9092974268256817, -0.4161468365471424)

any idea on how to improve the situation here?

1 Like

This is experimental now, but maybe you want to try another AD?

using Diffractor
using BenchmarkTools

function fdfd2f(f, x)
    let var"'" = Diffractor.PrimeDerivativeFwd
        fx = f(x)
        dfx = f'(x)
        d2fx = f''(x)
        return fx, dfx, d2fx
    end
end

julia> @btime fdfd2f(sin, $Ref(2.0)[])
  33.272 ns (0 allocations: 0 bytes)
(0.9092974268256817, -0.4161468365471424, -0.9092974268256817)
5 Likes

interesting, but i’m stuck with ForwardDiff for now, as i have to differenciate through functions with mutation, and i dont have the rrules defined for those

I believe Difractor can deal with mutation.

not in reverse mode.
but i think, yes, in forwards mode

how can i ask difractor to do Forward over Forward in that case?