Weights are not updated

I can’t update weights
Here is an example

struct hess_nn
    net
    re
    p

    # Define inner constructor method
    function hess_nn(net; p = nothing)
        _p, re = Flux.destructure(net)
        if p === nothing
            p = _p
        end
        return new(net, re, p)
    end
end     

function (nn::hess_nn)(x, p = nn.p)
    net, p = nn.net, nn.p
    hess = x -> Zygote.hessian_reverse(x->sum(net(x)), x)
    out = hess.(eachcol(x))  
    return out
end

# Initialize the struct
g = Dense(2, 1)
nn_net = hess_nn(g)
p = nn_net.p
re = nn_net.re

# Data
x = rand(2,256);
y = rand(256);

# some toy loss function
loss(x, y, p) = sum((y .- mean.(nn_net(x, p))).^2)


# Training 
epochs = 10
opt = ADAM(0.01)

Flux.trainable(net::hess_nn) = (net.p,)
println("loss before training $(loss(x, y, p))")

for epoch in 1:epochs
    gs = ReverseDiff.gradient(p -> loss(x, y, p), p) 
    println("$epoch, $gs")    
    Flux.Optimise.update!(opt, p, gs)
end
println("loss after training $(loss(x, target, p))")

Output of the training block is:

loss before training 85.48787253127696
1, Float32[0.0, 0.0, 0.0]
2, Float32[0.0, 0.0, 0.0]
3, Float32[0.0, 0.0, 0.0]
4, Float32[0.0, 0.0, 0.0]
5, Float32[0.0, 0.0, 0.0]
6, Float32[0.0, 0.0, 0.0]
7, Float32[0.0, 0.0, 0.0]
8, Float32[0.0, 0.0, 0.0]
9, Float32[0.0, 0.0, 0.0]
10, Float32[0.0, 0.0, 0.0]

You’re taking a hessian inside of a gradient context and using ReverseDiff instead of Zygote. Honestly, I’m surprised it runs at all! Can you talk a bit more about what you want to do and why it requires this formulation?

See also Hessian inside a Flux loss function

I want to implement Lagrangian Neural Network. I’ve seen this topic you refered, and I can’t figure it out how to apply this to my task

I think the gradient is zero because you are asking for the gradient with respect to the argument p.
But the argument p is never used in your network.
Changing that input has no effect because
you overwrite the variable p with p=nn.p.
Thus the gradient with respect to your input p is zero.

I could be wrong.
It’s a bit hard to follow the code,bi think there are 3 or 4 different variables being called p here (in different scopes, some nested)
I think most of the time they are equal.
It might be useful to give them different names.

2 Likes

Thanks, your remarks are legit. I’ve edited the names and added re(p)(x) in the network to bind it to the parameters. The zero gradients problem’s still present

struct hess_nn
    model
    re
    params

    # Define inner constructor method
    function hess_nn(model; p = nothing)
        _p, re = Flux.destructure(model)
        if p === nothing
            p = _p
        end
        return new(model, re, p)
    end
end     

function (nn::hess_nn)(x, p = nn.params)
    re = nn.re
    hess = x -> Zygote.hessian_reverse(x->sum(re(p)(x)), x)
    out = hess.(eachcol(x))  
    return out
end

# Initialize the struct
g = Dense(2, 1)
model_net = hess_nn(g)
params = model_net.params
re = model_net.re

# Data
x = rand(2,256);
y = rand(256);

# some toy loss function
loss(x, y, p) = sum((y .- mean.(model_net(x, params))).^2)

# Training 
epochs = 10
opt = ADAM(0.01)

Flux.trainable(nn::hess_nn) = (model_net.params,)
println("loss before training $(loss(x, y, params))")
for epoch in 1:epochs
    gs = ReverseDiff.gradient(p -> loss(x, y, p), params) 
    println("$epoch, $gs")    
    Flux.Optimise.update!(opt, params, gs)
end
println("loss after training $(loss(x, y, params))")

Output:

loss before training 83.36982153699346
1, Float32[0.0, 0.0, 0.0]
2, Float32[0.0, 0.0, 0.0]
3, Float32[0.0, 0.0, 0.0]
4, Float32[0.0, 0.0, 0.0]
5, Float32[0.0, 0.0, 0.0]
6, Float32[0.0, 0.0, 0.0]
7, Float32[0.0, 0.0, 0.0]
8, Float32[0.0, 0.0, 0.0]
9, Float32[0.0, 0.0, 0.0]
10, Float32[0.0, 0.0, 0.0]
loss after training 83.36982153699346

If I use x.^2 as input of my network I have non-zero grads.
It is clear because then I have non-zero second derivative with respect to the inputs.

function (nn::hess_nn)(x, p = nn.params)
    re = nn.re
    #hess = x -> Zygote.hessian_reverse(x->sum(re(p)(x)), x)
    hess = x -> Zygote.hessian_reverse(x->sum(re(p)(x.^2)), x)
    out = hess.(eachcol(x))  
    return out
end

Output

loss before training 23.475644124478368
1, Float32[17.670782, 17.670782, 0.0]
2, Float32[15.110788, 15.110788, 0.0]
3, Float32[12.568958, 12.568958, 0.0]
4, Float32[10.062634, 10.062634, 0.0]
5, Float32[7.614392, 7.614392, 0.0]
6, Float32[5.252881, 5.252881, 0.0]
7, Float32[3.0132563, 3.0132563, 0.0]
8, Float32[0.9364736, 0.9364736, 0.0]
9, Float32[-0.9332614, -0.9332614, 0.0]
10, Float32[-2.5531504, -2.5531504, 0.0]
loss after training 22.31493460335003

But when I want to use non-linear activation g = Dense(2, 1, σ) I get the error:

MethodError: no method matching mul!(::Array{Float32,1}, ::Array{Float32,2}, ::Adjoint{Float64,Array{Float64,1}}, ::Bool, ::Bool)
Closest candidates are:
  mul!(::AbstractArray, ::Number, ::AbstractArray, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/generic.jl:124
  mul!(::AbstractArray, ::AbstractArray, ::Number, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/generic.jl:126
  mul!(::StridedArray{T, 1}, ::StridedVecOrMat{T}, ::StridedArray{T, 1}, ::Number, ::Number) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:66
  ...

Stacktrace:
 [1] mul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:208 [inlined]
 [2] reverse_mul!(::ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}}, ::Array{Float32,2}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::Adjoint{Float64,Array{Float64,1}}, ::Array{Float32,1}, ::Adjoint{Float32,Array{Float32,1}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/derivatives/linalg/arithmetic.jl:282
 [3] special_reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},Adjoint{Float64,Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}},Tuple{Array{Float32,1},Adjoint{Float32,Array{Float32,1}}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/derivatives/linalg/arithmetic.jl:265
 [4] reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},Adjoint{Float64,Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}},Tuple{Array{Float32,1},Adjoint{Float32,Array{Float32,1}}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/tape.jl:93
 [5] reverse_pass!(::Array{ReverseDiff.AbstractInstruction,1}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/tape.jl:87
 [6] reverse_pass! at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/tape.jl:36 [inlined]
 [7] seeded_reverse_pass!(::Array{Float32,1}, ::ReverseDiff.TrackedReal{Float64,Float32,Nothing}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.GradientTape{var"#107#108",ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedReal{Float64,Float32,Nothing}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/utils.jl:31
 [8] seeded_reverse_pass!(::Array{Float32,1}, ::ReverseDiff.GradientTape{var"#107#108",ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedReal{Float64,Float32,Nothing}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/tape.jl:47
 [9] gradient(::Function, ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/gradients.jl:24
 [10] gradient(::Function, ::Array{Float32,1}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/gradients.jl:22
 [11] top-level scope at In[17]:45
 [12] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091


Chain g = Chain(Dense(2, 10), Dense(10,1)) also doesn’t work

MethodError: no method matching mul!(::Array{Float32,1}, ::Array{Float32,2}, ::Adjoint{Float64,Array{Float64,1}}, ::Bool, ::Bool)
Closest candidates are:
  mul!(::AbstractArray, ::Number, ::AbstractArray, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/generic.jl:124
  mul!(::AbstractArray, ::AbstractArray, ::Number, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/generic.jl:126
  mul!(::StridedArray{T, 1}, ::StridedVecOrMat{T}, ::StridedArray{T, 1}, ::Number, ::Number) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:66
  ...

Stacktrace:
 [1] mul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:208 [inlined]
 [2] reverse_mul!(::ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}}, ::Array{Float32,2}, ::ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}}, ::Adjoint{ReverseDiff.TrackedReal{Float64,Float32,Nothing},Array{ReverseDiff.TrackedReal{Float64,Float32,Nothing},1}}, ::Array{Float32,1}, ::Adjoint{Float32,Array{Float32,1}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/derivatives/linalg/arithmetic.jl:282
 [3] special_reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}},Adjoint{ReverseDiff.TrackedReal{Float64,Float32,Nothing},Array{ReverseDiff.TrackedReal{Float64,Float32,Nothing},1}}},ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}},Tuple{Array{Float32,1},Adjoint{Float32,Array{Float32,1}}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/derivatives/linalg/arithmetic.jl:265
 [4] reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}},Adjoint{ReverseDiff.TrackedReal{Float64,Float32,Nothing},Array{ReverseDiff.TrackedReal{Float64,Float32,Nothing},1}}},ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}},Tuple{Array{Float32,1},Adjoint{Float32,Array{Float32,1}}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/tape.jl:93
 [5] reverse_pass!(::Array{ReverseDiff.AbstractInstruction,1}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/tape.jl:87
 [6] reverse_pass! at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/tape.jl:36 [inlined]
 [7] seeded_reverse_pass!(::Array{Float32,1}, ::ReverseDiff.TrackedReal{Float64,Float32,Nothing}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.GradientTape{var"#115#116",ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedReal{Float64,Float32,Nothing}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/utils.jl:31
 [8] seeded_reverse_pass!(::Array{Float32,1}, ::ReverseDiff.GradientTape{var"#115#116",ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedReal{Float64,Float32,Nothing}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/tape.jl:47
 [9] gradient(::Function, ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/gradients.jl:24
 [10] gradient(::Function, ::Array{Float32,1}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/gradients.jl:22
 [11] top-level scope at In[18]:45
 [12] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091

You cut off the stack-trace.
we need that to debug an error.

Ah right, I did notice you didn’t sepecify an activation function, but i thought it was fine since first derivative is nonzero, but since you are taking second, makes sense that it is not fine.

I have recovered the stack trace