Optimizing performance of 2D nonlinear diffusion UDE

I have a UDE based on a 2D nonlinear ice flow diffusion PDE. It works correctly both for the forward and backwards pass, but despite using a semi-implicit solver the memory usage for backpropagation is just too big for even very short simulations, resulting in the process being killed. So far my model is implemented to parallelize multiple simulations for each training batch using pmap (not shown), but this doesn’t help, since the memory issues occur within each of the simulations of each batch.

I’m still very new to Zygote, so I was wondering: what would be the best way to optimize Zygote in that case in order to circumvent memory issues?

Here’s a rough overview of what I’m doing:

# Loss function
function loss(H, glacier_ref, UA, p, t, t₁)
  
    H = iceflow!(H, UA, p,t,t₁)
    l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], glacier_ref["H"][end][H.!= 0.0]; agg=sum))

    return l_H
end

# Ice flow UDE
function iceflow!(H, UA, p,t,t₁)

    # Retrieve input variables  
    let                  
    current_year = 0
    total_iter = 0
    t_step = 0
    temps = p[6]

    # Forward scheme implementation
    while t < t₁
        let
        iter = 1
        err = 2 * tolnl

        Hold = copy(H)
        dHdt = zeros(nx, ny)

        # Get current year for MB and ELA
        year = floor(Int, t) + 1

        if year != current_year

            # Predict value of `A`
            temp = [temps[year]]'
                    
            ŶA = predict_A̅(UA, temp)

            ## Unpack and repack tuple with updated A value
            Δx, Δy, Γ, A, B, temps, C, α = p
            p = (Δx, Δy, Γ, ŶA, B, temps, C, α)
            current_year = year
        end
           
        while err > tolnl && iter < itMax+1
       
            Err = copy(H)

            # Compute the Shallow Ice Approximation in a staggered grid
            F, dτ = SIA(H, p)

            # Compute the residual ice thickness for the inertia
            @tullio ResH[i,j] := -(H[i,j] - Hold[i,j])/Δt + F[pad(i-1,1,1),pad(j-1,1,1)]

            dHdt_ = copy(dHdt)
            @tullio dHdt[i,j] := dHdt_[i,j]*damp + ResH[i,j]
                            
            # We keep local copies for tullio
            H_ = copy(H)
            
            # Update the ice thickness
            @tullio H[i,j] := max(0.0, H_[i,j] + dHdt[i,j]*dτ)

            iter += 1
            total_iter += 1

        end 
          
        t += Δt
        t_step += 1
    
        end # let
    end 

    return H

    end   # let

end

# Shallow Ice Approximation
function SIA(H, p)
    Δx, Δy, Γ, A, B, temps, C, α = p

    # Update glacier surface altimetry
    S = B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx  = diff(S, dims=1) / Δx
    dSdy  = diff(S, dims=2) / Δy
    ∇S² = avg_y(dSdx).^2 .+ avg_x(dSdy).^2

    Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s 

    D = Γ .* avg(H).^(n + 2) .* ∇S².^((n - 1)/2)

    # Compute flux components
    dSdx_edges = diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges = diff(S[2:end - 1,:], dims=2) / Δy
    Fx = .-avg_y(D) .* dSdx_edges
    Fy = .-avg_x(D) .* dSdy_edges    
    #  Flux divergence
    F = .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here 

    # Compute dτ for the implicit method
    #dτ = dτsc * min.( 10.0 , 1.0./(1.0/Δt .+ 1.0./(cfl./(ϵ .+ avg(D)))))
    D_max = 3000000
    current_D_max = maximum(D) 
    if D_max < current_D_max
        error("Increase Maximum diffusivity. Required value must be larger than $current_D_max")
    end
    dτ = dτsc * min( 10.0 , 1.0/(1.0/Δt + 1.0/(cfl/(ϵ + D_max))))

    return F,  dτ

end

# Compute pullback for ice flow UDE
loss_UA, back_UA = Zygote.pullback(() -> loss(H, glacier_ref, UA, p, t, t₁), θ)

Thanks in advance!

What are the appropriate parameters to test this?

That’s a very inefficient time stepping method. Doing direct AD of an semi-implicit Euler method is going to be both compute and memory inefficient. Why not use something like KenCarp4 instead? Then in the sensealg you want to turn on checkpointing to reduce the memory usage. That would also not do the direct AD through the implicit solver, which of course can be handled O(1) instead of O(n) like is done here.

1 Like

@goerch I don’t have any MWE at hand. But I was just looking for general advice on strategies, so no real need to run my code.

Thanks @ChrisRackauckas for the advice. I’ve spent some time re-writing my code to fit inside the DifferentialEquations.jl solvers, and I must say it was worth it. Now my code is cleaner, and I can test a whole bunch of different solvers.

So far I have focused on optimizing the forward model. KenCarp4 doesn’t seem to work, it just doesn’t show any progress and I don’t get any errors or warnings. BS3 and Vern7 have turned out to be the best ones so far (though I should still test more I guess). The next thing I’ve tried to do is to follow your advice on how to correctly optimize the functions to be passed to solvers. What I realized, is that despite using @views, I still can see a lot of memory allocations slowing down the code. Here’s the new version of my code with DifferentialEquations.jl:

function iceflow!(dH, H, p,t)
    # Retrieve parameters
    #A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year 
    current_year = @view p[end]
    A = @view p[1]
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year && year <= t₁

        temp = @view p[7][year]
        A .= A_fake(temp)

        # Unpack and repack tuple with updated A value
        current_year .= year
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    SIA!(dH, H, p)
end  

function SIA!(dH, H, p)
    # Retrieve parameters
    #A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year  
    
    A = @view p[1]
    B = @view p[2][:,:]
    S = @view p[3][:,:]
    dSdx = @view p[4][:,:]
    dSdy = @view p[5][:,:]
    D = @view p[6][:,:]
    dSdx_edges = @view p[8][:,:]
    dSdy_edges = @view p[9][:,:]
    ∇S = @view p[10][:,:]
    Fx = @view p[11][:,:]
    Fy = @view p[12][:,:]
    Vx = @view p[13][:,:]
    Vy = @view p[14][:,:]

    # Update glacier surface altimetry
    S .= B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx .= diff(S, dims=1) ./ Δx
    dSdy .= diff(S, dims=2) / Δy
    ∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2) 

    Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s 
    
    D .= Γ .* avg(H).^(n + 2) .* ∇S

    # Compute flux components
    dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
    Fx .= .-avg_y(D) .* dSdx_edges
    Fy .= .-avg_x(D) .* dSdy_edges 

    #  Flux divergence
    inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here 

    # Compute velocities    
    Vx .= -D./(avg(H) .+ ϵ).*avg_y(dSdx)
    Vy .= -D./(avg(H) .+ ϵ).*avg_x(dSdy)
end

@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )

# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
    
# Gather simulation parameters
current_year = 0
p = [A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year]

# Perform reference simulation with forward model 
iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),p)
iceflow_sol = solve(iceflow_prob, BS3(), progress=true, progress_steps = 1)

When I check memory allocations with @allocated or @time, I can see memory being allocated in operations like dSdx .= diff(S, dims=1) ./ Δx. However, things like S .= B .+ H have 0 memory allocated. For a single forward run with solve I can see up to some GBs being allocated. What is going on? What am I doing wrong? Once I have something optimized working in forward mode I’ll try to add AD for a UDE.

Then in the sensealg you want to turn on checkpointing to reduce the memory usage. That would also not do the direct AD through the implicit solver, which of course can be handled O(1) instead of O(n) like is done here.

What do you mean by the sensealg?

Thanks in advance!

Did you try JET.jl already?

1 Like

diff(S, dims=1) allocates a new array. Instead, you could try something like:

dSdx .= @views (S[begin + 1:end, :] .- S[1:end - 1, :]) ./ Δx

Thanks for your reply. The thing is that even operations like this seem to be allocating memory:

@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )

Fx .= .-avg_y(D) .* dSdx_edges

So it’s not a matter of just the diff function, but something happening even with @views and indexing. The only operation that didn’t seem to allocate any memory was a pure sum of matrix @views: S .= B .+ H.

It’s still the same problem. Your avg_y function has the very same problem of diff, in that it materialises the result into a new array before returning.

One workaround would be to allow the avg_y function to accept an output array:

function avg_y!(u, A)
    u .= @views 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
    u
end

and then to allocate such an output array before this function is called.

2 Likes

How many ODEs? Most likely if it’s huge it’s just stuck doing matrix inverses and you’ll want to swap out to an iterative solver like GMRES. Otherwise, you might want to try modelingtoolkitize to generate the analytical sparse Jacobian (and parallelize it), or just run sparsity detection to get the sparsity pattern. Another thing is to try CVODE_BDF(linear_solver=:GMRES) and see what happens. If this PDE generally has real eigenvalues, I’d suggest trying ROCK2() or ROCK4(). Also if it’s huge try QNDF or FBDF instead of KenCarp.

Since the solver needs to work well for the pullback, I have moved on to the implementation of the full UDE using DiffEqFlux.jl before spending more time testing different solvers and checking the memory allocation issues above.

I’m currently struggling to find a way to pass all my PDE matrices (p) to the loss function and the solver wrapped by sciml_train. All the examples I’ve come across are optimized only based on the NN parameters (θ), but my model requires the full p array with all the matrices required for the staggered grid. I’m sure it must be something quite straightforward, so could someone please indicate the best way to do this? This is what I have so far:

function loss_iceflow(p)
    H = predict_iceflow(p)
    H_ref = @view p[19][:,:]
    l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], H_ref[H.!= 0.0]; agg=sum))
    return l_H
end

function predict_iceflow(p)       
    H = @view p[20][:,:]    
    tspan = (0.0,t₁)
    iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,p)
    H_pred = solve(iceflow_prob, BS3(), u0=H, p=p, saveat=1.0, 
                   sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP()), 
                   progress=true, progress_steps = 1)

    return H_pred[end]
end

function iceflow_UDE!(dH, H, p,t)
    # Unpack parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    current_year = p[18]
    A = @view p[1]
    UA = p[21]
    θ = p[22]
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year && year <= t₁
        A = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters
        current_year = year
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    SIA!(dH, H, p)

end  

loss(θ) = loss_iceflow(p)
iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(hyparams.η), cb=callback, maxiters = 10)

and this crashes with:

ERROR: LoadError: MethodError: no method matching (::var"#loss#10"{Vector{Any}})(::Vector{Float32})
Closest candidates are:
  (::var"#loss#10")() at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/helpers/iceflow_DiffEqs.jl:101
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context, f::var"#loss#10"{Vector{Any}}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:9
  [3] _pullback
    @ ~/.julia/packages/DiffEqFlux/jpIWG/src/train.jl:84 [inlined]
  [4] _pullback(::Zygote.Context, ::WARNING: both Flux and Iterators export "flatten"; uses of it in module DiffEqFlux must be qualified
WARNING: both Flux and Distributions export "params"; uses of it in module DiffEqFlux must be qualified
DiffEqFlux.var"#82#87"{var"#loss#10"{Vector{Any}}}, ::Vector{Float32}, ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
  [5] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:804
  [6] adjoint
    @ ~/.julia/packages/Zygote/AlLTp/src/lib/lib.jl:200 [inlined]
  [7] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
  [8] _pullback
    @ ~/.julia/packages/SciMLBase/7GnZA/src/problems/basic_problems.jl:107 [inlined]
  [9] _pullback(::Zygote.Context, ::OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{var"#loss#10"{Vector{Any}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Vector{Float32}, ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [10] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:804
 [11] adjoint
    @ ~/.julia/packages/Zygote/AlLTp/src/lib/lib.jl:200 [inlined]
 [12] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [13] _pullback
    @ ~/.julia/packages/GalacticOptim/DHxE0/src/function/zygote.jl:6 [inlined]
 [14] _pullback(ctx::Zygote.Context, f::GalacticOptim.var"#260#270"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{var"#loss#10"{Vector{Any}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [15] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:804
 [16] adjoint
    @ ~/.julia/packages/Zygote/AlLTp/src/lib/lib.jl:200 [inlined]
 [17] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [18] _pullback
    @ ~/.julia/packages/GalacticOptim/DHxE0/src/function/zygote.jl:8 [inlined]
 [19] _pullback(ctx::Zygote.Context, f::GalacticOptim.var"#263#273"{Tuple{}, GalacticOptim.var"#260#270"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{var"#loss#10"{Vector{Any}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [20] _pullback(f::Function, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:34
 [21] pullback(f::Function, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:40
 [22] gradient(f::Function, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/AlLTp/src/compiler/interface.jl:75
 [23] (::GalacticOptim.var"#261#271"{GalacticOptim.var"#260#270"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{var"#loss#10"{Vector{Any}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}})(::Vector{Float32}, ::Vector{Float32})
    @ GalacticOptim ~/.julia/packages/GalacticOptim/DHxE0/src/function/zygote.jl:8
 [24] macro expansion
    @ ~/.julia/packages/GalacticOptim/DHxE0/src/solve/flux.jl:41 [inlined]
 [25] macro expansion
    @ ~/.julia/packages/GalacticOptim/DHxE0/src/utils.jl:35 [inlined]
 [26] __solve(prob::OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{var"#loss#10"{Vector{Any}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#261#271"{GalacticOptim.var"#260#270"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{var"#loss#10"{Vector{Any}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#264#274"{GalacticOptim.var"#260#270"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{var"#loss#10"{Vector{Any}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#269#279", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, typeof(callback), Tuple{Symbol}, NamedTuple{(:cb,), Tuple{typeof(callback)}}}}, opt::RMSProp, data::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ GalacticOptim ~/.julia/packages/GalacticOptim/DHxE0/src/solve/flux.jl:39
 [27] #solve#476
    @ ~/.julia/packages/SciMLBase/7GnZA/src/solve.jl:3 [inlined]
 [28] sciml_train(::var"#loss#10"{Vector{Any}}, ::Vector{Float32}, ::RMSProp, ::Nothing; lower_bounds::Nothing, upper_bounds::Nothing, maxiters::Int64, kwargs::Base.Iterators.Pairs{Symbol, typeof(callback), Tuple{Symbol}, NamedTuple{(:cb,), Tuple{typeof(callback)}}})
    @ DiffEqFlux ~/.julia/packages/DiffEqFlux/jpIWG/src/train.jl:89
 [29] train_iceflow_UDE(H₀::Matrix{Float32}, UA::Function, H_refs::Vector{Matrix{Float32}}, temp_series::Vector{Any}, hyparams::Hyperparameters, idx::Int64)
    @ Main ~/Desktop/Jordi/Julia/odinn_toy_model/scripts/helpers/iceflow_DiffEqs.jl:103
 [30] #8
    @ ~/Desktop/Jordi/Julia/odinn_toy_model/scripts/helpers/iceflow_DiffEqs.jl:72 [inlined]
 [31] iterate
    @ ./generator.jl:47 [inlined]
 [32] _collect(c::Vector{Int64}, itr::Base.Generator{Vector{Int64}, var"#8#9"{Matrix{Float32}, FastChain{Tuple{FastDense{var"#12#20", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#13#21", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#14#22", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#sigmoid_A#17"{Int64, Float64}, DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Vector{Matrix{Float32}}, Vector{Any}, Hyperparameters}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:691
 [33] collect_similar(cont::Vector{Int64}, itr::Base.Generator{Vector{Int64}, var"#8#9"{Matrix{Float32}, FastChain{Tuple{FastDense{var"#12#20", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#13#21", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#14#22", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#sigmoid_A#17"{Int64, Float64}, DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Vector{Matrix{Float32}}, Vector{Any}, Hyperparameters}})
    @ Base ./array.jl:606
 [34] map(f::Function, A::Vector{Int64})
    @ Base ./abstractarray.jl:2294
 [35] train_batch_iceflow_UDE(H₀::Matrix{Float32}, UA::Function, H_refs::Vector{Matrix{Float32}}, temp_series::Vector{Any}, hyparams::Hyperparameters, idxs::Vector{Int64})
    @ Main ~/Desktop/Jordi/Julia/odinn_toy_model/scripts/helpers/iceflow_DiffEqs.jl:72
 [36] macro expansion
    @ ./timing.jl:210 [inlined]
 [37] top-level scope
    @ ~/Desktop/Jordi/Julia/odinn_toy_model/scripts/ice_dynamics_DiffEqs.jl:218
in expression starting at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/ice_dynamics_DiffEqs.jl:190

For my previous manual implementation with Zygote.jl (which I’m trying to implement with DiffEqFlux.jl), I just did:

loss_UA, back_UA = Zygote.pullback(() -> loss(H, H_ref, UA, p, t, t₁), θ)

My guess looking at the error is that sciml_train is passing θ to loss by default, which is not the required input, since I need p instead. What is the right way to pass more parameters aside from θ to the loss function? Once I get my head around this I’ll try to make a small PR in order to clarify this in the docs. Thanks in advance!

what’s the full current code to copy/paste and test?

Hi Chris, I believe I have managed to work around the issue of passing all the matrices using closures (see MWE below). Now the forward run works correctly, but I’m getting the classic “Mutating arrays is not allowed” from Zygote. Here is a MWE to reproduce the issue:

using Statistics
using LinearAlgebra
using Random 
using OrdinaryDiffEq
using DiffEqFlux
using Flux

const t₁ = 10                 # number of simulation years 
const ρ = 900                     # Ice density [kg / m^3]
const g = 9.81                    # Gravitational acceleration [m / s^2]
const n = 3                       # Glen's flow law exponent
const maxA = 8e-16
const minA = 3e-17
const maxT = 1
const minT = -25
A = 1.3e-24 #2e-16  1 / Pa^3 s
A *= 60 * 60 * 24 * 365.25 # [1 / Pa^3 yr]
C = 0
α = 0

@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]

function ref_glacier(temps, H₀)
      
    H = deepcopy(H₀)
    
    # Initialize all matrices for the solver
    S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
    dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
    D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
    V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
    
    # Gather simulation parameters
    current_year = 0
    p = [A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year]

    # Perform reference simulation with forward model 
    println("Running forward PDE ice flow model...\n")
    iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),p)
    iceflow_sol = solve(iceflow_prob, BS3(), progress=true, saveat=1.0, progress_steps = 1)

    return iceflow_sol[end] 
end

function iceflow!(dH, H, p,t)
    # Unpack parameters
    #A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year 
    current_year = @view p[18]
    A = @view p[1]
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year && year <= t₁
        temp = @view p[7][year]
        A .= A_fake(temp)
        current_year .= year
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    SIA!(dH, H, p)
end  

function train_iceflow_UDE(H₀, UA, H_ref, temps)
    
    # Gather simulation parameters
    H = deepcopy(H₀)
    
    # Initialize all matrices for the solver
    S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
    dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
    D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
    V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
    
    # Gather simulation parameters
    current_year = 0
    θ = initial_params(UA)
    context = [A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ]
    loss(θ) = loss_iceflow(UA, θ, H, context) # closure

    iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01), maxiters = 100)

    return iceflow_trained
end

function loss_iceflow(UA, θ, H, context)
    
    H = predict_iceflow(UA, θ, H, context)
    
    H_ref = @view context[19][:,:]
    l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], H_ref[H.!= 0.0]; agg=sum))

    return l_H
end

function predict_iceflow(UA, θ, H, context)
        
    iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, θ, t, context) # closure
    tspan = (0.0,t₁)
    iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
    H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, saveat=1.0, 
                   sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP()), 
                   progress=true, progress_steps = 1)

    return H_pred[end]
end

function iceflow_NN!(dH, H, θ, t, context)
    
    # Unpack parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    current_year = context[18]
    A = @view context[1]
    UA = context[21]
    dH = @view context[20][:,:]
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year && year <= t₁
        temp = context[7][year]
        A .= predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters
        current_year = year
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    SIA!(dH, H, context)

end  

"""
    SIA(H, p)

Compute a step of the Shallow Ice Approximation PDE in a forward model
"""

function SIA!(dH, H, context)
    
    # Retrieve parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    A = Ref{Float32}(context[1])
    B = @view context[2][:,:]
    S = @view context[3][:,:]
    dSdx = @view context[4][:,:]
    dSdy = @view context[5][:,:]
    D = @view context[6][:,:]
    dSdx_edges = @view context[8][:,:]
    dSdy_edges = @view context[9][:,:]
    ∇S = @view context[10][:,:]
    Fx = @view context[11][:,:]
    Fy = @view context[12][:,:]
    Vx = @view context[13][:,:]
    Vy = @view context[14][:,:]
    
    # Update glacier surface altimetry
    S .= B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx .= diff_x(S) / Δx
    dSdy .= diff_y(S) / Δy
    ∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2) 

    Γ = 2 * A[] * (ρ * g)^n / (n+2) # 1 / m^3 s 
    D .= Γ .* avg(H).^(n + 2) .* ∇S

    # Compute flux components
    dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
    Fx .= .-avg_y(D) .* dSdx_edges
    Fy .= .-avg_x(D) .* dSdy_edges 

    #  Flux divergence
    inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here 
    
end

function A_fake(temp)
    return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end

predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16


#### Generate reference dataset ####
nx = ny = 100
B = zeros(Float64, (nx, ny))
σ = 1000
H₀ = [ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ]    
Δx = Δy = 50 #m   

temps = [0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1]
H_ref = ref_glacier(temps, H₀)

# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
        FastDense(1,3, x->tanh.(x)),
        FastDense(3,10, x->tanh.(x)),
        FastDense(10,3, x->tanh.(x)),
        FastDense(3,1, sigmoid_A)
    )

iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)

In my manual implementation of this with Zygote I managed to circumvent this by copying the matrices at each step and by using intermediate buffers in the operations that required updating a given matrix. That is of course very inefficient. Do I still need to do the same with DiffEqFlux.jl or is there any way to pass everything by reference using @views as I implemented in the MWE above? Thanks again for your help!

What does EnzymeVJP() spit out? Zygote won’t work here, but in theory Enzyme can.

With sensealg = BacksolveAdjoint(autojacvec=EnzymeVJP()) it gives a very lengthy low-level error. The only meaningful parts I could see were at the very end with:

could not deduce type of integer   %.repack31 = getelementptr inbounds { {} addrspace(10)*, [1 x i64], i64, i64 }, { {} addrspace(10)*, [1 x i64], i64, i64 } addrspace(10)* %52, i64 0, i32 2, !dbg !104 num:8 q:{[]:Pointer} 
warning: /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:128:0: failed to deduce type of value   %.repack31 = getelementptr inbounds { {} addrspace(10)*, [1 x i64], i64, i64 }, { {} addrspace(10)*, [1 x i64], i64, i64 } addrspace(10)* %52, i64 0, i32 2, !dbg !104
Assertion failed: (0 && "could not deduce type of integer"), function firstPointer, file /workspace/srcdir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp, line 4724.

signal (6): Abort trap: 6
in expression starting at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:212
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 356126791 (Pool: 356060610; Big: 66181); GC: 170
zsh: abort      julia

Looks like that’s from a type-instability. If you check the stability of your f function do you get inference issues?

Hmm, apparently yes:

return type SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, UnitRange{Int64}}, false} does not match inferred return type Any

I don’t quite know how to interpret that, since the iceflow! and SIA! functions don’t have a return statement, everything is passed by reference with @views.

What if you return nothing at the end of them?

I get a very similar error:

could not deduce type of integer   %.repack31 = getelementptr inbounds { {} addrspace(10)*, [1 x i64], i64, i64 }, { {} addrspace(10)*, [1 x i64], i64, i64 } addrspace(10)* %52, i64 0, i32 2, !dbg !105 num:8 q:{[]:Pointer} 
warning: /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:129:0: failed to deduce type of value   %.repack31 = getelementptr inbounds { {} addrspace(10)*, [1 x i64], i64, i64 }, { {} addrspace(10)*, [1 x i64], i64, i64 } addrspace(10)* %52, i64 0, i32 2, !dbg !105
Assertion failed: (0 && "could not deduce type of integer"), function firstPointer, file /workspace/srcdir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp, line 4724.

signal (6): Abort trap: 6
in expression starting at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:217
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 381551632 (Pool: 381487983; Big: 63649); GC: 176

Do you still have internal instabilities?

No, now @inferred doesn’t return anything.