# Model parameters not updating when using Flux.update with differential equation

Here is a MWE, where I am learning a missing term in an equation by representing it with a neural network model. I can compute the gradients with Flux, but the model does not update after each epoch.

``````using DifferentialEquations, Zygote, LinearAlgebra, SciMLSensitivity
using SparseArrays, Random, Statistics, Optimisers, Flux

begin
global N = 10                      # number of cells
global L = 2π                       # total length of 1d sim
global T = 0.5                       # total time
global nt = 50                # number of timesteps
global dt = T / nt                     # time step
global dx = L/N                      # spatial step
global x = 0.0 : dx : (L-dx)         # discretized spatial dimension
global xgrid = collect(x)            # grid
global tsteps = 0.0:dt:T             # discretized time dimension
global tspan = (0,T)                 # end points of time integration for ODEProbem
global t = collect(tsteps)
global ν = 0.01                       # viscosity
global epochs = 5
end

rng = MersenneTwister(1234);

#numerical derivatives
# First Order Backward
function  f1_firstOrder_backward(n,dx)
∂x1 = (diagm(
0 =>  1 * ones(n),
-1 =>  -1 * ones(n-1)
)) ./ dx
# periodic boundaries
∂x1[1,end] = -1 / dx
∂x1[end,1] = 0

# sparsify
∂x1 = sparse(∂x1)

return ∂x1
end

# Second Order Central
function f2_secondOrder_central(n,dx)
∂x2 = (diagm(
0 => -2 * ones(n),
-1 => 1 * ones(n-1),
1 => 1 * ones(n-1)
)) ./ (dx^2)
# periodic boundaries
∂x2[1,end] = 1 / (dx^2)
∂x2[end,1] = 1 / (dx^2)

# sparsify
∂x2 = sparse(∂x2)

return ∂x2
end

# generate training data
∂x1 = f1_firstOrder_backward(N,dx)
∂x2 = f2_secondOrder_central(N,dx)

function true_eqn(u,p,t)
u = - u .* (∂x1*u) + ν .* (∂x2*u)
return u
end

function ic_definition(N)
u0 = ones(N)
u0[2:Int(N/2)] .= 2.0
return u0
end
u0 = ic_definition(N)

p = zeros(N);
true_eqn_prob = ODEProblem(true_eqn, u0, tspan,p)
## Create training data
sol_true_eqn = Array(solve(true_eqn_prob,Tsit5(),alg_hints=[:stiff], saveat=dt));

# learning problem
model = Chain(
Dense(N,20,relu; init=Flux.randn32(rng)),
Dense(20,20,relu; init=Flux.randn32(rng)),
Dense(20,N; init=Flux.randn32(rng)));

p, re = Flux.destructure(model);

function learned_eqn(u,p,t)
ϕ = re(p)
u = - u .* (∂x1*u) + (ν .* ϕ(u))
return u
end

learned_eqn_prob = ODEProblem(learned_eqn, u0, tspan, p)

#traindata
traindata = (u0=u0,y=sol_true_eqn[:,end]);

# Training loop

losses = []
for epoch in 1:epochs
@show epoch

# Step 1: Compute gradient for backprop and loss
ŷ  = solve(learned_eqn_prob, Tsit5(), u0 = u0, p = p, saveat = dt);
l = Flux.mse(ŷ[:,end], traindata.y)
@show l
return l
end

# Step 2: Optimizer update
end
``````

The output shows that gradient with respect to parameters `p` does not change, and neither does the loss `l`:

``````epoch = 1
l =
0.0076257285132982746
epoch = 2
l = 0.0076257285132982746
epoch = 3
l = 0.0076257285132982746
epoch = 4
l = 0.0076257285132982746
epoch = 5
l = 0.0076257285132982746

5-element Vector{Any}:
Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176  …  1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176  …  1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176  …  1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176  …  1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176  …  1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
``````

Am I passing the parameters correctly with `Flux.update`? I am using the latest version

``````Julia Version 1.9.3
Commit bed2cd540a1 (2023-08-24 14:43 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (x86_64-apple-darwin22.4.0)
CPU: 16 × Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
Threads: 1 on 16 virtual cores
Environment:
``````

I would appreciate any help to overcome this issue! Thank you.

I haven’t run this, but I’m almost sure you want this – `setup` and `update!` need to see the same object:

``````optim = Flux.setup(Flux.Adam(0.01), p);  # once

Flux.update!(optim, p, grads[2])  # in the loop
``````

This solves it! Thank you.

1 Like