Weights are not updated

I can’t update weights
Here is an example

struct hess_nn
    net
    re
    p

    # Define inner constructor method
    function hess_nn(net; p = nothing)
        _p, re = Flux.destructure(net)
        if p === nothing
            p = _p
        end
        return new(net, re, p)
    end
end     

function (nn::hess_nn)(x, p = nn.p)
    net, p = nn.net, nn.p
    hess = x -> Zygote.hessian_reverse(x->sum(net(x)), x)
    out = hess.(eachcol(x))  
    return out
end

# Initialize the struct
g = Dense(2, 1)
nn_net = hess_nn(g)
p = nn_net.p
re = nn_net.re

# Data
x = rand(2,256);
y = rand(256);

# some toy loss function
loss(x, y, p) = sum((y .- mean.(nn_net(x, p))).^2)


# Training 
epochs = 10
opt = ADAM(0.01)

Flux.trainable(net::hess_nn) = (net.p,)
println("loss before training $(loss(x, y, p))")

for epoch in 1:epochs
    gs = ReverseDiff.gradient(p -> loss(x, y, p), p) 
    println("$epoch, $gs")    
    Flux.Optimise.update!(opt, p, gs)
end
println("loss after training $(loss(x, target, p))")

Output of the training block is:

loss before training 85.48787253127696
1, Float32[0.0, 0.0, 0.0]
2, Float32[0.0, 0.0, 0.0]
3, Float32[0.0, 0.0, 0.0]
4, Float32[0.0, 0.0, 0.0]
5, Float32[0.0, 0.0, 0.0]
6, Float32[0.0, 0.0, 0.0]
7, Float32[0.0, 0.0, 0.0]
8, Float32[0.0, 0.0, 0.0]
9, Float32[0.0, 0.0, 0.0]
10, Float32[0.0, 0.0, 0.0]

You’re taking a hessian inside of a gradient context and using ReverseDiff instead of Zygote. Honestly, I’m surprised it runs at all! Can you talk a bit more about what you want to do and why it requires this formulation?

See also Hessian inside a Flux loss function

I want to implement Lagrangian Neural Network. I’ve seen this topic you refered, and I can’t figure it out how to apply this to my task

I think the gradient is zero because you are asking for the gradient with respect to the argument p.
But the argument p is never used in your network.
Changing that input has no effect because
you overwrite the variable p with p=nn.p.
Thus the gradient with respect to your input p is zero.

I could be wrong.
It’s a bit hard to follow the code,bi think there are 3 or 4 different variables being called p here (in different scopes, some nested)
I think most of the time they are equal.
It might be useful to give them different names.

Thanks, your remarks are legit. I’ve edited the names and added re(p)(x) in the network to bind it to the parameters. The zero gradients problem’s still present

struct hess_nn
    model
    re
    params

    # Define inner constructor method
    function hess_nn(model; p = nothing)
        _p, re = Flux.destructure(model)
        if p === nothing
            p = _p
        end
        return new(model, re, p)
    end
end     

function (nn::hess_nn)(x, p = nn.params)
    re = nn.re
    hess = x -> Zygote.hessian_reverse(x->sum(re(p)(x)), x)
    out = hess.(eachcol(x))  
    return out
end

# Initialize the struct
g = Dense(2, 1)
model_net = hess_nn(g)
params = model_net.params
re = model_net.re

# Data
x = rand(2,256);
y = rand(256);

# some toy loss function
loss(x, y, p) = sum((y .- mean.(model_net(x, params))).^2)

# Training 
epochs = 10
opt = ADAM(0.01)

Flux.trainable(nn::hess_nn) = (model_net.params,)
println("loss before training $(loss(x, y, params))")
for epoch in 1:epochs
    gs = ReverseDiff.gradient(p -> loss(x, y, p), params) 
    println("$epoch, $gs")    
    Flux.Optimise.update!(opt, params, gs)
end
println("loss after training $(loss(x, y, params))")

Output:

loss before training 83.36982153699346
1, Float32[0.0, 0.0, 0.0]
2, Float32[0.0, 0.0, 0.0]
3, Float32[0.0, 0.0, 0.0]
4, Float32[0.0, 0.0, 0.0]
5, Float32[0.0, 0.0, 0.0]
6, Float32[0.0, 0.0, 0.0]
7, Float32[0.0, 0.0, 0.0]
8, Float32[0.0, 0.0, 0.0]
9, Float32[0.0, 0.0, 0.0]
10, Float32[0.0, 0.0, 0.0]
loss after training 83.36982153699346

If I use x.^2 as input of my network I have non-zero grads.
It is clear because then I have non-zero second derivative with respect to the inputs.

function (nn::hess_nn)(x, p = nn.params)
    re = nn.re
    #hess = x -> Zygote.hessian_reverse(x->sum(re(p)(x)), x)
    hess = x -> Zygote.hessian_reverse(x->sum(re(p)(x.^2)), x)
    out = hess.(eachcol(x))  
    return out
end

Output

loss before training 23.475644124478368
1, Float32[17.670782, 17.670782, 0.0]
2, Float32[15.110788, 15.110788, 0.0]
3, Float32[12.568958, 12.568958, 0.0]
4, Float32[10.062634, 10.062634, 0.0]
5, Float32[7.614392, 7.614392, 0.0]
6, Float32[5.252881, 5.252881, 0.0]
7, Float32[3.0132563, 3.0132563, 0.0]
8, Float32[0.9364736, 0.9364736, 0.0]
9, Float32[-0.9332614, -0.9332614, 0.0]
10, Float32[-2.5531504, -2.5531504, 0.0]
loss after training 22.31493460335003

But when I want to use non-linear activation g = Dense(2, 1, σ) I get the error:

MethodError: no method matching mul!(::Array{Float32,1}, ::Array{Float32,2}, ::Adjoint{Float64,Array{Float64,1}}, ::Bool, ::Bool)
Closest candidates are:
  mul!(::AbstractArray, ::Number, ::AbstractArray, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/generic.jl:124
  mul!(::AbstractArray, ::AbstractArray, ::Number, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/generic.jl:126
  mul!(::StridedArray{T, 1}, ::StridedVecOrMat{T}, ::StridedArray{T, 1}, ::Number, ::Number) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:66
  ...

Stacktrace:
 [1] mul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:208 [inlined]
 [2] reverse_mul!(::ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}}, ::Array{Float32,2}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::Adjoint{Float64,Array{Float64,1}}, ::Array{Float32,1}, ::Adjoint{Float32,Array{Float32,1}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/derivatives/linalg/arithmetic.jl:282
 [3] special_reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},Adjoint{Float64,Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}},Tuple{Array{Float32,1},Adjoint{Float32,Array{Float32,1}}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/derivatives/linalg/arithmetic.jl:265
 [4] reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},Adjoint{Float64,Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}},Tuple{Array{Float32,1},Adjoint{Float32,Array{Float32,1}}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/tape.jl:93
 [5] reverse_pass!(::Array{ReverseDiff.AbstractInstruction,1}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/tape.jl:87
 [6] reverse_pass! at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/tape.jl:36 [inlined]
 [7] seeded_reverse_pass!(::Array{Float32,1}, ::ReverseDiff.TrackedReal{Float64,Float32,Nothing}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.GradientTape{var"#107#108",ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedReal{Float64,Float32,Nothing}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/utils.jl:31
 [8] seeded_reverse_pass!(::Array{Float32,1}, ::ReverseDiff.GradientTape{var"#107#108",ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedReal{Float64,Float32,Nothing}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/tape.jl:47
 [9] gradient(::Function, ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/gradients.jl:24
 [10] gradient(::Function, ::Array{Float32,1}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/gradients.jl:22
 [11] top-level scope at In[17]:45
 [12] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091


Chain g = Chain(Dense(2, 10), Dense(10,1)) also doesn’t work

MethodError: no method matching mul!(::Array{Float32,1}, ::Array{Float32,2}, ::Adjoint{Float64,Array{Float64,1}}, ::Bool, ::Bool)
Closest candidates are:
  mul!(::AbstractArray, ::Number, ::AbstractArray, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/generic.jl:124
  mul!(::AbstractArray, ::AbstractArray, ::Number, ::Number, ::Number) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/generic.jl:126
  mul!(::StridedArray{T, 1}, ::StridedVecOrMat{T}, ::StridedArray{T, 1}, ::Number, ::Number) where T<:Union{Complex{Float32}, Complex{Float64}, Float32, Float64} at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:66
  ...

Stacktrace:
 [1] mul! at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/matmul.jl:208 [inlined]
 [2] reverse_mul!(::ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}}, ::Array{Float32,2}, ::ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}}, ::Adjoint{ReverseDiff.TrackedReal{Float64,Float32,Nothing},Array{ReverseDiff.TrackedReal{Float64,Float32,Nothing},1}}, ::Array{Float32,1}, ::Adjoint{Float32,Array{Float32,1}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/derivatives/linalg/arithmetic.jl:282
 [3] special_reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}},Adjoint{ReverseDiff.TrackedReal{Float64,Float32,Nothing},Array{ReverseDiff.TrackedReal{Float64,Float32,Nothing},1}}},ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}},Tuple{Array{Float32,1},Adjoint{Float32,Array{Float32,1}}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/derivatives/linalg/arithmetic.jl:265
 [4] reverse_exec!(::ReverseDiff.SpecialInstruction{typeof(*),Tuple{ReverseDiff.TrackedArray{Float64,Float32,1,Array{Float64,1},Array{Float32,1}},Adjoint{ReverseDiff.TrackedReal{Float64,Float32,Nothing},Array{ReverseDiff.TrackedReal{Float64,Float32,Nothing},1}}},ReverseDiff.TrackedArray{Float64,Float32,2,Array{Float64,2},Array{Float32,2}},Tuple{Array{Float32,1},Adjoint{Float32,Array{Float32,1}}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/tape.jl:93
 [5] reverse_pass!(::Array{ReverseDiff.AbstractInstruction,1}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/tape.jl:87
 [6] reverse_pass! at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/tape.jl:36 [inlined]
 [7] seeded_reverse_pass!(::Array{Float32,1}, ::ReverseDiff.TrackedReal{Float64,Float32,Nothing}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.GradientTape{var"#115#116",ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedReal{Float64,Float32,Nothing}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/utils.jl:31
 [8] seeded_reverse_pass!(::Array{Float32,1}, ::ReverseDiff.GradientTape{var"#115#116",ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedReal{Float64,Float32,Nothing}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/tape.jl:47
 [9] gradient(::Function, ::Array{Float32,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/gradients.jl:24
 [10] gradient(::Function, ::Array{Float32,1}) at /home/solar/.julia/packages/ReverseDiff/iHmB4/src/api/gradients.jl:22
 [11] top-level scope at In[18]:45
 [12] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091

You cut off the stack-trace.
we need that to debug an error.

Ah right, I did notice you didn’t sepecify an activation function, but i thought it was fine since first derivative is nonzero, but since you are taking second, makes sense that it is not fine.

I have recovered the stack trace