Is there an efficient way to compute the Hessian of a NN?

Flux.jacobian(::Chain, ::AbstractArray) gives the full jacobian of a neural network. Flux.hessian however throws the error “output is not scalar”.

Let m = Chain(Dense(5,3,elu), Dense(3,3)) . What I can do is compute the gradient of each value in the jacobian matrix w.r. to each input using Flux.gradient(x->Flux.jacobian(m, x)[i], input) where i the index of the jacobian matrix we are interested in differentiating, and input is the input vector fed to the neural network (i.e. the point where we want to differentiate). The jacobian has 3x5 dimensions, that means that to compute the hessian for each i I must compute 15 times the same jacobian. That’s not efficient.

Moreover, I’m in fact only interested in what I would call the “second order jacobian” and not the full hessian. That is, the diagonals of the hessian (dy/dx^2). So computing the full hessian is wasteful anyway. In other words, what I would like is to compute the second order derivative of each output of a neural network with respect to each input.

3 Likes

One way to compute the Hessian that usually works pretty efficiently is to apply the reverse-on-forward trick, as outlined in

If one replaces the floating point type with a dual number type, such as
http://www.juliadiff.org/ForwardDiff.jl/latest/dev/how_it_works.html#Dual-Number-Implementation-1

Then the usual backpropagation will, in addition to giving the gradient in the value of the dual number output, encode the Hessian-vector product in the dual number partial – the “vector” in Hessian-vector product being the list of partials with which your variables were initialized.

to compute the first diagonal element of the Hessian, you would initialize the partial of the first element of x to be 1.0 and all other partials to be 0.0

3 Likes

Reverse mode is going to give columns, and I don’t think you need that. Using double forward mode will be the fastest here. You’ll need to mapchildren to remove the tracker information (or use the Flux#zygote branch) and then just forward diff (or use a hyperdual)

If you do want to Forward-over-Reverse for Hess-vec products though, it is implemented in SparseDiffTools.jl

https://github.com/JuliaDiffEq/SparseDiffTools.jl#jacobian-vector-and-hessian-vector-products

but note that our tests don’t show that using Zygote here is the fastest yet :man_shrugging:

2 Likes

Thank you both for your answers. And thank you for this nice paper, I understand AD much better now. I wanted to attempt both RoF and double forward methods. But I encounter a stack overflow issue when I use ForwardDiff methods on a neural network. Only the jacobian method takes an AbstractArray as input and accepts AbstractArrays as output. The other assume one of the two is Real. However:

f = Chain(Dense(10,5,elu),Dense(5,3))
x = rand(10)
ForwardDiff.jacobian(f, x) 

Throws a stack overflow. I assume this method does n forward passes, one for each input. But the underlying ForwardDiff.gradient cannot take multi-dimensional output functions as argument. Isn’t it weird for a forward accumulation mode to not be able to do that ? According to the paper from Baydin et al. it is quite straightforward to do. Same for the ForwardDiff.derivative function, it only handle functions with Real inputs.

Am I missing something ?

That will stack overflow :man_shrugging: an issue in flux (pun intended)

I opened an issue then. By the way, indeed double forward mode is exactly what I need. The first pass will compute [dy_1/dx_1 … dy_m/dx_1] and the derivative with respect to x_1 will give the dx_1^2.

So I’ll go with this mode once I figure out how to use that with a ::Chain. Reverse-on-forward would be best If you need the full hessian.

It’s hardcoded in Flux to not work (and stackoverflow… it should throw a better error), but kind of for a good reason. We are working for a solution to this right now for other reasons, so stay tuned. We may start chatting about it in #autodiff

2 Likes

Ha okay ! Great, thank you I’ll keep track of the progress then !

1 Like

Did you have any luck @HenriDeh ? I’m facing similar issues.

Hi, no unfortunately I haven’t heard of any changes regarding this. I know that Flux is getting a complete revamp with zygote and other stuff so this might not be their priority at the moment. It could also fix this in the process, I don’t know. I have put this aspect of my project on hold in the meantime.

1 Like

I had this same issue and resolved it using the sum function:

using Flux

m = Chain(Dense(2, 1))
x = zeros(2)
# Flux.hessian(m, x) throws "ERROR: Function output is not scalar"
# Flux.hessian(v->m(v)[1], x) throws "ERROR: Nested AD not defined for getindex"
Flux.hessian(v->sum(m(v)), x) # works

Tracked 2×2 Array{Float64,2}:
 0.0  0.0
 0.0  0.0

This only works if your network has one output. With multiple outputs nodes you cannot use a sum.

2 Likes