Optimizing performance of 2D nonlinear diffusion UDE

:thinking: I’d have to see a picture of how it’s going unstable. Though it could also be from using an explicit method.

I have also tried VCABM() and it appears to be working well (like BS3()), but I still get the same error during sciml_train. Wanna try the new MWE and have a guess? :grinning:

using Statistics
using LinearAlgebra
using Random 
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using Tullio
using RecursiveArrayTools
using ComponentArrays

const t₁ = 10                 # number of simulation years 
const ρ = 900f0                     # Ice density [kg / m^3]
const g = 9.81f0                    # Gravitational acceleration [m / s^2]
const n = 3f0                       # Glen's flow law exponent
const maxA = 8f-16
const minA = 3f-17
const maxT = 1f0
const minT = -25f0
A = 1.3f-24 #2e-16  1 / Pa^3 s
A *= Float32(60 * 60 * 24 * 365.25) # [1 / Pa^3 yr]
C = 0
α = 0

@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]

function ref_glacier(temps, H₀)
      
    H = deepcopy(H₀)
    
    # Initialize all matrices for the solver
    S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
    dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
    D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
    V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
    
    # Gather simulation parameters
    current_year = Int(0)
    context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])

    # Perform reference simulation with forward model 
    println("Running forward PDE ice flow model...\n")
    iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
    iceflow_sol = solve(iceflow_prob, BS3(), reltol=1e-6, progress=true, saveat=1.0, progress_steps = 1)

    return Float32.(iceflow_sol[end])
end

function iceflow!(dH, H, context,t)
    # Unpack parameters
    #A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year 
    current_year = Ref(context.x[18])
    A = Ref(context.x[1])
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year[] && year <= t₁
        temp = Ref{Float32}(context.x[7][year])
        A[] .= A_fake(temp[])
        current_year[] .= year
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    SIA!(dH, H, context)
end      


function train_iceflow_UDE(H₀, UA, H_ref, temps)
    
    # Gather simulation parameters
    H = deepcopy(H₀)
    
    # Gather simulation parameters
    current_year = 0
    θ = initial_params(UA)
    context = ComponentArray(B=B, C=C, α=α, temps=temps,current_year=current_year, H_ref=H_ref)
    loss(θ) = loss_iceflow(θ, UA, H, context) # closure

    println("Training iceflow UDE...")
    iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.0001), maxiters = 10)

    return iceflow_trained
end

function loss_iceflow(θ, UA, H, context)

    H = predict_iceflow(θ, UA, H, context)

    l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], context.H_ref[H.!= 0.0]; agg=sum))
    println("Loss = ", l_H)

    return l_H
end

function predict_iceflow(θ, UA, H, context)
    
    iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, θ, t, context, UA) # closure
    tspan = (0.0,t₁)
    iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
    H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, reltol=1e-6, save_everystep=false, 
                   sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP()), 
                   progress=true, progress_steps = 1)

    return H_pred[end]
end

function iceflow_NN!(dH, H, θ, t, context, UA)
    
    year = floor(Int, t) + 1
    if year <= t₁
        temp = context.temps[year]
    else
        temp = context.temps[year-1]
    end
    YA = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters

    if t%1 == 0
        println("A: ", YA)
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    dH .= SIA!(dH, H, YA, context)

end  

"""
    SIA(H, p)

Compute a step of the Shallow Ice Approximation PDE in a forward model
"""

function SIA!(dH, H, context)
    
    # Retrieve parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    A = context.x[1]
    B = context.x[2]
    S = context.x[3]
    dSdx = context.x[4]
    dSdy = context.x[5]
    D = context.x[6]
    dSdx_edges = context.x[8]
    dSdy_edges = context.x[9]
    ∇S = context.x[10]
    Fx = context.x[11]
    Fy = context.x[12]
    
    # Update glacier surface altimetry
    S .= B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx .= diff_x(S) / Δx
    dSdy .= diff_y(S) / Δy
    ∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2) 

    Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s 
    D .= Γ .* avg(H).^(n + 2) .* ∇S

    # Compute flux components
    dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
    Fx .= .-avg_y(D) .* dSdx_edges
    Fy .= .-avg_x(D) .* dSdy_edges 

    #  Flux divergence
    inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here 
    
end

# Function without mutation for Zygote, with context as a tuple
function SIA!(dH, H, A, context::ComponentArray)
    
    # Retrieve parameters
    B = context.B

    # Update glacier surface altimetry
    S = B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx = diff_x(S) / Δx
    dSdy = diff_y(S) / Δy
    ∇S = (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2) 

    Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s 
    D = Γ .* avg(H).^(n + 2) .* ∇S

    # Compute flux components
    dSdx_edges = diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges = diff(S[2:end - 1,:], dims=2) / Δy
    Fx = .-avg_y(D) .* dSdx_edges
    Fy = .-avg_x(D) .* dSdy_edges 

    #  Flux divergence
    @tullio dH[i,j] := -(diff(Fx, dims=1)[pad(i-1,1,1),pad(j-1,1,1)] / Δx + diff(Fy, dims=2)[pad(i-1,1,1),pad(j-1,1,1)] / Δy) # MB to be added here 

    return dH
end


function A_fake(temp)
    return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end

predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16


#### Generate reference dataset ####
nx = ny = 100
const B = zeros(Float32, (nx, ny))
const σ = 1000
H₀ = Matrix{Float32}([ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ])    
Δx = Δy = 50 #m   

const temps =  Vector{Float32}([0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1])
H_ref = ref_glacier(temps, H₀)

# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
        FastDense(1,3, x->tanh.(x)),
        FastDense(3,10, x->tanh.(x)),
        FastDense(10,3, x->tanh.(x)),
        FastDense(3,1, sigmoid_A)
    )

iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)

Yeah, I’ll see when I can get to that. I’m in the middle of a big PR so I’m trying to debug from a distance :sweat_smile:

1 Like

I’ve continued investigating this and I have a little bit more of insight. The first forward pass of the UDE works fine, the initial values of the NN are stable and produce a meaningful loss. However, right after, the ODE solver is called again (I can see the progress bar), but the new initial conditions (i.e. H) are exactly the same as at the end of the previous forward run. And then is when the gradients go to infinity and everything crashes.

Since sciml_train is such a high level black box, I’m having a hard time understanding what is going on exactly. Is that the pullback calling the solver backwards? Is that why the initial conditions match the final conditions of the previous forward run? Not sure if this is of much help, just some extra context that might avoid you trying to debug it. Thanks again!

Yes, if you do BacksolveAdjoint(autojacvec=ZygoteVJP()). That method is will give unstable gradients. I describe this in my talk on adjoints that you never want to use BacksolveAdjoint for any real equations. See the talk starting at:

1 Like

OK, now I’m confused. In your talk you say that the solution of the PDE needs to be in the loss function, but that is exactly what I am doing in my UDE. If I shouldn’t use BacksolveAdjoint(autojacvec=ZygoteVJP()) what should I use in order to use Zygote as we discussed? I think I might have misunderstood the sensealgs.

There are two parts to the sensealg (well many, but let’s simplify). There’s the adjoint choice and the VJP choice. Zygote is the right VJP here, so ZygoteVJP is :ok_hand: . BacksolveAdjoint is unstable for the reasons mentioned in the video, so you probably want to use InterpolatingAdjoint(autojacvec=ZygoteVJP()), or if you have enough memory QuadratureAdjoint(autojacvec=ZygoteVJP()) will be much faster if you’re using an implicit solver. But if you don’t choose a sensealg, it would have probably automatically defaulted to InterpolatingAdjoint(autojacvec=ZygoteVJP()) which is why I normally say to just use the defaults for this kind of thing (on out-of-place, it defaults to Zygote, and then with parameters it chooses the interpolating adjoint in order to be safe for non-trivial equations).

1 Like

Oh I see! OK, now everything’s starting to fall into place. Thanks so much for having taken the time to explain this. After re-watching that part of the talk and re-reading the documentation I think I’m starting to make sense of this, plus I managed to make it work! I realize my confusion came from the fact that I never saw the ZygoteVJP used in the two adjoint methods you mentioned, so I just thought it was something specific to BacksolveAdjoint.

Here are my conclusions so far:

  • Now I totally understand why BacksolveAdjoint is so slow. I did a test using checkpointing=true, and indeed, the backsolve became stable, but with a very high memory cost resulting in a very inefficient solution.

  • InterpolatingAdjoint(autojacvec=ZygoteVJP()) works like a charm. Quite fast on the forward pass, and pretty fast for the pullback. For now, I’m quite happy with this. I’ve tested it with VCABM() and it seemed reasonable. I guess from here onwards I could probably optimize it further into a viable solution.

  • QuadratureAdjoint(autojacvec=ZygoteVJP()) is not working. I tried with both VCABM() and BS3() and it crashes during the pullback with (iceflow_UDE! is my UDE function):

ERROR: LoadError: MethodError: no method matching (::var"#iceflow_UDE!#314"{ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(B = ViewAxis(1:10000, ShapedAxis((100, 100), NamedTuple())), C = 10001, α = 10002, temps = 10003:10012, current_year = 10013, H = ViewAxis(10014:20013, ShapedAxis((100, 100), NamedTuple())), H_ref = ViewAxis(20014:30013, ShapedAxis((100, 100), NamedTuple())), θ = 30014:30096)}}}, FastChain{Tuple{FastDense{var"#322#325", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#323#326", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#324#327", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}})(::Matrix{Float32}, ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, ::Float64)
Closest candidates are:
  (::var"#iceflow_UDE!#314")(::Any, ::Any, ::Any, ::Any) at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:134

Is it worth exploring QuadratureAdjoint for this problem? Maybe I should use it with specific solvers? For now I can use InterpolatingAdjoint, but I’m really curious if QuadratureAdjoint could be faster.

Oh that’s interesting. I didn’t know about this issue. This is indicating that QuadratureAdjoint is only compatible with f(du,u,p,t) forms and not f(u,p,t). That seems to be a missing spot in the testing matrix and would be worth making an MWE and issue to fix.

That said,

If you were using an implicit method with a Newton solver, like TRBDF2, yes it would give a major speedup. With VCABM, no it’s not going to give one in general. So InterpolatingAdjoint with ZygoteVJP is a good choice for that kind of setup + model.

And this complexity is why I am trying to push more and more into just building better default handling because the choice of adjoints is far more complex than most people should have to deal with :sweat_smile:

4 Likes

Now that I have a working version to train the UDEs, I’m realizing once again how slow the training is. Even using 24 cores in a powerful machine every epoch is really slow.

Besides waiting for the fix on the parallelization of the adjoints for the backpropagation, what other things could be done to accelerate this? Since Zygote doesn’t allow mutation, most of the tricks to accelerate DifferentialEquations.jl are out of question.

The chosen solver is BS3, right? I’d play with that a bit, it’s fairly rare that one is decent. What’s the stiffness like? Do you have an eigenvalue estimate?

I have done a benchmark for a short simulation with different solvers. Here are the results:

  • BS3(): 129.758 s (8984491 allocations: 359.19 GiB)

  • OwrenZen3() : 178.252 s (9053362 allocations: 364.58 GiB)

  • RK4() : 241.572 s (16666521 allocations: 710.28 GiB)

  • Ralston() : 104.069 s (7575698 allocations: 284.61 GiB)

  • Heun() : 108.007 s (7584535 allocations: 284.61 GiB)

  • Midpoint(): 108.288 s (7569720 allocations: 284.72 GiB)

  • ROCK2() : 74.382 s (3538753 allocations: 197.26 GiB)

  • ROCK4(): 43.860 s (3108538 allocations: 117.33 GiB)

I have tested other solvers, but whenever they gave terrible results I just stopped the simulation and kept on testing others. With your knowledge of the full palette of solvers, is there anything else I should try? Is there anything similar to ROCK4 worth testing? Thanks!

Anything implicit needs to go through the whole gambit of optimizing the linear solver (supplying sparsity patterns, and preconditioners, etc.) before assessing them. Did you do that part?

https://diffeq.sciml.ai/stable/solvers/ode_solve/#Stabilized-Explicit-Methods

If RKC is the right direction, ESERK5 might be worth a try. I haven’t seen it beat out ROCK methods though.

No, so far I just used the default values. I did try ESERK5 but it was pretty slow. I have tried all Stabilized Explicit Methods, but ROCK4 was the best by far. So far I see the best performing methods are all explicit. If I optimized the linear solver for the implicit ones could they potentially beat ROCK4?

Yes it could potentially beat it, depending on the eigenvalue structure of the problem.