Hessian inside a Flux loss function

MWE (or as minimal as I can think to make it):

using Flux
import ForwardDiff, Zygote

g = Dense(2, 1, softplus)
loss(m, x) = sum(abs2, Zygote.hessian(x->sum(m(x)), x))
gradient(g, rand(2)) do m, x
   loss(m, x)

Both Zygote.hessian and ForwardDiff.hessian throw out the following error

ERROR: LoadError: Mutating arrays is not supported
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#399#400")(::Nothing) at /home/arun/.julia/packages/Zygote/IsBxF/src/lib/array.jl:58
 [3] (::Zygote.var"#2265#back#401"{Zygote.var"#399#400"})(::Nothing) at /home/arun/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [4] forward_jacobian at /home/arun/.julia/packages/Zygote/IsBxF/src/lib/forward.jl:25 [inlined]
 [5] (::typeof(∂(forward_jacobian)))(::Tuple{Nothing,Array{Float64,2}}) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0
 [6] forward_jacobian at /home/arun/.julia/packages/Zygote/IsBxF/src/lib/forward.jl:38 [inlined]
 [7] (::typeof(∂(forward_jacobian)))(::Tuple{Nothing,Array{Float64,2}}) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0
 [8] hessian_dual at /home/arun/.julia/packages/Zygote/IsBxF/src/lib/grad.jl:74 [inlined]
 [9] (::typeof(∂(hessian_dual)))(::Array{Float64,2}) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0
 [10] hessian at /home/arun/.julia/packages/Zygote/IsBxF/src/lib/grad.jl:72 [inlined]
 [11] loss at /tmp/test.jl:5 [inlined]
 [12] (::typeof(∂(loss)))(::Float64) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface2.jl:0
 [13] #67 at /tmp/test.jl:7 [inlined]
 [14] (::Zygote.var"#41#42"{typeof(∂(#67))})(::Float64) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface.jl:41
 [15] gradient(::Function, ::Dense{typeof(softplus),Array{Float32,2},Array{Float32,1}}, ::Vararg{Any,N} where N) at /home/arun/.julia/packages/Zygote/IsBxF/src/compiler/interface.jl:59
 [16] top-level scope at /tmp/test.jl:6
 [17] include(::String) at ./client.jl:457
 [18] top-level scope at REPL[9]:1
in expression starting at /tmp/test.jl:6

Any ideas how I can go about solving this issue? Thanks!

Zygote.hessian uses Zygote then ForwardDiff.

Zygote.hessian_reverse uses Zygote twice. You can try taking the 3rd derivative using ForwardDiff on the result of that, although whether this can work beyond toy examples you’ll have to see:

julia> g = x -> exp.(ones(2,2) * x);

julia> loss(m, x) = sum(abs2, Zygote.hessian_reverse(x->sum(m(x)), x));

julia> ForwardDiff.gradient(rand(2)) do x
         loss(g, x)
2-element Vector{Float64}:
1 Like

Thanks! This does work!

But how can I go about using ForwardDiff.gradients w.r.t the parameters of the network to train? i.e. I can’t seem to do something like this (that Flux would normally allow)

gs = ForwardDiff.gradient(params(g)) do
   loss(g, x)

Just realized that we can use Flux.destructure to get a function of the network based on the parameters.