I’d have to see a picture of how it’s going unstable. Though it could also be from using an explicit method.
I have also tried VCABM()
and it appears to be working well (like BS3()
), but I still get the same error during sciml_train
. Wanna try the new MWE and have a guess?
using Statistics
using LinearAlgebra
using Random
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using Tullio
using RecursiveArrayTools
using ComponentArrays
const t₁ = 10 # number of simulation years
const ρ = 900f0 # Ice density [kg / m^3]
const g = 9.81f0 # Gravitational acceleration [m / s^2]
const n = 3f0 # Glen's flow law exponent
const maxA = 8f-16
const minA = 3f-17
const maxT = 1f0
const minT = -25f0
A = 1.3f-24 #2e-16 1 / Pa^3 s
A *= Float32(60 * 60 * 24 * 365.25) # [1 / Pa^3 yr]
C = 0
α = 0
@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]
function ref_glacier(temps, H₀)
H = deepcopy(H₀)
# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
# Gather simulation parameters
current_year = Int(0)
context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])
# Perform reference simulation with forward model
println("Running forward PDE ice flow model...\n")
iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
iceflow_sol = solve(iceflow_prob, BS3(), reltol=1e-6, progress=true, saveat=1.0, progress_steps = 1)
return Float32.(iceflow_sol[end])
end
function iceflow!(dH, H, context,t)
# Unpack parameters
#A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year
current_year = Ref(context.x[18])
A = Ref(context.x[1])
# Get current year for MB and ELA
year = floor(Int, t) + 1
if year != current_year[] && year <= t₁
temp = Ref{Float32}(context.x[7][year])
A[] .= A_fake(temp[])
current_year[] .= year
end
# Compute the Shallow Ice Approximation in a staggered grid
SIA!(dH, H, context)
end
function train_iceflow_UDE(H₀, UA, H_ref, temps)
# Gather simulation parameters
H = deepcopy(H₀)
# Gather simulation parameters
current_year = 0
θ = initial_params(UA)
context = ComponentArray(B=B, C=C, α=α, temps=temps,current_year=current_year, H_ref=H_ref)
loss(θ) = loss_iceflow(θ, UA, H, context) # closure
println("Training iceflow UDE...")
iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.0001), maxiters = 10)
return iceflow_trained
end
function loss_iceflow(θ, UA, H, context)
H = predict_iceflow(θ, UA, H, context)
l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], context.H_ref[H.!= 0.0]; agg=sum))
println("Loss = ", l_H)
return l_H
end
function predict_iceflow(θ, UA, H, context)
iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, θ, t, context, UA) # closure
tspan = (0.0,t₁)
iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, reltol=1e-6, save_everystep=false,
sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP()),
progress=true, progress_steps = 1)
return H_pred[end]
end
function iceflow_NN!(dH, H, θ, t, context, UA)
year = floor(Int, t) + 1
if year <= t₁
temp = context.temps[year]
else
temp = context.temps[year-1]
end
YA = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters
if t%1 == 0
println("A: ", YA)
end
# Compute the Shallow Ice Approximation in a staggered grid
dH .= SIA!(dH, H, YA, context)
end
"""
SIA(H, p)
Compute a step of the Shallow Ice Approximation PDE in a forward model
"""
function SIA!(dH, H, context)
# Retrieve parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
A = context.x[1]
B = context.x[2]
S = context.x[3]
dSdx = context.x[4]
dSdy = context.x[5]
D = context.x[6]
dSdx_edges = context.x[8]
dSdy_edges = context.x[9]
∇S = context.x[10]
Fx = context.x[11]
Fy = context.x[12]
# Update glacier surface altimetry
S .= B .+ H
# All grid variables computed in a staggered grid
# Compute surface gradients on edges
dSdx .= diff_x(S) / Δx
dSdy .= diff_y(S) / Δy
∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2)
Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s
D .= Γ .* avg(H).^(n + 2) .* ∇S
# Compute flux components
dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
Fx .= .-avg_y(D) .* dSdx_edges
Fy .= .-avg_x(D) .* dSdy_edges
# Flux divergence
inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here
end
# Function without mutation for Zygote, with context as a tuple
function SIA!(dH, H, A, context::ComponentArray)
# Retrieve parameters
B = context.B
# Update glacier surface altimetry
S = B .+ H
# All grid variables computed in a staggered grid
# Compute surface gradients on edges
dSdx = diff_x(S) / Δx
dSdy = diff_y(S) / Δy
∇S = (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2)
Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s
D = Γ .* avg(H).^(n + 2) .* ∇S
# Compute flux components
dSdx_edges = diff(S[:,2:end - 1], dims=1) / Δx
dSdy_edges = diff(S[2:end - 1,:], dims=2) / Δy
Fx = .-avg_y(D) .* dSdx_edges
Fy = .-avg_x(D) .* dSdy_edges
# Flux divergence
@tullio dH[i,j] := -(diff(Fx, dims=1)[pad(i-1,1,1),pad(j-1,1,1)] / Δx + diff(Fy, dims=2)[pad(i-1,1,1),pad(j-1,1,1)] / Δy) # MB to be added here
return dH
end
function A_fake(temp)
return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end
predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16
#### Generate reference dataset ####
nx = ny = 100
const B = zeros(Float32, (nx, ny))
const σ = 1000
H₀ = Matrix{Float32}([ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ])
Δx = Δy = 50 #m
const temps = Vector{Float32}([0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1])
H_ref = ref_glacier(temps, H₀)
# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
FastDense(1,3, x->tanh.(x)),
FastDense(3,10, x->tanh.(x)),
FastDense(10,3, x->tanh.(x)),
FastDense(3,1, sigmoid_A)
)
iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)
Yeah, I’ll see when I can get to that. I’m in the middle of a big PR so I’m trying to debug from a distance
I’ve continued investigating this and I have a little bit more of insight. The first forward pass of the UDE works fine, the initial values of the NN are stable and produce a meaningful loss. However, right after, the ODE solver is called again (I can see the progress bar), but the new initial conditions (i.e. H
) are exactly the same as at the end of the previous forward run. And then is when the gradients go to infinity and everything crashes.
Since sciml_train
is such a high level black box, I’m having a hard time understanding what is going on exactly. Is that the pullback calling the solver backwards? Is that why the initial conditions match the final conditions of the previous forward run? Not sure if this is of much help, just some extra context that might avoid you trying to debug it. Thanks again!
Yes, if you do BacksolveAdjoint(autojacvec=ZygoteVJP())
. That method is will give unstable gradients. I describe this in my talk on adjoints that you never want to use BacksolveAdjoint for any real equations. See the talk starting at:
OK, now I’m confused. In your talk you say that the solution of the PDE needs to be in the loss function, but that is exactly what I am doing in my UDE. If I shouldn’t use BacksolveAdjoint(autojacvec=ZygoteVJP())
what should I use in order to use Zygote as we discussed? I think I might have misunderstood the sensealgs.
There are two parts to the sensealg (well many, but let’s simplify). There’s the adjoint choice and the VJP choice. Zygote is the right VJP here, so ZygoteVJP is . BacksolveAdjoint is unstable for the reasons mentioned in the video, so you probably want to use InterpolatingAdjoint(autojacvec=ZygoteVJP())
, or if you have enough memory QuadratureAdjoint(autojacvec=ZygoteVJP())
will be much faster if you’re using an implicit solver. But if you don’t choose a sensealg, it would have probably automatically defaulted to InterpolatingAdjoint(autojacvec=ZygoteVJP())
which is why I normally say to just use the defaults for this kind of thing (on out-of-place, it defaults to Zygote, and then with parameters it chooses the interpolating adjoint in order to be safe for non-trivial equations).
Oh I see! OK, now everything’s starting to fall into place. Thanks so much for having taken the time to explain this. After re-watching that part of the talk and re-reading the documentation I think I’m starting to make sense of this, plus I managed to make it work! I realize my confusion came from the fact that I never saw the ZygoteVJP
used in the two adjoint methods you mentioned, so I just thought it was something specific to BacksolveAdjoint
.
Here are my conclusions so far:
-
Now I totally understand why
BacksolveAdjoint
is so slow. I did a test usingcheckpointing=true
, and indeed, the backsolve became stable, but with a very high memory cost resulting in a very inefficient solution. -
InterpolatingAdjoint(autojacvec=ZygoteVJP())
works like a charm. Quite fast on the forward pass, and pretty fast for the pullback. For now, I’m quite happy with this. I’ve tested it withVCABM()
and it seemed reasonable. I guess from here onwards I could probably optimize it further into a viable solution. -
QuadratureAdjoint(autojacvec=ZygoteVJP())
is not working. I tried with bothVCABM()
andBS3()
and it crashes during the pullback with (iceflow_UDE!
is my UDE function):
ERROR: LoadError: MethodError: no method matching (::var"#iceflow_UDE!#314"{ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(B = ViewAxis(1:10000, ShapedAxis((100, 100), NamedTuple())), C = 10001, α = 10002, temps = 10003:10012, current_year = 10013, H = ViewAxis(10014:20013, ShapedAxis((100, 100), NamedTuple())), H_ref = ViewAxis(20014:30013, ShapedAxis((100, 100), NamedTuple())), θ = 30014:30096)}}}, FastChain{Tuple{FastDense{var"#322#325", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#323#326", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#324#327", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}})(::Matrix{Float32}, ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, ::Float64)
Closest candidates are:
(::var"#iceflow_UDE!#314")(::Any, ::Any, ::Any, ::Any) at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:134
Is it worth exploring QuadratureAdjoint
for this problem? Maybe I should use it with specific solvers? For now I can use InterpolatingAdjoint
, but I’m really curious if QuadratureAdjoint
could be faster.
Oh that’s interesting. I didn’t know about this issue. This is indicating that QuadratureAdjoint is only compatible with f(du,u,p,t)
forms and not f(u,p,t)
. That seems to be a missing spot in the testing matrix and would be worth making an MWE and issue to fix.
That said,
If you were using an implicit method with a Newton solver, like TRBDF2
, yes it would give a major speedup. With VCABM
, no it’s not going to give one in general. So InterpolatingAdjoint
with ZygoteVJP
is a good choice for that kind of setup + model.
And this complexity is why I am trying to push more and more into just building better default handling because the choice of adjoints is far more complex than most people should have to deal with
Now that I have a working version to train the UDEs, I’m realizing once again how slow the training is. Even using 24 cores in a powerful machine every epoch is really slow.
Besides waiting for the fix on the parallelization of the adjoints for the backpropagation, what other things could be done to accelerate this? Since Zygote doesn’t allow mutation, most of the tricks to accelerate DifferentialEquations.jl are out of question.
The chosen solver is BS3, right? I’d play with that a bit, it’s fairly rare that one is decent. What’s the stiffness like? Do you have an eigenvalue estimate?
I have done a benchmark for a short simulation with different solvers. Here are the results:
-
BS3(): 129.758 s (8984491 allocations: 359.19 GiB)
-
OwrenZen3() : 178.252 s (9053362 allocations: 364.58 GiB)
-
RK4() : 241.572 s (16666521 allocations: 710.28 GiB)
-
Ralston() : 104.069 s (7575698 allocations: 284.61 GiB)
-
Heun() : 108.007 s (7584535 allocations: 284.61 GiB)
-
Midpoint(): 108.288 s (7569720 allocations: 284.72 GiB)
-
ROCK2() : 74.382 s (3538753 allocations: 197.26 GiB)
-
ROCK4(): 43.860 s (3108538 allocations: 117.33 GiB)
I have tested other solvers, but whenever they gave terrible results I just stopped the simulation and kept on testing others. With your knowledge of the full palette of solvers, is there anything else I should try? Is there anything similar to ROCK4 worth testing? Thanks!
Anything implicit needs to go through the whole gambit of optimizing the linear solver (supplying sparsity patterns, and preconditioners, etc.) before assessing them. Did you do that part?
https://diffeq.sciml.ai/stable/solvers/ode_solve/#Stabilized-Explicit-Methods
If RKC is the right direction, ESERK5 might be worth a try. I haven’t seen it beat out ROCK methods though.
No, so far I just used the default values. I did try ESERK5 but it was pretty slow. I have tried all Stabilized Explicit Methods, but ROCK4 was the best by far. So far I see the best performing methods are all explicit. If I optimized the linear solver for the implicit ones could they potentially beat ROCK4?
Yes it could potentially beat it, depending on the eigenvalue structure of the problem.