Gradient not being computed when training NN using Flux

Hello Everyone, I’m new to Flux and I’m following the course on book.sciml.ai, right now I’m in lesson three, specifically where is discussed modelling Hooke’s Law and the construction of a PINN, and I’m havong trouble training the model.

I think is because the book is from 2020 and Flux 0.13 introduced some breaking changes, but no matter how I change the code the gradient always comes as nothing. I altered some things of the book, this is the code I’m running:

k = 1.0
force(dx,x,k,t) = -k*x + 0.1sin(x) # True force
prob = SecondOrderODEProblem(force,1.0,0.0,(0.0,10.0),k)
sol = solve(prob)

# Generate the dataset of specifc points for the neural network
# Here we have a limited number of data points
t = 0:3.3:10
position_data = [state[2] for state in sol(t)]
force_data = [force(state[1],state[2],k,t) for state in sol(t)]

NNForce = Chain(
    x -> [x], # Transform the input into a 1-element array
    Dense(1, 32, tanh),
    Dense(32, 1),
    first # Extract the first element of the output
)

loss(nn) = sum(abs2, nn(position_data[i]) - force_data[i] for i in 1:length(position_data))
loss(NNForce) |> x -> println("Initial loss: ", x)

# Standard gradient descent
opt = Flux.setup(Descent(0.01), NNForce)

# Training loop
for i in 1:20
    ∂loss∂m = gradient(loss, NNForce)[1]
    Flux.Optimisers.update!(opt, NNForce, ∂loss∂m[1])
    i%10 == 0 ? println("loss: ", loss(NNForce)) : nothing
end

My versioninfo and status is as follows:

julia> versioninfo()
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 10 × Apple M4
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

(intro-sciml) pkg> status
Status `~/Developer/intro-sciml/Project.toml`
  [0c46a032] DifferentialEquations v7.15.0
  [7da242da] Enzyme v0.13.28
  [587475ba] Flux v0.16.2
  [91a5bcdd] Plots v1.40.9
  [10745b16] Statistics v1.11.1
1 Like

I think you want either of those [1]s, but not both.

I’m surprised that Optimisers.update! does not object more forcefully about the resulting mismatch of nested structures, but it does complain briefly. In fact, I think that warning is a path to warn you if you forgot [1] entirely, and here it makes things worse, keeps just the nothing?

julia> t = 0:3.3:10;

julia> force_data = randn(Float32, size(t));  # fake data, without running ODE code

julia> position_data = randn(Float32, size(t));

julia> gradient(loss, NNForce)[1]
(layers = (nothing, (weight = Float32[-0.00031716842; 0.005882237; … ; 0.02040809; -0.017797563;;], bias = Float32[-0.00015526835, 0.0045354646, 0.0001758635, 0.000911925, -0.00093348324, -0.0028021857, -0.00037156977, -0.0002479069, 0.00634047, 0.0022809654  …  0.00086749066, -0.015178025, -0.0015061237, 0.005903747, 0.0004341714, 0.011407629, 0.005850576, 0.0060465317, 0.019239068, 0.0012870021], σ = nothing), (weight = Float32[-0.014764998 0.015171923 … -0.022916555 0.0054369], bias = Float32[-0.0074519515], σ = nothing), nothing),)

julia> Flux.state(NNForce)  # model's structure, must match the gradient!
(layers = ((), (weight = Float32[-0.28203613; 0.30942923; … ; -0.40744054; 0.09964981;;], bias = Float32[0.0046208426, -0.032450862, 0.06144477, 0.024944201, 0.057604324, 0.048263744, -0.009535714, -0.05626719, -0.03320607, -0.02485606  …  -0.0106274625, 0.07197578, -0.042995136, -0.017393522, -0.021419354, -0.046221554, -0.06806464, -0.016840763, -0.06592702, 0.049628224], σ = ()), (weight = Float32[-0.007292727 0.16620831 … 0.41103598 -0.30187216], bias = Float32[-0.16961181], σ = ()), ()),)

julia> gradient(loss, NNForce)[1][1]  # wrong!
(nothing, (weight = Float32[-0.00031716842; 0.005882237; … ; 0.02040809; -0.017797563;;], bias = Float32[-0.00015526835, 0.0045354646, 0.0001758635, 0.000911925, -0.00093348324, -0.0028021857, -0.00037156977, -0.0002479069, 0.00634047, 0.0022809654  …  0.00086749066, -0.015178025, -0.0015061237, 0.005903747, 0.0004341714, 0.011407629, 0.005850576, 0.0060465317, 0.019239068, 0.0012870021], σ = nothing), (weight = Float32[-0.014764998 0.015171923 … -0.022916555 0.0054369], bias = Float32[-0.0074519515], σ = nothing), nothing)

julia> for i in 1:10
           ∂loss∂m = gradient(loss, NNForce)[1]
           Flux.Optimisers.update!(opt, NNForce, ∂loss∂m[1])  # as above, with [1][1]
           println("loss: ", loss(NNForce))
       end
┌ Warning: explicit `update!(opt, model, grad)` wants the gradient for the model alone,
│ not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`.
└ @ Flux ~/.julia/packages/Flux/BkG8S/src/layers/basic.jl:87
loss: 2.9432225
┌ Warning: explicit `update!(opt, model, grad)` wants the gradient for the model alone,
│ not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`.
└ @ Flux ~/.julia/packages/Flux/BkG8S/src/layers/basic.jl:87
loss: 2.9432225
...
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225
loss: 2.9432225  # loss has not decreased

julia> for i in 1:10
           ∂loss∂m = gradient(loss, NNForce)[1]
           Flux.Optimisers.update!(opt, NNForce, ∂loss∂m)  # corrected
           println("loss: ", loss(NNForce))
       end
loss: 2.3963256
loss: 2.0603027
loss: 1.853425
loss: 1.7260386
loss: 1.6475848
loss: 1.599206
loss: 1.5692846
loss: 1.5506897
loss: 1.5390539
loss: 1.5317094  # now it learns?
1 Like

It works now, thanks!

But this is what I find strange, while I was following along the previous segment of the same lesson I came a across this error while training: ERROR: type NamedTuple has no field layers. The structure of that code was pretty similar, the only real difference is tha we’re learning a g function that g(t) = 1.0 + NN(t). The only way to make that work was doing the double [1] that was in my code before.

Weird that it works in one case but not in the other…

I’m not exactly sure what happened, but it shouldn’t be necessary to guess by sprinkling [1] and waiting for errors. The gradient must have the same fields as the model. You can print them out and check. (Flux.state above is one way to reveal what the structure of types like Chain really are.)

But glad it now works!

1 Like