# Optimizing performance of 2D nonlinear diffusion UDE

But must you replace A, or can you just update A?

``````julia> context = ([1 2; 3 4], Ref(0));

julia> context[2][] = 99;

julia> context[1] .= 7:8;

julia> context
([7 7; 8 8], Base.RefValue{Int64}(99))
``````
1 Like

For the Diffeq adjoint it’s best to just make it a vector of numbers. I would recommend you just make the parameters a vector of numbers and then `A = reshape(@view p[n:m],i,j)` etc. to build out all of the structures. An ArrayPartition is a good object for doing this kind of thing BTW.

1 Like

This is a possibility, but using @views allows me to write all the physical equations with meaningful variable names instead of e.g. `context[1]`. I’d like a solution that is readable on top of being efficient.

I tried using a tuple and avoiding @views and I still get another error:

``````....
....
....
!10087 = !DILocation(line: 287, scope: !2750, inlinedAt: !10088)
!10088 = !DILocation(line: 969, scope: !2750, inlinedAt: !10089)
!10089 = !DILocation(line: 148, scope: !2753, inlinedAt: !10090)
!10090 = !DILocation(line: 375, scope: !2755, inlinedAt: !10083)

top:
%_replacementA = phi {}***
%4 = load i8, i8* inttoptr (i64 4469935449 to i8*), align 1, !dbg !10043, !tbaa !145, !invariant.load !4
%5 = and i8 %4, 1, !dbg !10045
%.not.not = icmp eq i8 %5, 0, !dbg !10045
br i1 %.not.not, label %L4, label %L82, !dbg !10045

L4:                                               ; preds = %top
%6 = load i8, i8* inttoptr (i64 5579314521 to i8*), align 1, !dbg !10043, !tbaa !145, !invariant.load !4
%7 = and i8 %6, 1, !dbg !10045
%.not.not26 = icmp eq i8 %7, 0, !dbg !10045
br i1 %.not.not26, label %L7, label %L82, !dbg !10045

L7:                                               ; preds = %L4
%9 = call nonnull align 8 {}* @julia.pointer_from_objref({} addrspace(11)* %8) #4, !dbg !10047
%"'ip_phi" = phi {}* , !dbg !10047
%10 = bitcast {}* %9 to i64*, !dbg !10047
%11 = load i64, i64* %10, align 8, !dbg !10047, !tbaa !145, !range !233, !invariant.load !4
%"'il_phi" = phi i64 , !dbg !10050
%12 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 0, !dbg !10050
%"'il_phi3" = phi {} addrspace(10)* , !dbg !10052
%15 = call {}* @julia.pointer_from_objref({} addrspace(11)* %"'ipc"), !dbg !10052
%16 = call nonnull align 8 {}* @julia.pointer_from_objref({} addrspace(11)* %14) #4, !dbg !10052
%17 = bitcast {}* %16 to i64*, !dbg !10052
%18 = load i64, i64* %17, align 8, !dbg !10052, !tbaa !145, !range !233, !invariant.load !4
%.not = icmp eq i64 %11, %18, !dbg !10055
br i1 %.not, label %L20, label %L82, !dbg !10046

L20:                                              ; preds = %L7
%19 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 1, i32 0, i64 1, !dbg !10058
%20 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 1, i32 0, i64 0, !dbg !10063
%23 = call { i64, i1 } @llvm.ssub.with.overflow.i64(i64 %21, i64 %22), !dbg !10065
%24 = extractvalue { i64, i1 } %23, 0, !dbg !10065
%25 = extractvalue { i64, i1 } %23, 1, !dbg !10065
br i1 %25, label %L29, label %L32, !dbg !10067

L29:                                              ; preds = %L20
%26 = call fastcc nonnull {} addrspace(10)* @julia_throw_overflowerr_binaryop_11556() #12, !dbg !10067
unreachable, !dbg !10067

L32:                                              ; preds = %L20
%27 = call { i64, i1 } @llvm.sadd.with.overflow.i64(i64 %24, i64 1), !dbg !10068
%28 = extractvalue { i64, i1 } %27, 1, !dbg !10068
br i1 %28, label %L36, label %L64, !dbg !10070

L36:                                              ; preds = %L32
%29 = call fastcc nonnull {} addrspace(10)* @julia_throw_overflowerr_binaryop_11556() #12, !dbg !10070
unreachable, !dbg !10070

L64:                                              ; preds = %L32
%30 = extractvalue { i64, i1 } %27, 0, !dbg !10068
%31 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 1, i32 1, i64 0, i64 0, !dbg !10071
%33 = call nonnull {} addrspace(10)* @jl_alloc_array_2d({} addrspace(10)* addrspacecast ({}* inttoptr (i64 5464987568 to {}*) to {} addrspace(10)*), i64 %30, i64 %32) [ "jl_roots"(i64 addrspace(11)* %31) ], !dbg !10075
%"'mi" = phi {} addrspace(10)* , !dbg !10078
%34 = call fastcc nonnull {} addrspace(10)* @julia_copyto__11605({} addrspace(10)* nonnull align 16 dereferenceable(40) %33, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* nocapture nonnull readonly align 8 dereferenceable(48) %3), !dbg !10078
%"'ip_phi7" = phi {} addrspace(10)* , !dbg !10079
%35 = icmp sgt i64 %30, 0, !dbg !10079
%36 = select i1 %35, i64 %30, i64 0, !dbg !10086
%37 = icmp sgt i64 %32, 0, !dbg !10079
%38 = select i1 %37, i64 %32, i64 0, !dbg !10079
%39 = getelementptr inbounds [1 x {} addrspace(10)*], [1 x {} addrspace(10)*]* %1, i64 0, i64 0, !dbg !10046
%.repack = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 0, !dbg !10046
%.repack27.repack.repack = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 1, i32 0, i64 0, !dbg !10046
store i64 1, i64* %.repack27.repack.repack, align 8, !dbg !10046
%.repack27.repack.repack35 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 1, i32 0, i64 1, !dbg !10046
store i64 %36, i64* %.repack27.repack.repack35, align 8, !dbg !10046
%40 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 1, i32 1, i64 0, i64 0, !dbg !10046
store i64 %38, i64* %40, align 8, !dbg !10046
%.repack29 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 2, !dbg !10046
%41 = bitcast i64* %.repack29 to i8*, !dbg !10046
call void @llvm.memset.p0i8.i64(i8* nocapture nonnull writeonly align 8 dereferenceable(16) %41, i8 0, i64 16, i1 false), !dbg !10046
ret { i8* } undef, !dbg !10046

L82:                                              ; preds = %L7, %L4, %top
%42 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 0, !dbg !10046
%"'il_phi8" = phi {} addrspace(10)* , !dbg !10046
%44 = getelementptr inbounds [1 x {} addrspace(10)*], [1 x {} addrspace(10)*]* %1, i64 0, i64 0, !dbg !10046
%45 = bitcast { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0 to i8*, !dbg !10046
%46 = bitcast { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3 to i8 addrspace(11)*, !dbg !10046
call void @llvm.memcpy.p0i8.p11i8.i64(i8* nocapture nonnull writeonly align 8 dereferenceable(48) %45, i8 addrspace(11)* nonnull align 8 dereferenceable(48) %46, i64 48, i1 false), !dbg !10046
ret { i8* } undef, !dbg !10046

allocsForInversion:                               ; No predecessors!
}

%"'il_phi3" = phi {} addrspace(10)* , !dbg !164
Assertion failed: (I->use_empty()), function erase, file /workspace/srcdir/Enzyme/enzyme/Enzyme/CacheUtility.cpp, line 72.

signal (6): Abort trap: 6
in expression starting at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:215
Allocations: 375843513 (Pool: 375774007; Big: 69506); GC: 174
``````

@ChrisRackauckas: thanks, I’ll explore those data structures once I have something working end-to-end with AD for this case.

So I’ve followed your advice and I have implemented this using `ArrayPartition`s. Enzyme.jl is still crashing, but with a different error.

Here is my new MWE with all the recent updates from this discussion:

``````using Statistics
using LinearAlgebra
using Random
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using RecursiveArrayTools
using Infiltrator

const t₁ = 10                 # number of simulation years
const ρ = 900                     # Ice density [kg / m^3]
const g = 9.81                    # Gravitational acceleration [m / s^2]
const n = 3                       # Glen's flow law exponent
const maxA = 8e-16
const minA = 3e-17
const maxT = 1
const minT = -25
A = 1.3f-24 #2e-16  1 / Pa^3 s
A *= 60 * 60 * 24 * 365.25 # [1 / Pa^3 yr]
C = 0
α = 0

@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]

function ref_glacier(temps, H₀)

H = deepcopy(H₀)

# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)

# Gather simulation parameters
current_year = Int(0)
context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])

# Perform reference simulation with forward model
println("Running forward PDE ice flow model...\n")
iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
iceflow_sol = solve(iceflow_prob, BS3(), progress=true, saveat=1.0, progress_steps = 1)

return Float32.(iceflow_sol[end])
end

function iceflow!(dH, H, context,t)
# Unpack parameters
#A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year
current_year = context.x[18]
A = context.x[1]

# Get current year for MB and ELA
year = floor(Int, t) + 1
if year != current_year && year <= t₁
temp = Ref{Float32}(context.x[7][year])
context.x[1] .= A_fake(temp[])
current_year[] = year
end

# Compute the Shallow Ice Approximation in a staggered grid
SIA!(dH, H, context)
end

function train_iceflow_UDE(H₀, UA, H_ref, temps)

# Gather simulation parameters
H = deepcopy(H₀)

# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)

# Gather simulation parameters
current_year = 0
θ = initial_params(UA)
context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year], H_ref, H, UA, θ)
loss(θ) = loss_iceflow(UA, θ, H, context) # closure

println("Training iceflow UDE...")

iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01), maxiters = 5)

return iceflow_trained
end

function loss_iceflow(UA, θ, H, context)

H = predict_iceflow(UA, θ, H, context)

H_ref = context.x[19]
l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], H_ref[H.!= 0.0]; agg=sum))
println("Loss = ", l_H)

return l_H
end

function predict_iceflow(UA, θ, H, context)

iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, θ, t, context) # closure
tspan = (0.0,t₁)
iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, save_everystep=false,
progress=true, progress_steps = 1)

return H_pred[end]
end

function iceflow_NN!(dH, H, θ, t, context)

# Unpack parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
current_year = Ref(context.x[18])
A = Ref(context.x[1])
UA = context.x[21]

# Get current year for MB and ELA
year = floor(Int, t) + 1
if year != current_year && year <= t₁
temp = context.x[7][year]
A[] .= predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters
current_year[] .= year
end

# Compute the Shallow Ice Approximation in a staggered grid
SIA!(dH, H, context)

return nothing

end

"""
SIA(H, p)

Compute a step of the Shallow Ice Approximation PDE in a forward model
"""

function SIA!(dH, H, context)

# Retrieve parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
A = context.x[1]
B = context.x[2]
S = context.x[3]
dSdx = context.x[4]
dSdy = context.x[5]
D = context.x[6]
dSdx_edges = context.x[8]
dSdy_edges = context.x[9]
∇S = context.x[10]
Fx = context.x[11]
Fy = context.x[12]

# Update glacier surface altimetry
S .= B .+ H

# All grid variables computed in a staggered grid
# Compute surface gradients on edges
dSdx .= diff_x(S) / Δx
dSdy .= diff_y(S) / Δy
∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2)

Γ = 2 * A[] * (ρ * g)^n / (n+2) # 1 / m^3 s
D .= Γ .* avg(H).^(n + 2) .* ∇S

# Compute flux components
dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
Fx .= .-avg_y(D) .* dSdx_edges
Fy .= .-avg_x(D) .* dSdy_edges

#  Flux divergence
inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here

return nothing

end

function A_fake(temp)
return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end

predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16

#### Generate reference dataset ####
nx = ny = 100
B = zeros(Float32, (nx, ny))
σ = 1000
H₀ = Matrix{Float32}([ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ])
Δx = Δy = 50 #m

temps =  Vector{Float32}([0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1])
H_ref = ref_glacier(temps, H₀)

# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
FastDense(1,3, x->tanh.(x)),
FastDense(3,10, x->tanh.(x)),
FastDense(10,3, x->tanh.(x)),
FastDense(3,1, sigmoid_A)
)

iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)
``````

And here’s the error I’m getting:

``````└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/g5epq/src/compiler.jl:375
┌ Warning: ("reverse differentiating jl_invoke call without split mode", iceflow_NN!, MethodInstance for iceflow_NN!(::Matrix{Float32}, ::Matrix{Float32}, ::Function, ::Vector{Float32}, ::Float64, ::ArrayPartition{Float64, Tuple{Vector{Float64}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Vector{Int64}, Matrix{Float32}, Matrix{Float32}}}))
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/g5epq/src/compiler.jl:404
need   %35 = call token (...) @llvm.julia.gc_preserve_begin({} addrspace(10)* nonnull %34), !dbg !215 via   call void @llvm.julia.gc_preserve_end(token %35), !dbg !245
Assertion failed: (!inst->getType()->isTokenTy()), function is_value_needed_in_reverse, file /workspace/srcdir/Enzyme/enzyme/Enzyme/DifferentialUseAnalysis.h, line 379.

signal (6): Abort trap: 6
in expression starting at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:214
Allocations: 382360361 (Pool: 382273315; Big: 87046); GC: 196
``````

Some progress, but still not there yet. Any idea on what’s wrong now regarding the “split mode”? Thanks!

I’m going to need to call in the Enzyme devs here since this is something I would now expect to work . Seems like there’s a dynamic call it cannot handle somewhere?

2 Likes

OK, so as discussed, I’m going back to a Zygote implementation of this, with the goal of making it as optimized as possible while waiting for Enzyme to work.

Unlike my manual previous implementation of this using `pullback` which used to work, when I use `BacksolveAdjoint(autojacvec=ZygoteVJP())` I’m getting an error that I cannot debug. In order to avoid mutation issues, I’m passing a tuple of model parameters and I’m returning a copy of `dH`.

Here’s a new MWE with the updated implementation:

``````using Statistics
using LinearAlgebra
using Random
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using Tullio
using RecursiveArrayTools
using Infiltrator

const t₁ = 10                 # number of simulation years
const ρ = 900f0                     # Ice density [kg / m^3]
const g = 9.81f0                    # Gravitational acceleration [m / s^2]
const n = 3f0                       # Glen's flow law exponent
const maxA = 8f-16
const minA = 3f-17
const maxT = 1f0
const minT = -25f0
A = 1.3f-24 #2e-16  1 / Pa^3 s
A *= Float32(60 * 60 * 24 * 365.25) # [1 / Pa^3 yr]
C = 0
α = 0

@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]

function ref_glacier(temps, H₀)

H = deepcopy(H₀)

# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)

# Gather simulation parameters
current_year = Int(0)
context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])

# Perform reference simulation with forward model
println("Running forward PDE ice flow model...\n")
iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
iceflow_sol = solve(iceflow_prob, BS3(), progress=true, saveat=1.0, progress_steps = 1)

return Float32.(iceflow_sol[end])
end

function iceflow!(dH, H, context,t)
# Unpack parameters
#A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year
current_year = Ref(context.x[18])
A = Ref(context.x[1])

# Get current year for MB and ELA
year = floor(Int, t) + 1
if year != current_year[] && year <= t₁
temp = Ref{Float32}(context.x[7][year])
A[] .= A_fake(temp[])
current_year[] .= year
end

# Compute the Shallow Ice Approximation in a staggered grid
SIA!(dH, H, context)
end

function train_iceflow_UDE(H₀, UA, H_ref, temps)

# Gather simulation parameters
H = deepcopy(H₀)

# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)

# Gather simulation parameters
current_year = 0
θ = initial_params(UA)
context = (A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H)
loss(θ) = loss_iceflow(UA, θ, H, context) # closure

println("Training iceflow UDE...")

iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01), maxiters = 10)

return iceflow_trained
end

function loss_iceflow(UA, θ, H, context)

H = predict_iceflow(UA, θ, H, context)

H_ref = context[19]
l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], H_ref[H.!= 0.0]; agg=sum))
println("Loss = ", l_H)

return l_H
end

function predict_iceflow(UA, θ, H, context)

iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, context, UA, θ, t) # closure
tspan = (0.0,t₁)
iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, save_everystep=false,
progress=true, progress_steps = 1)

return H_pred[end]
end

function iceflow_NN!(dH, H, context, UA, θ, t)

# Unpack parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
current_year = context[18]
A = context[1]

# Get current year for MB and ELA
year = floor(Int, t) + 1
if year != current_year && year <= t₁
temp = context[7][year]
YA = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters

# Unpack and repack tuple to update `A` and `current_year`
A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H = context
context = (YA, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, year, H_ref, H)

end

# Compute the Shallow Ice Approximation in a staggered grid
dH = SIA!(dH, H, context)

end

"""
SIA(H, p)

Compute a step of the Shallow Ice Approximation PDE in a forward model
"""

function SIA!(dH, H, context)

# Retrieve parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
A = context.x[1]
B = context.x[2]
S = context.x[3]
dSdx = context.x[4]
dSdy = context.x[5]
D = context.x[6]
dSdx_edges = context.x[8]
dSdy_edges = context.x[9]
∇S = context.x[10]
Fx = context.x[11]
Fy = context.x[12]

# Update glacier surface altimetry
S .= B .+ H

# All grid variables computed in a staggered grid
# Compute surface gradients on edges
dSdx .= diff_x(S) / Δx
dSdy .= diff_y(S) / Δy
∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2)

Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s
D .= Γ .* avg(H).^(n + 2) .* ∇S

# Compute flux components
dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
Fx .= .-avg_y(D) .* dSdx_edges
Fy .= .-avg_x(D) .* dSdy_edges

#  Flux divergence
inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here

# Compute velocities
#Vx = -D./(avg(H) .+ ϵ).*avg_y(dSdx)
#Vy = -D./(avg(H) .+ ϵ).*avg_x(dSdy)
end

# Function without mutation for Zygote, with context as a tuple
function SIA!(dH, H, context::Tuple)

# Retrieve parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
A = context[1]
B = context[2]
S = context[3]
dSdx = context[4]
dSdy = context[5]
D = context[6]
dSdx_edges = context[8]
dSdy_edges = context[9]
∇S = context[10]
Fx = context[11]
Fy = context[12]

# Update glacier surface altimetry
S = B .+ H

# All grid variables computed in a staggered grid
# Compute surface gradients on edges
dSdx = diff_x(S) / Δx
dSdy = diff_y(S) / Δy
∇S = (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2)

Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s
D = Γ .* avg(H).^(n + 2) .* ∇S

# Compute flux components
dSdx_edges = diff(S[:,2:end - 1], dims=1) / Δx
dSdy_edges = diff(S[2:end - 1,:], dims=2) / Δy
Fx = -avg_y(D) .* dSdx_edges
Fy = -avg_x(D) .* dSdy_edges

#  Flux divergence

return dH
end

function A_fake(temp)
return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end

predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16

#### Generate reference dataset ####
nx = ny = 100
const B = zeros(Float32, (nx, ny))
const σ = 1000
H₀ = Matrix{Float32}([ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ])
Δx = Δy = 50 #m

const temps =  Vector{Float32}([0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1])
H_ref = ref_glacier(temps, H₀)

# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
FastDense(1,3, x->tanh.(x)),
FastDense(3,10, x->tanh.(x)),
FastDense(10,3, x->tanh.(x)),
FastDense(3,1, sigmoid_A)
)

iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)
``````

And this is the new error I’m having:

``````ERROR: LoadError: MethodError: no method matching vec(::Nothing)
Closest candidates are:
vec(::FillArrays.Ones{T, N, Axes} where {N, Axes}) where T at /Users/Bolib001/.julia/packages/FillArrays/Vzxer/src/fillalgebra.jl:3
vec(::Tuple{Vararg{MultivariatePolynomials.AbstractVariable, N} where N}) at /Users/Bolib001/.julia/packages/MultivariatePolynomials/vqcb5/src/operators.jl:351
...
Stacktrace:
[1] _vecjacobian!(dλ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, y::Matrix{Float32}, λ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float32}, t::Float64, S::DiffEqSensitivity.ODEBacksolveSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Matrix{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, UniformScaling{Bool}}, BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float32}, ODEProblem{Matrix{Float32}, Tuple{Float64, Float64}, true, Vector{Float32}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, isautojacvec::ZygoteVJP, dgrad::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, dy::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, W::Nothing)
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/sLLcu/src/derivative_wrappers.jl:448
[2] #vecjacobian!#37
@ ~/.julia/packages/DiffEqSensitivity/sLLcu/src/derivative_wrappers.jl:224 [inlined]
[3] (::DiffEqSensitivity.ODEBacksolveSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Matrix{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, UniformScaling{Bool}}, BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float32}, ODEProblem{Matrix{Float32}, Tuple{Float64, Float64}, true, Vector{Float32}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}})(du::Vector{Float32}, u::Vector{Float32}, p::Vector{Float32}, t::Float64)
...
...
...
``````

Any idea where the error comes from? I find it super hard to debug errors when using the `DiffEqFlux.sciml_train` wrapper function. Thanks again!

The way I usually debug these things is to simplify the ODE by commenting things out until it works, and then progressively add back lines of code until the error is hit. My guess is that the pullback is calculating a nothing on one of the terms (i.e. no gradient) that then is being used for a gradient calculation, which shouldn’t happen and could be an issue in one of the adjoint definitions. If you help isolate it I can take a deeper look.

1 Like

I have progressively commented everything inside the ODE and the error persists. Could the issue come from the way I’m using a closure to pass additional parameters to the PDE function? Here’s what I’m doing:

``````function predict_iceflow(UA, θ, H, context)

iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, context, UA, θ, t) # closure
tspan = (0.0,t₁)
iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, save_everystep=false,
progress=true, progress_steps = 1)

return H_pred[end]
end
``````

This appears to work in the forward pass, but crashes during the pullback with the same error as above. Is that closure correct?

1 Like

If `context` and `UA` are constants (not dependent on the optimized values) then it should be fine.

`context` is a tuple, containing the prediction made by `UA` with the optimized values (the tuple is updated by unpacking and repacking the values):

``````year = floor(Int, t) + 1
if year != current_year && year <= t₁
temp = context[7][year]
YA = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters

# Unpack and repack tuple to update `A` and `current_year`
A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H = context
context = (YA, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, year, H_ref, H)

end
``````

In fact I’m also using another closure for the loss function (for the same reason):

``````θ = initial_params(UA)
context = (A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H)
loss(θ) = loss_iceflow(UA, θ, H, context) # closure

iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01), maxiters = 10)
``````

Yeah, so if those are parameters of the differential equation then they have to come from `p`, otherwise they are constant. This is the piece about `@view` from a vector etc.

1 Like

I don’t know if I shared this with you before, but ComponentArrays.jl is a very good way to set this kind of thing up.

https://jonniedie.github.io/ComponentArrays.jl/dev/examples/DiffEqFlux/

2 Likes

Thanks for the tip, indeed, a `ComponentArray` makes things clearer than using a tuple.

I have managed to fix that problem, now Zygote seems to be working, but the solver is returning `Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable` after computing the loss. I have tried several different solvers that you suggested, `BS3()` still seems to be the best, but only seems to work correctly for the forward run of the reference dataset.

Is it normal that on the normal forward run to generate the reference dataset everything runs super fast and smoothly, but on the UDE I get that behaviour? The strange thing is that the NN weights don’t seem to get updated for each iteration (maybe due to the abort?). Moreover, when I get that error `dH` goes to NaN, but I’m pretty sure the PDE model is correct. I have also constrained the output values of the NN in order to give physically stable values, and I have set a low learning rate to avoid unstable results. What am I missing? I can provide an updated MWE if you want.

Is the UDE that you defined stable? It’s easy to accidentally make a U(u) that is positive over your whole solution interval, and so if you then `u' = f(u) + U(u)` suddenly the ODE is no longer stable with the random initial weights of the neural network. This is especially true for long time spans, which is why we do not suggest using direct single shooting for long time spans and stiff models. Did you follow the docs on the training strategies for these cases?

https://diffeqflux.sciml.ai/dev/examples/multiple_shooting/

https://diffeqflux.sciml.ai/dev/examples/local_minima/

The latter technique is usually required if you aren’t doing multiple shooting since otherwise you can get cases where the first complete solve is unstable with the first weights, in which case the optimization would not be able to start.

Thanks for the links, so much information in the DifferentialEquations.jl and DiffEqFlux.jl docs! I will read everything and try this.

However, I’m quite sure the problem might not come from unstable initial values. I have designed the NN to only produce predictions between a given range of physically stable values (sigmoid function at the output layer with max and min values):

``````minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
FastDense(1,3, x->tanh.(x)),
FastDense(3,10, x->tanh.(x)),
FastDense(10,3, x->tanh.(x)),
FastDense(3,1, sigmoid_A)
)
``````

I have checked the values generated by the NN and they are always physically realistic, so they should not produce instabilities in the ODE. Moreover, the input timeseries I’m using are mostly the same value repeated with a tiny noise, so it’s a very easy initial problem (just to debug it, for now), quite different from the highly varying time series from the multiple shooting example you sent.

I’d have to see a picture of how it’s going unstable. Though it could also be from using an explicit method.

I have also tried `VCABM()` and it appears to be working well (like `BS3()`), but I still get the same error during `sciml_train`. Wanna try the new MWE and have a guess?

``````using Statistics
using LinearAlgebra
using Random
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using Tullio
using RecursiveArrayTools
using ComponentArrays

const t₁ = 10                 # number of simulation years
const ρ = 900f0                     # Ice density [kg / m^3]
const g = 9.81f0                    # Gravitational acceleration [m / s^2]
const n = 3f0                       # Glen's flow law exponent
const maxA = 8f-16
const minA = 3f-17
const maxT = 1f0
const minT = -25f0
A = 1.3f-24 #2e-16  1 / Pa^3 s
A *= Float32(60 * 60 * 24 * 365.25) # [1 / Pa^3 yr]
C = 0
α = 0

@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]

function ref_glacier(temps, H₀)

H = deepcopy(H₀)

# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)

# Gather simulation parameters
current_year = Int(0)
context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])

# Perform reference simulation with forward model
println("Running forward PDE ice flow model...\n")
iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
iceflow_sol = solve(iceflow_prob, BS3(), reltol=1e-6, progress=true, saveat=1.0, progress_steps = 1)

return Float32.(iceflow_sol[end])
end

function iceflow!(dH, H, context,t)
# Unpack parameters
#A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year
current_year = Ref(context.x[18])
A = Ref(context.x[1])

# Get current year for MB and ELA
year = floor(Int, t) + 1
if year != current_year[] && year <= t₁
temp = Ref{Float32}(context.x[7][year])
A[] .= A_fake(temp[])
current_year[] .= year
end

# Compute the Shallow Ice Approximation in a staggered grid
SIA!(dH, H, context)
end

function train_iceflow_UDE(H₀, UA, H_ref, temps)

# Gather simulation parameters
H = deepcopy(H₀)

# Gather simulation parameters
current_year = 0
θ = initial_params(UA)
context = ComponentArray(B=B, C=C, α=α, temps=temps,current_year=current_year, H_ref=H_ref)
loss(θ) = loss_iceflow(θ, UA, H, context) # closure

println("Training iceflow UDE...")
iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.0001), maxiters = 10)

return iceflow_trained
end

function loss_iceflow(θ, UA, H, context)

H = predict_iceflow(θ, UA, H, context)

l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], context.H_ref[H.!= 0.0]; agg=sum))
println("Loss = ", l_H)

return l_H
end

function predict_iceflow(θ, UA, H, context)

iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, θ, t, context, UA) # closure
tspan = (0.0,t₁)
iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, reltol=1e-6, save_everystep=false,
progress=true, progress_steps = 1)

return H_pred[end]
end

function iceflow_NN!(dH, H, θ, t, context, UA)

year = floor(Int, t) + 1
if year <= t₁
temp = context.temps[year]
else
temp = context.temps[year-1]
end
YA = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters

if t%1 == 0
println("A: ", YA)
end

# Compute the Shallow Ice Approximation in a staggered grid
dH .= SIA!(dH, H, YA, context)

end

"""
SIA(H, p)

Compute a step of the Shallow Ice Approximation PDE in a forward model
"""

function SIA!(dH, H, context)

# Retrieve parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
A = context.x[1]
B = context.x[2]
S = context.x[3]
dSdx = context.x[4]
dSdy = context.x[5]
D = context.x[6]
dSdx_edges = context.x[8]
dSdy_edges = context.x[9]
∇S = context.x[10]
Fx = context.x[11]
Fy = context.x[12]

# Update glacier surface altimetry
S .= B .+ H

# All grid variables computed in a staggered grid
# Compute surface gradients on edges
dSdx .= diff_x(S) / Δx
dSdy .= diff_y(S) / Δy
∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2)

Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s
D .= Γ .* avg(H).^(n + 2) .* ∇S

# Compute flux components
dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
Fx .= .-avg_y(D) .* dSdx_edges
Fy .= .-avg_x(D) .* dSdy_edges

#  Flux divergence
inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here

end

# Function without mutation for Zygote, with context as a tuple
function SIA!(dH, H, A, context::ComponentArray)

# Retrieve parameters
B = context.B

# Update glacier surface altimetry
S = B .+ H

# All grid variables computed in a staggered grid
# Compute surface gradients on edges
dSdx = diff_x(S) / Δx
dSdy = diff_y(S) / Δy
∇S = (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2)

Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s
D = Γ .* avg(H).^(n + 2) .* ∇S

# Compute flux components
dSdx_edges = diff(S[:,2:end - 1], dims=1) / Δx
dSdy_edges = diff(S[2:end - 1,:], dims=2) / Δy
Fx = .-avg_y(D) .* dSdx_edges
Fy = .-avg_x(D) .* dSdy_edges

#  Flux divergence

return dH
end

function A_fake(temp)
return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end

predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16

#### Generate reference dataset ####
nx = ny = 100
const B = zeros(Float32, (nx, ny))
const σ = 1000
H₀ = Matrix{Float32}([ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ])
Δx = Δy = 50 #m

const temps =  Vector{Float32}([0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1])
H_ref = ref_glacier(temps, H₀)

# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
FastDense(1,3, x->tanh.(x)),
FastDense(3,10, x->tanh.(x)),
FastDense(10,3, x->tanh.(x)),
FastDense(3,1, sigmoid_A)
)

iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)

``````

Yeah, I’ll see when I can get to that. I’m in the middle of a big PR so I’m trying to debug from a distance

1 Like

I’ve continued investigating this and I have a little bit more of insight. The first forward pass of the UDE works fine, the initial values of the NN are stable and produce a meaningful loss. However, right after, the ODE solver is called again (I can see the progress bar), but the new initial conditions (i.e. `H`) are exactly the same as at the end of the previous forward run. And then is when the gradients go to infinity and everything crashes.

Since `sciml_train` is such a high level black box, I’m having a hard time understanding what is going on exactly. Is that the pullback calling the solver backwards? Is that why the initial conditions match the final conditions of the previous forward run? Not sure if this is of much help, just some extra context that might avoid you trying to debug it. Thanks again!

Yes, if you do `BacksolveAdjoint(autojacvec=ZygoteVJP())`. That method is will give unstable gradients. I describe this in my talk on adjoints that you never want to use BacksolveAdjoint for any real equations. See the talk starting at:

1 Like