Here is a MWE, where I am learning a missing term in an equation by representing it with a neural network model. I can compute the gradients with Flux, but the model does not update after each epoch.
using DifferentialEquations, Zygote, LinearAlgebra, SciMLSensitivity
using SparseArrays, Random, Statistics, Optimisers, Flux
begin
global N = 10 # number of cells
global L = 2π # total length of 1d sim
global T = 0.5 # total time
global nt = 50 # number of timesteps
global dt = T / nt # time step
global dx = L/N # spatial step
global x = 0.0 : dx : (L-dx) # discretized spatial dimension
global xgrid = collect(x) # grid
global tsteps = 0.0:dt:T # discretized time dimension
global tspan = (0,T) # end points of time integration for ODEProbem
global t = collect(tsteps)
global ν = 0.01 # viscosity
global epochs = 5
end
rng = MersenneTwister(1234);
#numerical derivatives
# First Order Backward
function f1_firstOrder_backward(n,dx)
∂x1 = (diagm(
0 => 1 * ones(n),
-1 => -1 * ones(n-1)
)) ./ dx
# periodic boundaries
∂x1[1,end] = -1 / dx
∂x1[end,1] = 0
# sparsify
∂x1 = sparse(∂x1)
return ∂x1
end
# Second Order Central
function f2_secondOrder_central(n,dx)
∂x2 = (diagm(
0 => -2 * ones(n),
-1 => 1 * ones(n-1),
1 => 1 * ones(n-1)
)) ./ (dx^2)
# periodic boundaries
∂x2[1,end] = 1 / (dx^2)
∂x2[end,1] = 1 / (dx^2)
# sparsify
∂x2 = sparse(∂x2)
return ∂x2
end
# generate training data
∂x1 = f1_firstOrder_backward(N,dx)
∂x2 = f2_secondOrder_central(N,dx)
function true_eqn(u,p,t)
u = - u .* (∂x1*u) + ν .* (∂x2*u)
return u
end
function ic_definition(N)
u0 = ones(N)
u0[2:Int(N/2)] .= 2.0
return u0
end
u0 = ic_definition(N)
p = zeros(N);
true_eqn_prob = ODEProblem(true_eqn, u0, tspan,p)
## Create training data
sol_true_eqn = Array(solve(true_eqn_prob,Tsit5(),alg_hints=[:stiff], saveat=dt));
# learning problem
model = Chain(
Dense(N,20,relu; init=Flux.randn32(rng)),
Dense(20,20,relu; init=Flux.randn32(rng)),
Dense(20,N; init=Flux.randn32(rng)));
p, re = Flux.destructure(model);
function learned_eqn(u,p,t)
ϕ = re(p)
u = - u .* (∂x1*u) + (ν .* ϕ(u))
return u
end
learned_eqn_prob = ODEProblem(learned_eqn, u0, tspan, p)
#traindata
traindata = (u0=u0,y=sol_true_eqn[:,end]);
# Training loop
optim = Flux.setup(Flux.Adam(0.01), model);
losses = []
grads_list_u = []
grads_list_p = []
for epoch in 1:epochs
@show epoch
# Step 1: Compute gradient for backprop and loss
grads = Flux.gradient(u0,p) do u0,p
ŷ = solve(learned_eqn_prob, Tsit5(), u0 = u0, p = p, saveat = dt);
l = Flux.mse(ŷ[:,end], traindata.y)
@show l
return l
end
# Step 2: Optimizer update
Flux.update!(optim, learned_eqn_prob, grads[2])
push!(grads_list_u, grads[1])
push!(grads_list_p, grads[2])
end
@show grads_list_u
@show grads_list_p
The output shows that gradient with respect to parameters p
does not change, and neither does the loss l
:
epoch = 1
l =
0.0076257285132982746
epoch = 2
l = 0.0076257285132982746
epoch = 3
l = 0.0076257285132982746
epoch = 4
l = 0.0076257285132982746
epoch = 5
l = 0.0076257285132982746
5-element Vector{Any}:
Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176 … 1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176 … 1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176 … 1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176 … 1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
Float32[0.0008227346, 0.0, 0.0006256843, 0.0, 0.0, 0.00025020575, 0.0, -0.00065144146, 0.0, 0.0005012176 … 1.2426534f-5, -6.6746696f-5, -2.937876f-5, -3.75617f-5, -2.5478084f-5, -1.5328378f-5, -4.5170782f-5, -3.2007174f-5, 0.00012620524, 0.000120991856]
Am I passing the parameters correctly with Flux.update
? I am using the latest version
Julia Version 1.9.3
Commit bed2cd540a1 (2023-08-24 14:43 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (x86_64-apple-darwin22.4.0)
CPU: 16 × Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
Threads: 1 on 16 virtual cores
Environment:
JULIA_NUM_THREADS =
I would appreciate any help to overcome this issue! Thank you.