Model parameters not updating when using Flux.update with differential equation

Here is a MWE, where I am learning a missing term in an equation by representing it with a neural network model. I can compute the gradients with Flux, but the model does not update after each epoch.

using DifferentialEquations, Zygote, LinearAlgebra, SciMLSensitivity
using SparseArrays, Random, Statistics, Optimisers, Flux

begin 
    global N = 10                      # number of cells 
    global L = 2π                       # total length of 1d sim
    global T = 0.5                       # total time 
    global nt = 50                # number of timesteps
    global dt = T / nt                     # time step
    global dx = L/N                      # spatial step
    global x = 0.0 : dx : (L-dx)         # discretized spatial dimension 
    global xgrid = collect(x)            # grid 
    global tsteps = 0.0:dt:T             # discretized time dimension 
    global tspan = (0,T)                 # end points of time integration for ODEProbem
    global t = collect(tsteps)
    global ν = 0.01                       # viscosity
    global epochs = 5
end 

rng = MersenneTwister(1234); 

#numerical derivatives
# First Order Backward
function  f1_firstOrder_backward(n,dx)
    ∂x1 = (diagm(
                0 =>  1 * ones(n),
                -1 =>  -1 * ones(n-1)
         )) ./ dx
    # periodic boundaries
    ∂x1[1,end] = -1 / dx 
    ∂x1[end,1] = 0

    # sparsify
    ∂x1 = sparse(∂x1)
    
    return ∂x1
end

# Second Order Central
function f2_secondOrder_central(n,dx)
    ∂x2 = (diagm(
                0 => -2 * ones(n),
                -1 => 1 * ones(n-1),
                 1 => 1 * ones(n-1)
                )) ./ (dx^2)
    # periodic boundaries
    ∂x2[1,end] = 1 / (dx^2)
    ∂x2[end,1] = 1 / (dx^2)

    # sparsify
    ∂x2 = sparse(∂x2)    

    return ∂x2
end

# generate training data
∂x1 = f1_firstOrder_backward(N,dx)
∂x2 = f2_secondOrder_central(N,dx)

function true_eqn(u,p,t)
    u = - u .* (∂x1*u) + ν .* (∂x2*u)
    return u
end

function ic_definition(N)
    u0 = ones(N)
    u0[2:Int(N/2)] .= 2.0
    return u0
end
u0 = ic_definition(N)

p = zeros(N);
true_eqn_prob = ODEProblem(true_eqn, u0, tspan,p)
## Create training data 
sol_true_eqn = Array(solve(true_eqn_prob,Tsit5(),alg_hints=[:stiff], saveat=dt));

# learning problem
model = Chain(
            Dense(N,20,relu; init=Flux.randn32(rng)),
            Dense(20,20,relu; init=Flux.randn32(rng)),
            Dense(20,N; init=Flux.randn32(rng)));

p, re = Flux.destructure(model);

function learned_eqn(u,p,t)
    ϕ = re(p)
    u = - u .* (∂x1*u) + (ν .* ϕ(u))
    return u
end 

learned_eqn_prob = ODEProblem(learned_eqn, u0, tspan, p)

#traindata
traindata = (u0=u0,y=sol_true_eqn[:,end]);

# Training loop
optim = Flux.setup(Flux.Adam(0.01), model);  

losses = []
grads_list_u = []
grads_list_p = []
for epoch in 1:epochs
    @show epoch

    # Step 1: Compute gradient for backprop and loss
    grads = Flux.gradient(u0,p) do u0,p
        ŷ  = solve(learned_eqn_prob, Tsit5(), u0 = u0, p = p, saveat = dt);
        l = Flux.mse(ŷ[:,end], traindata.y) 
        @show l
        return l
    end
    
    # Step 2: Optimizer update
    Flux.update!(optim, learned_eqn_prob, grads[2])
    push!(grads_list_u, grads[1])
    push!(grads_list_p, grads[2])
end
@show grads_list_u
@show grads_list_p

The output shows that gradient with respect to parameters p does not change, and neither does the loss l:

epoch = 1
l = 
0.0076257285132982746
epoch = 2
l = 0.0076257285132982746
epoch = 3
l = 0.0076257285132982746
epoch = 4
l = 0.0076257285132982746
epoch = 5
l = 0.0076257285132982746

5-element Vector{Any}:
 Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176  …  1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
 Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176  …  1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
 Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176  …  1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
 Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176  …  1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
 Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176  …  1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]

Am I passing the parameters correctly with Flux.update? I am using the latest version

Julia Version 1.9.3
Commit bed2cd540a1 (2023-08-24 14:43 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (x86_64-apple-darwin22.4.0)
  CPU: 16 × Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
  Threads: 1 on 16 virtual cores
Environment:
  JULIA_NUM_THREADS = 

I would appreciate any help to overcome this issue! Thank you.

I haven’t run this, but I’m almost sure you want this – setup and update! need to see the same object:

optim = Flux.setup(Flux.Adam(0.01), p);  # once

Flux.update!(optim, p, grads[2])  # in the loop

This solves it! Thank you.

1 Like