Optimizing performance of 2D nonlinear diffusion UDE

OK, so as discussed, I’m going back to a Zygote implementation of this, with the goal of making it as optimized as possible while waiting for Enzyme to work.

Unlike my manual previous implementation of this using pullback which used to work, when I use BacksolveAdjoint(autojacvec=ZygoteVJP()) I’m getting an error that I cannot debug. In order to avoid mutation issues, I’m passing a tuple of model parameters and I’m returning a copy of dH.

Here’s a new MWE with the updated implementation:

using Statistics
using LinearAlgebra
using Random 
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using Tullio
using RecursiveArrayTools
using Infiltrator

const t₁ = 10                 # number of simulation years 
const ρ = 900f0                     # Ice density [kg / m^3]
const g = 9.81f0                    # Gravitational acceleration [m / s^2]
const n = 3f0                       # Glen's flow law exponent
const maxA = 8f-16
const minA = 3f-17
const maxT = 1f0
const minT = -25f0
A = 1.3f-24 #2e-16  1 / Pa^3 s
A *= Float32(60 * 60 * 24 * 365.25) # [1 / Pa^3 yr]
C = 0
α = 0

@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]

function ref_glacier(temps, H₀)
      
    H = deepcopy(H₀)
    
    # Initialize all matrices for the solver
    S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
    dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
    D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
    V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
    
    # Gather simulation parameters
    current_year = Int(0)
    context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])

    # Perform reference simulation with forward model 
    println("Running forward PDE ice flow model...\n")
    iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
    iceflow_sol = solve(iceflow_prob, BS3(), progress=true, saveat=1.0, progress_steps = 1)

    return Float32.(iceflow_sol[end])
end

function iceflow!(dH, H, context,t)
    # Unpack parameters
    #A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year 
    current_year = Ref(context.x[18])
    A = Ref(context.x[1])
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year[] && year <= t₁
        temp = Ref{Float32}(context.x[7][year])
        A[] .= A_fake(temp[])
        current_year[] .= year
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    SIA!(dH, H, context)
end      


function train_iceflow_UDE(H₀, UA, H_ref, temps)
    
    # Gather simulation parameters
    H = deepcopy(H₀)
    
    # Initialize all matrices for the solver
    S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
    dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
    D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
    V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
    
    # Gather simulation parameters
    current_year = 0
    θ = initial_params(UA)
    context = (A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H)
    loss(θ) = loss_iceflow(UA, θ, H, context) # closure

    println("Training iceflow UDE...")

    iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01), maxiters = 10)

    return iceflow_trained
end

function loss_iceflow(UA, θ, H, context)
    
    H = predict_iceflow(UA, θ, H, context)
    
    H_ref = context[19]
    l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], H_ref[H.!= 0.0]; agg=sum))
    println("Loss = ", l_H)

    return l_H
end

function predict_iceflow(UA, θ, H, context)
        
    iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, context, UA, θ, t) # closure
    tspan = (0.0,t₁)
    iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
    H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, save_everystep=false, 
                   sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP()), 
                   progress=true, progress_steps = 1)

    return H_pred[end]
end

function iceflow_NN!(dH, H, context, UA, θ, t)
    
    # Unpack parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    current_year = context[18]
    A = context[1]
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year && year <= t₁
        temp = context[7][year]
        YA = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters

        # Unpack and repack tuple to update `A` and `current_year`
        A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H = context
        context = (YA, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, year, H_ref, H)

    end

    # Compute the Shallow Ice Approximation in a staggered grid
    dH = SIA!(dH, H, context)

end  

"""
    SIA(H, p)

Compute a step of the Shallow Ice Approximation PDE in a forward model
"""

function SIA!(dH, H, context)
    
    # Retrieve parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    A = context.x[1]
    B = context.x[2]
    S = context.x[3]
    dSdx = context.x[4]
    dSdy = context.x[5]
    D = context.x[6]
    dSdx_edges = context.x[8]
    dSdy_edges = context.x[9]
    ∇S = context.x[10]
    Fx = context.x[11]
    Fy = context.x[12]
    
    # Update glacier surface altimetry
    S .= B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx .= diff_x(S) / Δx
    dSdy .= diff_y(S) / Δy
    ∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2) 

    Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s 
    D .= Γ .* avg(H).^(n + 2) .* ∇S

    # Compute flux components
    dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
    Fx .= .-avg_y(D) .* dSdx_edges
    Fy .= .-avg_x(D) .* dSdy_edges 

    #  Flux divergence
    inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here 
    
    # Compute velocities    
    #Vx = -D./(avg(H) .+ ϵ).*avg_y(dSdx)
    #Vy = -D./(avg(H) .+ ϵ).*avg_x(dSdy)
end

# Function without mutation for Zygote, with context as a tuple
function SIA!(dH, H, context::Tuple)
    
    # Retrieve parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    A = context[1]
    B = context[2]
    S = context[3]
    dSdx = context[4]
    dSdy = context[5]
    D = context[6]
    dSdx_edges = context[8]
    dSdy_edges = context[9]
    ∇S = context[10]
    Fx = context[11]
    Fy = context[12]
    
    # Update glacier surface altimetry
    S = B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx = diff_x(S) / Δx
    dSdy = diff_y(S) / Δy
    ∇S = (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2) 

    Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s 
    D = Γ .* avg(H).^(n + 2) .* ∇S

    # Compute flux components
    dSdx_edges = diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges = diff(S[2:end - 1,:], dims=2) / Δy
    Fx = -avg_y(D) .* dSdx_edges
    Fy = -avg_x(D) .* dSdy_edges 

    #  Flux divergence
    @tullio dH[i,j] := -(diff(Fx, dims=1)[pad(i-1,1,1),pad(j-1,1,1)] / Δx + diff(Fy, dims=2)[pad(i-1,1,1),pad(j-1,1,1)] / Δy) # MB to be added here 

    return dH
end


function A_fake(temp)
    return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end

predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16


#### Generate reference dataset ####
nx = ny = 100
const B = zeros(Float32, (nx, ny))
const σ = 1000
H₀ = Matrix{Float32}([ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ])    
Δx = Δy = 50 #m   

const temps =  Vector{Float32}([0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1])
H_ref = ref_glacier(temps, H₀)

# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
        FastDense(1,3, x->tanh.(x)),
        FastDense(3,10, x->tanh.(x)),
        FastDense(10,3, x->tanh.(x)),
        FastDense(3,1, sigmoid_A)
    )

iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)

And this is the new error I’m having:

ERROR: LoadError: MethodError: no method matching vec(::Nothing)
Closest candidates are:
  vec(::FillArrays.Ones{T, N, Axes} where {N, Axes}) where T at /Users/Bolib001/.julia/packages/FillArrays/Vzxer/src/fillalgebra.jl:3
  vec(::Adjoint{var"#s832", var"#s8321"} where {var"#s832"<:Real, var"#s8321"<:(AbstractVector{T} where T)}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:240
  vec(::Tuple{Vararg{MultivariatePolynomials.AbstractVariable, N} where N}) at /Users/Bolib001/.julia/packages/MultivariatePolynomials/vqcb5/src/operators.jl:351
  ...
Stacktrace:
  [1] _vecjacobian!(dλ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, y::Matrix{Float32}, λ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float32}, t::Float64, S::DiffEqSensitivity.ODEBacksolveSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Matrix{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, UniformScaling{Bool}}, BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float32}, ODEProblem{Matrix{Float32}, Tuple{Float64, Float64}, true, Vector{Float32}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, isautojacvec::ZygoteVJP, dgrad::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, dy::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, W::Nothing)
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/sLLcu/src/derivative_wrappers.jl:448
  [2] #vecjacobian!#37
    @ ~/.julia/packages/DiffEqSensitivity/sLLcu/src/derivative_wrappers.jl:224 [inlined]
  [3] (::DiffEqSensitivity.ODEBacksolveSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Matrix{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, UniformScaling{Bool}}, BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float32}, ODEProblem{Matrix{Float32}, Tuple{Float64, Float64}, true, Vector{Float32}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}})(du::Vector{Float32}, u::Vector{Float32}, p::Vector{Float32}, t::Float64)
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/sLLcu/src/backsolve_adjoint.jl:36
...
...
...

Any idea where the error comes from? I find it super hard to debug errors when using the DiffEqFlux.sciml_train wrapper function. Thanks again!