Is there an efficient way to compute the Hessian of a NN?

One way to compute the Hessian that usually works pretty efficiently is to apply the reverse-on-forward trick, as outlined in

If one replaces the floating point type with a dual number type, such as
http://www.juliadiff.org/ForwardDiff.jl/latest/dev/how_it_works.html#Dual-Number-Implementation-1

Then the usual backpropagation will, in addition to giving the gradient in the value of the dual number output, encode the Hessian-vector product in the dual number partial – the “vector” in Hessian-vector product being the list of partials with which your variables were initialized.

to compute the first diagonal element of the Hessian, you would initialize the partial of the first element of x to be 1.0 and all other partials to be 0.0

3 Likes