But I’m always getting in troubles when I’m trying to implement the same operation for a neural net:
g = Dense(2,1)
hess = x -> Zygote.hessian.(x->sum(g(x)), x)
hess(rand(2,3))
Error:
MethodError: no method matching (::Dense{typeof(softplus),Array{Float32,2},Array{Float32,1}})(::ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#74#75"{var"#206#208"},Float64},Float64,1})
Closest candidates are:
Any(::AbstractArray{T,N} where N) where {T<:Union{Float32, Float64}, W<:(AbstractArray{T,N} where N)} at /home/solar/.julia/packages/Flux/goUGu/src/layers/basic.jl:134
Any(::AbstractArray{var"#s127",N} where N where var"#s127"<:AbstractFloat) where {T<:Union{Float32, Float64}, W<:(AbstractArray{T,N} where N)} at /home/solar/.julia/packages/Flux/goUGu/src/layers/basic.jl:137
Any(::AbstractArray) at /home/solar/.julia/packages/Flux/goUGu/src/layers/basic.jl:121
Is this possible to get the hessian of neural network and keep its dimensions as follows (2,2,3) or
(2,2,m) in general case?
In the first example, do you mean to take the hessian w.r.t each column of the matrix? Because the hessian of a function from \mathbb{R}^2 \to \mathbb{R} should be in \mathbb{R}^{2 \times 2}. I think you perhaps meant to do this:
julia> hess = x -> (Zygote.hessian(x->sum(x.^3), x)) # note that it's without broadcasting
julia> hess.(eachcol(rand(2, 3)))
3-element Array{Array{Float64,2},1}:
[1.6979767651037765 0.0; 0.0 3.4428948492135]
[4.9107531966282725 0.0; 0.0 0.1841352060564092]
[1.7349169169298424 0.0; 0.0 2.680030582634431]
Otherwise in your code you were simply calculating the double derivatives of a scalar function at 6 points. Similarly
julia> g = Dense(2, 1, softplus)
julia> hess = x -> Zygote.hessian(x->sum(g(x)), x) # again without broadcasting
julia> hess.(eachcol(rand(2, 3)))
3-element Array{Array{Float64,2},1}:
[0.16390895842649078 -0.0006955633608079606; -0.0006955633608079607 2.9516897279012466e-6]
[0.15883581950267744 -0.0006740350099870638; -0.0006740350099870637 2.8603321096637323e-6]
[0.14517203907327184 -0.0006160514493076655; -0.0006160514493076656 2.614273317484523e-6]