 # Hessian matrix of ML model

Hi.
I would like to compute the Hessian matrix of my Flux model.
I tried something like this without success:

``````n = 100
X = rand(2, n)
Y = rand(1, n)
``````
``````model = Chain(Dense(2,2),Dense(2,1))
loss(x,y) = Flux.mse(model(x),y)
hess = Zygote.hessian(Flux.params(model)) do
loss(X, Y)
end
``````
``````ERROR: MethodError: no method matching hessian(::var"#129#130", ::Params)
Closest candidates are:
hessian(::Any, ::AbstractArray) at /user/mpoliti/home/.julia/packages/Zygote/ggM8Z/src/lib/utils.jl:113
Stacktrace:
 top-level scope at REPL:1
``````
5 Likes

Looking at the documentation of the hessian it seems like it expects

• a single argument function where the argument is a vector with the paramenters over which the hessian should be calculated
• what points those parameters should be evaluated in
``````help?> Zygote.hessian
hessian(f, x)

Construct the Hessian ∂²f/∂x², where x is a real number or an array, and f(x) is a real number. When x is an
array, the result is a matrix H[i,j] = ∂²f/∂x[i]∂x[j], using linear indexing x[i] even if the argument is
higher-dimensional.

This uses forward over reverse, ForwardDiff over Zygote, calling hessian_dual(f, x). See hessian_reverse for an
all-Zygote alternative.

Examples
≡≡≡≡≡≡≡≡≡≡

hessian(x -> x*x, randn(2))
2×2 Matrix{Float64}:
0.0  1.0
1.0  0.0

hessian(x -> sum(x.^3), [1 2; 3 4])  # uses linear indexing of x
4×4 Matrix{Int64}:
6   0   0   0
0  18   0   0
0   0  12   0
0   0   0  24

hessian(sin, pi/2)
-1.0
``````

So I’m not sure what the solution would be, but it seems you would want to create something along the lines of

``````x = ...
y = ...
function f(params)
# Build model based on params
model = ...
loss = Flux.mse(model(x), y)
end
initial_params = [...]
Zygote.hessian(f, initial_params)
``````
1 Like

Is there a way to compute only the diagonal of the hessian without computing the whole hessian matrix ?

Maybe something like this?

I don’t think that answer is optimized for the diagonal computation, it just gives a solution to compute the hessian matrix

By the way cannot make it work.

``````function f(params)
# Build model based on params
params=Flux.params(params)
loss = Flux.mse(model(x), y)
end

model = Chain(Dense(2,1))
initial_params = [model.W, model.b]
Zygote.hessian(f, initial_params)
``````
``````ERROR: ArgumentError: Cannot create a dual over scalar type Any. If the type behaves as a scalar, define FowardDiff.can_dual.
Stacktrace:
 throw_cannot_dual(::Type{T} where T) at /user/mpoliti/home/.julia/packages/ForwardDiff/QOqCN/src/dual.jl:36
 ForwardDiff.Dual{Nothing,Any,2}(::Array{Float32,2}, ::ForwardDiff.Partials{2,Any}) at /user/mpoliti/home/.julia/packages/ForwardDiff/QOqCN/src/dual.jl:18
``````

The docs state that `hessian(f, x) ... where x is a real number or an array, ...` which I interpret as if it should be a scalar array.
Doing something like

``````function f2(params)
loss = Flux.mse(params[:, 2:end] * x .+ params[:, 1:1], y)
end
initial_params = [model.b model.W]
Zygote.hessian(f, initial_params)
``````

seems to work for me. I think it is something with mutating arrays that Zygote does not like, have not read into it much but here is some discussion around the topic.

You may want to try out Zygote#959 which adds a function for computing only the diagonal of the Hessian.

Like the existing function, it won’t understand implicit parameters. Both use ForwardDiff, and this needs change the input — or at least, I don’t see an obvious way to make them work with `Params`. You might be able to do something with `Flux.destructure`?

That seems great, but I am not able to install your branch.

``````Pkg.add(url="https://github.com/mcabbott/Zygote.jl.git")
``````
``````ERROR: Unsatisfiable requirements detected for package Adapt [79e6a3ab]:
├─possible versions are: [0.3.0-0.3.1, 0.4.0-0.4.2, 1.0.0-1.0.1, 1.1.0, 2.0.0-2.0.2, 2.1.0, 2.2.0, 2.3.0, 2.4.0, 3.0.0, 3.1.0-3.1.1, 3.2.0, 3.3.0] or uninstalled
├─restricted by compatibility requirements with Tracker [9f7883ad] to versions: [0.3.0-0.3.1, 0.4.0-0.4.2, 1.0.0-1.0.1, 1.1.0, 2.0.0-2.0.2, 2.1.0, 2.2.0, 2.3.0, 2.4.0, 3.0.0, 3.1.0-3.1.1, 3.2.0, 3.3.0]
``````

I think that without specifying the branch that gets you a very old version, compare:

``````pkg> activate --temp