Hessian Vector products in Flux with Zygote

Hi, I’m trying to update some of my code that uses Flux models and HVPs. With Tracker as the backend (i.e. the old version of Flux), this was possible. I’m unsure how to implement similar functionality with Zygote. Any help would be greatly appreciated!

Below, I’ve written a minimal example implementing (unsuccessfully) a hvp. Any ideas how to get this to work? I’ve been banging my head!

Thanks a lot in advance.

using Flux
using Zygote
using Random



student = Chain(
  Dense(10, 10, relu),
  Dense(10, 10, relu))


teacher = Chain(
  Dense(10, 10, relu),
  Dense(10, 10, relu))

data = randn(10,100)

loss(stu, dat) = Flux.mse(stu(dat), teacher(dat))


gr = Zygote.gradient(params(student)) do
    loss(student,data)
end

some_v = randn!.([p for p in params(student)])

Gvp = () -> sum([dot(gr[p], some_v[k]) for (k,p) in enumerate(params(student))])

Hvp = Zygote.gradient(params(student)) do
    Gvp()
end

ERROR: Need an adjoint for constructor Base.Iterators.Enumerate{Params}. Gradient is of type Array{Nothing,1}

@DR59 Did you ever resolve this? There’s a similar issue on github (https://github.com/FluxML/Zygote.jl/issues/410) by @colinxs that also doesn’t seem to have been fully resolved, although they posted an initial solution involving RecursiveArrayTools.jl. It would be nice if HVPs were easily-performed in Flux.

Whether it’ll be possible to perform second order differentiation of a model as required for an hvp is highly situational. I’d open/bump an issue thread with your specific use case if you have something that isn’t working.

@ToucheSir Thanks for your response. I’m interested in the precise use case of this original thread:

Given a twice-differentiable network f with parameter vector \theta \in R^p, and a scalar loss function L, can we calculate the Hessian-vector product (\nabla_\theta^2 L) v for a given vector v \in R^p without calculating the entire Hessian. This is important in a lot of use-cases, such as eigenvalue estimation.

You can do this in Zygote pretty easily (https://github.com/FluxML/Zygote.jl/issues/115), but there’s a question of how it can be done within the context of a Flux model.

The million dollar question is what f and L do, because not all operations are twice-differentiable. Have a look if https://github.com/FluxML/Flux.jl/issues/1813 helps you, and if not the best way to proceed is to post a MWE with the errors you’re running into.

1 Like