Optimizing performance of 2D nonlinear diffusion UDE

Let us check this with JET.jl. First we need a function to check

function test()
    #### Generate reference dataset ####
    nx = ny = 100
    B = zeros(Float64, (nx, ny))
    σ = 1000
    H₀ = [ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ]
    Δx = Δy = 50 #m

    temps = [0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1]
    H_ref = ref_glacier(temps, H₀)

    # Train UDE
    minA_out = 0.3
    maxA_out = 8
    sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
    UA = FastChain(
            FastDense(1,3, x->tanh.(x)),
            FastDense(3,10, x->tanh.(x)),
            FastDense(10,3, x->tanh.(x)),
            FastDense(3,1, sigmoid_A)
        )

    iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)
end

Then we run JET.@report_opt test() and get something like this.

1 Like

Thanks for your reply, Michael. The reason I’m using an array instead of a tuple is because my model requires updating certain parameters during the forward run (e.g. A). What would be the best way to have multiple parameters passed by reference while being able to update them and avoiding type instabilities? I feel like I’ve been going round in circles fixing one issue which creates another.

But must you replace A, or can you just update A?

julia> context = ([1 2; 3 4], Ref(0));

julia> context[2][] = 99;

julia> context[1] .= 7:8;

julia> context
([7 7; 8 8], Base.RefValue{Int64}(99))
1 Like

For the Diffeq adjoint it’s best to just make it a vector of numbers. I would recommend you just make the parameters a vector of numbers and then A = reshape(@view p[n:m],i,j) etc. to build out all of the structures. An ArrayPartition is a good object for doing this kind of thing BTW.

1 Like

This is a possibility, but using @views allows me to write all the physical equations with meaningful variable names instead of e.g. context[1]. I’d like a solution that is readable on top of being efficient.

I tried using a tuple and avoiding @views and I still get another error:

....
....
....
!10087 = !DILocation(line: 287, scope: !2750, inlinedAt: !10088)
!10088 = !DILocation(line: 969, scope: !2750, inlinedAt: !10089)
!10089 = !DILocation(line: 148, scope: !2753, inlinedAt: !10090)
!10090 = !DILocation(line: 375, scope: !2755, inlinedAt: !10083)

define internal fastcc { i8* } @fakeaugmented_julia_unalias_11598({ {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* noalias nocapture nonnull sret writeonly align 8 dereferenceable(48) %0, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* nocapture %"'", [1 x {} addrspace(10)*]* noalias nocapture nonnull writeonly align 8 dereferenceable(8) %1, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2, {} addrspace(10)* %"'1", { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* nocapture nonnull readonly align 8 dereferenceable(48) %3, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* nocapture %"'2") unnamed_addr !dbg !10042 {
top:
  %_replacementA = phi {}*** 
  %4 = load i8, i8* inttoptr (i64 4469935449 to i8*), align 1, !dbg !10043, !tbaa !145, !invariant.load !4
  %5 = and i8 %4, 1, !dbg !10045
  %.not.not = icmp eq i8 %5, 0, !dbg !10045
  br i1 %.not.not, label %L4, label %L82, !dbg !10045

L4:                                               ; preds = %top
  %6 = load i8, i8* inttoptr (i64 5579314521 to i8*), align 1, !dbg !10043, !tbaa !145, !invariant.load !4
  %7 = and i8 %6, 1, !dbg !10045
  %.not.not26 = icmp eq i8 %7, 0, !dbg !10045
  br i1 %.not.not26, label %L7, label %L82, !dbg !10045

L7:                                               ; preds = %L4
  %8 = addrspacecast {} addrspace(10)* %2 to {} addrspace(11)*, !dbg !10047
  %9 = call nonnull align 8 {}* @julia.pointer_from_objref({} addrspace(11)* %8) #4, !dbg !10047
  %"'ip_phi" = phi {}* , !dbg !10047
  %10 = bitcast {}* %9 to i64*, !dbg !10047
  %11 = load i64, i64* %10, align 8, !dbg !10047, !tbaa !145, !range !233, !invariant.load !4
  %"'il_phi" = phi i64 , !dbg !10050
  %12 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 0, !dbg !10050
  %13 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %12 unordered, align 8, !dbg !10050, !tbaa !145, !invariant.load !4, !nonnull !4, !dereferenceable !379, !align !380
  %"'il_phi3" = phi {} addrspace(10)* , !dbg !10052
  %"'ipc" = addrspacecast {} addrspace(10)* %"'il_phi3" to {} addrspace(11)*, !dbg !10052
  %14 = addrspacecast {} addrspace(10)* %13 to {} addrspace(11)*, !dbg !10052
  %15 = call {}* @julia.pointer_from_objref({} addrspace(11)* %"'ipc"), !dbg !10052
  %16 = call nonnull align 8 {}* @julia.pointer_from_objref({} addrspace(11)* %14) #4, !dbg !10052
  %17 = bitcast {}* %16 to i64*, !dbg !10052
  %18 = load i64, i64* %17, align 8, !dbg !10052, !tbaa !145, !range !233, !invariant.load !4
  %.not = icmp eq i64 %11, %18, !dbg !10055
  br i1 %.not, label %L20, label %L82, !dbg !10046

L20:                                              ; preds = %L7
  %19 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 1, i32 0, i64 1, !dbg !10058
  %20 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 1, i32 0, i64 0, !dbg !10063
  %21 = load i64, i64 addrspace(11)* %19, align 8, !dbg !10065, !tbaa !145, !invariant.load !4
  %22 = load i64, i64 addrspace(11)* %20, align 8, !dbg !10065, !tbaa !145, !invariant.load !4
  %23 = call { i64, i1 } @llvm.ssub.with.overflow.i64(i64 %21, i64 %22), !dbg !10065
  %24 = extractvalue { i64, i1 } %23, 0, !dbg !10065
  %25 = extractvalue { i64, i1 } %23, 1, !dbg !10065
  br i1 %25, label %L29, label %L32, !dbg !10067

L29:                                              ; preds = %L20
  %26 = call fastcc nonnull {} addrspace(10)* @julia_throw_overflowerr_binaryop_11556() #12, !dbg !10067
  unreachable, !dbg !10067

L32:                                              ; preds = %L20
  %27 = call { i64, i1 } @llvm.sadd.with.overflow.i64(i64 %24, i64 1), !dbg !10068
  %28 = extractvalue { i64, i1 } %27, 1, !dbg !10068
  br i1 %28, label %L36, label %L64, !dbg !10070

L36:                                              ; preds = %L32
  %29 = call fastcc nonnull {} addrspace(10)* @julia_throw_overflowerr_binaryop_11556() #12, !dbg !10070
  unreachable, !dbg !10070

L64:                                              ; preds = %L32
  %30 = extractvalue { i64, i1 } %27, 0, !dbg !10068
  %31 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 1, i32 1, i64 0, i64 0, !dbg !10071
  %32 = load i64, i64 addrspace(11)* %31, align 8, !dbg !10075, !tbaa !145, !invariant.load !4
  %33 = call nonnull {} addrspace(10)* @jl_alloc_array_2d({} addrspace(10)* addrspacecast ({}* inttoptr (i64 5464987568 to {}*) to {} addrspace(10)*), i64 %30, i64 %32) [ "jl_roots"(i64 addrspace(11)* %31) ], !dbg !10075
  %"'mi" = phi {} addrspace(10)* , !dbg !10078
  %34 = call fastcc nonnull {} addrspace(10)* @julia_copyto__11605({} addrspace(10)* nonnull align 16 dereferenceable(40) %33, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* nocapture nonnull readonly align 8 dereferenceable(48) %3), !dbg !10078
  %"'ip_phi7" = phi {} addrspace(10)* , !dbg !10079
  %35 = icmp sgt i64 %30, 0, !dbg !10079
  %36 = select i1 %35, i64 %30, i64 0, !dbg !10086
  %37 = icmp sgt i64 %32, 0, !dbg !10079
  %38 = select i1 %37, i64 %32, i64 0, !dbg !10079
  %39 = getelementptr inbounds [1 x {} addrspace(10)*], [1 x {} addrspace(10)*]* %1, i64 0, i64 0, !dbg !10046
  store {} addrspace(10)* %33, {} addrspace(10)** %39, align 8, !dbg !10046
  %.repack = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 0, !dbg !10046
  store {} addrspace(10)* %33, {} addrspace(10)** %.repack, align 8, !dbg !10046
  %.repack27.repack.repack = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 1, i32 0, i64 0, !dbg !10046
  store i64 1, i64* %.repack27.repack.repack, align 8, !dbg !10046
  %.repack27.repack.repack35 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 1, i32 0, i64 1, !dbg !10046
  store i64 %36, i64* %.repack27.repack.repack35, align 8, !dbg !10046
  %40 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 1, i32 1, i64 0, i64 0, !dbg !10046
  store i64 %38, i64* %40, align 8, !dbg !10046
  %.repack29 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 2, !dbg !10046
  %41 = bitcast i64* %.repack29 to i8*, !dbg !10046
  call void @llvm.memset.p0i8.i64(i8* nocapture nonnull writeonly align 8 dereferenceable(16) %41, i8 0, i64 16, i1 false), !dbg !10046
  ret { i8* } undef, !dbg !10046

L82:                                              ; preds = %L7, %L4, %top
  %42 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 0, !dbg !10046
  %43 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %42, align 8, !dbg !10046
  %"'il_phi8" = phi {} addrspace(10)* , !dbg !10046
  %44 = getelementptr inbounds [1 x {} addrspace(10)*], [1 x {} addrspace(10)*]* %1, i64 0, i64 0, !dbg !10046
  store {} addrspace(10)* %43, {} addrspace(10)** %44, align 8, !dbg !10046
  %45 = bitcast { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0 to i8*, !dbg !10046
  %46 = bitcast { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3 to i8 addrspace(11)*, !dbg !10046
  call void @llvm.memcpy.p0i8.p11i8.i64(i8* nocapture nonnull writeonly align 8 dereferenceable(48) %45, i8 addrspace(11)* nonnull align 8 dereferenceable(48) %46, i64 48, i1 false), !dbg !10046
  ret { i8* } undef, !dbg !10046

allocsForInversion:                               ; No predecessors!
}

  %"'il_phi3" = phi {} addrspace(10)* , !dbg !164
Assertion failed: (I->use_empty()), function erase, file /workspace/srcdir/Enzyme/enzyme/Enzyme/CacheUtility.cpp, line 72.

signal (6): Abort trap: 6
in expression starting at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:215
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 375843513 (Pool: 375774007; Big: 69506); GC: 174

@ChrisRackauckas: thanks, I’ll explore those data structures once I have something working end-to-end with AD for this case.

So I’ve followed your advice and I have implemented this using ArrayPartitions. Enzyme.jl is still crashing, but with a different error.

Here is my new MWE with all the recent updates from this discussion:

using Statistics
using LinearAlgebra
using Random 
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using RecursiveArrayTools
using Infiltrator

const t₁ = 10                 # number of simulation years 
const ρ = 900                     # Ice density [kg / m^3]
const g = 9.81                    # Gravitational acceleration [m / s^2]
const n = 3                       # Glen's flow law exponent
const maxA = 8e-16
const minA = 3e-17
const maxT = 1
const minT = -25
A = 1.3f-24 #2e-16  1 / Pa^3 s
A *= 60 * 60 * 24 * 365.25 # [1 / Pa^3 yr]
C = 0
α = 0

@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]

function ref_glacier(temps, H₀)
      
    H = deepcopy(H₀)
    
    # Initialize all matrices for the solver
    S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
    dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
    D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
    V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
    
    # Gather simulation parameters
    current_year = Int(0)
    context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])

    # Perform reference simulation with forward model 
    println("Running forward PDE ice flow model...\n")
    iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
    iceflow_sol = solve(iceflow_prob, BS3(), progress=true, saveat=1.0, progress_steps = 1)

    return Float32.(iceflow_sol[end])
end

function iceflow!(dH, H, context,t)
    # Unpack parameters
    #A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year 
    current_year = context.x[18]
    A = context.x[1]
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year && year <= t₁
        temp = Ref{Float32}(context.x[7][year])
        context.x[1] .= A_fake(temp[])
        current_year[] = year
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    SIA!(dH, H, context)
end  

function train_iceflow_UDE(H₀, UA, H_ref, temps)
    
    # Gather simulation parameters
    H = deepcopy(H₀)
    
    # Initialize all matrices for the solver
    S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
    dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
    D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
    V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
    
    # Gather simulation parameters
    current_year = 0
    θ = initial_params(UA)
    context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year], H_ref, H, UA, θ)
    loss(θ) = loss_iceflow(UA, θ, H, context) # closure

    println("Training iceflow UDE...")

    iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01), maxiters = 5)

    return iceflow_trained
end

function loss_iceflow(UA, θ, H, context)
    
    H = predict_iceflow(UA, θ, H, context)
    
    H_ref = context.x[19]
    l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], H_ref[H.!= 0.0]; agg=sum))
    println("Loss = ", l_H)

    return l_H
end

function predict_iceflow(UA, θ, H, context)
        
    iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, θ, t, context) # closure
    tspan = (0.0,t₁)
    iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
    H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, save_everystep=false, 
                   sensealg = BacksolveAdjoint(autojacvec=EnzymeVJP()), 
                   progress=true, progress_steps = 1)

    return H_pred[end]
end

function iceflow_NN!(dH, H, θ, t, context)
    
    # Unpack parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    current_year = Ref(context.x[18])
    A = Ref(context.x[1])
    UA = context.x[21]
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year && year <= t₁
        temp = context.x[7][year]
        A[] .= predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters
        current_year[] .= year
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    SIA!(dH, H, context)

    return nothing

end  

"""
    SIA(H, p)

Compute a step of the Shallow Ice Approximation PDE in a forward model
"""

function SIA!(dH, H, context)
    
    # Retrieve parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    A = context.x[1]
    B = context.x[2]
    S = context.x[3]
    dSdx = context.x[4]
    dSdy = context.x[5]
    D = context.x[6]
    dSdx_edges = context.x[8]
    dSdy_edges = context.x[9]
    ∇S = context.x[10]
    Fx = context.x[11]
    Fy = context.x[12]
    
    # Update glacier surface altimetry
    S .= B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx .= diff_x(S) / Δx
    dSdy .= diff_y(S) / Δy
    ∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2) 

    Γ = 2 * A[] * (ρ * g)^n / (n+2) # 1 / m^3 s 
    D .= Γ .* avg(H).^(n + 2) .* ∇S

    # Compute flux components
    dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
    Fx .= .-avg_y(D) .* dSdx_edges
    Fy .= .-avg_x(D) .* dSdy_edges 

    #  Flux divergence
    inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here 

    return nothing
    
end

function A_fake(temp)
    return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end

predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16


#### Generate reference dataset ####
nx = ny = 100
B = zeros(Float32, (nx, ny))
σ = 1000
H₀ = Matrix{Float32}([ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ])    
Δx = Δy = 50 #m   

temps =  Vector{Float32}([0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1])
H_ref = ref_glacier(temps, H₀)

# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
        FastDense(1,3, x->tanh.(x)),
        FastDense(3,10, x->tanh.(x)),
        FastDense(10,3, x->tanh.(x)),
        FastDense(3,1, sigmoid_A)
    )

iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)

And here’s the error I’m getting:

└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/g5epq/src/compiler.jl:375
┌ Warning: ("reverse differentiating jl_invoke call without split mode", iceflow_NN!, MethodInstance for iceflow_NN!(::Matrix{Float32}, ::Matrix{Float32}, ::Function, ::Vector{Float32}, ::Float64, ::ArrayPartition{Float64, Tuple{Vector{Float64}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Vector{Int64}, Matrix{Float32}, Matrix{Float32}}}))
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/g5epq/src/compiler.jl:404
 need   %35 = call token (...) @llvm.julia.gc_preserve_begin({} addrspace(10)* nonnull %34), !dbg !215 via   call void @llvm.julia.gc_preserve_end(token %35), !dbg !245
Assertion failed: (!inst->getType()->isTokenTy()), function is_value_needed_in_reverse, file /workspace/srcdir/Enzyme/enzyme/Enzyme/DifferentialUseAnalysis.h, line 379.

signal (6): Abort trap: 6
in expression starting at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:214
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 382360361 (Pool: 382273315; Big: 87046); GC: 196

Some progress, but still not there yet. Any idea on what’s wrong now regarding the “split mode”? Thanks!

I’m going to need to call in the Enzyme devs here since this is something I would now expect to work :sweat: . Seems like there’s a dynamic call it cannot handle somewhere?

@vchuravy @wsmoses

2 Likes

OK, so as discussed, I’m going back to a Zygote implementation of this, with the goal of making it as optimized as possible while waiting for Enzyme to work.

Unlike my manual previous implementation of this using pullback which used to work, when I use BacksolveAdjoint(autojacvec=ZygoteVJP()) I’m getting an error that I cannot debug. In order to avoid mutation issues, I’m passing a tuple of model parameters and I’m returning a copy of dH.

Here’s a new MWE with the updated implementation:

using Statistics
using LinearAlgebra
using Random 
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using Tullio
using RecursiveArrayTools
using Infiltrator

const t₁ = 10                 # number of simulation years 
const ρ = 900f0                     # Ice density [kg / m^3]
const g = 9.81f0                    # Gravitational acceleration [m / s^2]
const n = 3f0                       # Glen's flow law exponent
const maxA = 8f-16
const minA = 3f-17
const maxT = 1f0
const minT = -25f0
A = 1.3f-24 #2e-16  1 / Pa^3 s
A *= Float32(60 * 60 * 24 * 365.25) # [1 / Pa^3 yr]
C = 0
α = 0

@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]

function ref_glacier(temps, H₀)
      
    H = deepcopy(H₀)
    
    # Initialize all matrices for the solver
    S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
    dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
    D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
    V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
    
    # Gather simulation parameters
    current_year = Int(0)
    context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])

    # Perform reference simulation with forward model 
    println("Running forward PDE ice flow model...\n")
    iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
    iceflow_sol = solve(iceflow_prob, BS3(), progress=true, saveat=1.0, progress_steps = 1)

    return Float32.(iceflow_sol[end])
end

function iceflow!(dH, H, context,t)
    # Unpack parameters
    #A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year 
    current_year = Ref(context.x[18])
    A = Ref(context.x[1])
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year[] && year <= t₁
        temp = Ref{Float32}(context.x[7][year])
        A[] .= A_fake(temp[])
        current_year[] .= year
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    SIA!(dH, H, context)
end      


function train_iceflow_UDE(H₀, UA, H_ref, temps)
    
    # Gather simulation parameters
    H = deepcopy(H₀)
    
    # Initialize all matrices for the solver
    S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
    dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
    D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
    V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
    
    # Gather simulation parameters
    current_year = 0
    θ = initial_params(UA)
    context = (A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H)
    loss(θ) = loss_iceflow(UA, θ, H, context) # closure

    println("Training iceflow UDE...")

    iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01), maxiters = 10)

    return iceflow_trained
end

function loss_iceflow(UA, θ, H, context)
    
    H = predict_iceflow(UA, θ, H, context)
    
    H_ref = context[19]
    l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], H_ref[H.!= 0.0]; agg=sum))
    println("Loss = ", l_H)

    return l_H
end

function predict_iceflow(UA, θ, H, context)
        
    iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, context, UA, θ, t) # closure
    tspan = (0.0,t₁)
    iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
    H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, save_everystep=false, 
                   sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP()), 
                   progress=true, progress_steps = 1)

    return H_pred[end]
end

function iceflow_NN!(dH, H, context, UA, θ, t)
    
    # Unpack parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    current_year = context[18]
    A = context[1]
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year && year <= t₁
        temp = context[7][year]
        YA = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters

        # Unpack and repack tuple to update `A` and `current_year`
        A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H = context
        context = (YA, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, year, H_ref, H)

    end

    # Compute the Shallow Ice Approximation in a staggered grid
    dH = SIA!(dH, H, context)

end  

"""
    SIA(H, p)

Compute a step of the Shallow Ice Approximation PDE in a forward model
"""

function SIA!(dH, H, context)
    
    # Retrieve parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    A = context.x[1]
    B = context.x[2]
    S = context.x[3]
    dSdx = context.x[4]
    dSdy = context.x[5]
    D = context.x[6]
    dSdx_edges = context.x[8]
    dSdy_edges = context.x[9]
    ∇S = context.x[10]
    Fx = context.x[11]
    Fy = context.x[12]
    
    # Update glacier surface altimetry
    S .= B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx .= diff_x(S) / Δx
    dSdy .= diff_y(S) / Δy
    ∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2) 

    Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s 
    D .= Γ .* avg(H).^(n + 2) .* ∇S

    # Compute flux components
    dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
    Fx .= .-avg_y(D) .* dSdx_edges
    Fy .= .-avg_x(D) .* dSdy_edges 

    #  Flux divergence
    inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here 
    
    # Compute velocities    
    #Vx = -D./(avg(H) .+ ϵ).*avg_y(dSdx)
    #Vy = -D./(avg(H) .+ ϵ).*avg_x(dSdy)
end

# Function without mutation for Zygote, with context as a tuple
function SIA!(dH, H, context::Tuple)
    
    # Retrieve parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    A = context[1]
    B = context[2]
    S = context[3]
    dSdx = context[4]
    dSdy = context[5]
    D = context[6]
    dSdx_edges = context[8]
    dSdy_edges = context[9]
    ∇S = context[10]
    Fx = context[11]
    Fy = context[12]
    
    # Update glacier surface altimetry
    S = B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx = diff_x(S) / Δx
    dSdy = diff_y(S) / Δy
    ∇S = (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2) 

    Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s 
    D = Γ .* avg(H).^(n + 2) .* ∇S

    # Compute flux components
    dSdx_edges = diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges = diff(S[2:end - 1,:], dims=2) / Δy
    Fx = -avg_y(D) .* dSdx_edges
    Fy = -avg_x(D) .* dSdy_edges 

    #  Flux divergence
    @tullio dH[i,j] := -(diff(Fx, dims=1)[pad(i-1,1,1),pad(j-1,1,1)] / Δx + diff(Fy, dims=2)[pad(i-1,1,1),pad(j-1,1,1)] / Δy) # MB to be added here 

    return dH
end


function A_fake(temp)
    return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end

predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16


#### Generate reference dataset ####
nx = ny = 100
const B = zeros(Float32, (nx, ny))
const σ = 1000
H₀ = Matrix{Float32}([ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ])    
Δx = Δy = 50 #m   

const temps =  Vector{Float32}([0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1])
H_ref = ref_glacier(temps, H₀)

# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
        FastDense(1,3, x->tanh.(x)),
        FastDense(3,10, x->tanh.(x)),
        FastDense(10,3, x->tanh.(x)),
        FastDense(3,1, sigmoid_A)
    )

iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)

And this is the new error I’m having:

ERROR: LoadError: MethodError: no method matching vec(::Nothing)
Closest candidates are:
  vec(::FillArrays.Ones{T, N, Axes} where {N, Axes}) where T at /Users/Bolib001/.julia/packages/FillArrays/Vzxer/src/fillalgebra.jl:3
  vec(::Adjoint{var"#s832", var"#s8321"} where {var"#s832"<:Real, var"#s8321"<:(AbstractVector{T} where T)}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:240
  vec(::Tuple{Vararg{MultivariatePolynomials.AbstractVariable, N} where N}) at /Users/Bolib001/.julia/packages/MultivariatePolynomials/vqcb5/src/operators.jl:351
  ...
Stacktrace:
  [1] _vecjacobian!(dλ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, y::Matrix{Float32}, λ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float32}, t::Float64, S::DiffEqSensitivity.ODEBacksolveSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Matrix{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, UniformScaling{Bool}}, BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float32}, ODEProblem{Matrix{Float32}, Tuple{Float64, Float64}, true, Vector{Float32}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, isautojacvec::ZygoteVJP, dgrad::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, dy::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, W::Nothing)
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/sLLcu/src/derivative_wrappers.jl:448
  [2] #vecjacobian!#37
    @ ~/.julia/packages/DiffEqSensitivity/sLLcu/src/derivative_wrappers.jl:224 [inlined]
  [3] (::DiffEqSensitivity.ODEBacksolveSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Matrix{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, UniformScaling{Bool}}, BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float32}, ODEProblem{Matrix{Float32}, Tuple{Float64, Float64}, true, Vector{Float32}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}})(du::Vector{Float32}, u::Vector{Float32}, p::Vector{Float32}, t::Float64)
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/sLLcu/src/backsolve_adjoint.jl:36
...
...
...

Any idea where the error comes from? I find it super hard to debug errors when using the DiffEqFlux.sciml_train wrapper function. Thanks again!

The way I usually debug these things is to simplify the ODE by commenting things out until it works, and then progressively add back lines of code until the error is hit. My guess is that the pullback is calculating a nothing on one of the terms (i.e. no gradient) that then is being used for a gradient calculation, which shouldn’t happen and could be an issue in one of the adjoint definitions. If you help isolate it I can take a deeper look.

1 Like

I have progressively commented everything inside the ODE and the error persists. Could the issue come from the way I’m using a closure to pass additional parameters to the PDE function? Here’s what I’m doing:

function predict_iceflow(UA, θ, H, context)
        
    iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, context, UA, θ, t) # closure
    tspan = (0.0,t₁)
    iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
    H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, save_everystep=false, 
                   sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP()), 
                   progress=true, progress_steps = 1)

    return H_pred[end]
end

This appears to work in the forward pass, but crashes during the pullback with the same error as above. Is that closure correct?

1 Like

If context and UA are constants (not dependent on the optimized values) then it should be fine.

context is a tuple, containing the prediction made by UA with the optimized values (the tuple is updated by unpacking and repacking the values):

year = floor(Int, t) + 1
    if year != current_year && year <= t₁
        temp = context[7][year]
        YA = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters

        # Unpack and repack tuple to update `A` and `current_year`
        A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H = context
        context = (YA, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, year, H_ref, H)

    end

In fact I’m also using another closure for the loss function (for the same reason):

θ = initial_params(UA)
context = (A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H)
loss(θ) = loss_iceflow(UA, θ, H, context) # closure

iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01), maxiters = 10)

Yeah, so if those are parameters of the differential equation then they have to come from p, otherwise they are constant. This is the piece about @view from a vector etc.

1 Like

I don’t know if I shared this with you before, but ComponentArrays.jl is a very good way to set this kind of thing up.

https://jonniedie.github.io/ComponentArrays.jl/dev/examples/DiffEqFlux/

2 Likes

Thanks for the tip, indeed, a ComponentArray makes things clearer than using a tuple.

I have managed to fix that problem, now Zygote seems to be working, but the solver is returning Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable after computing the loss. I have tried several different solvers that you suggested, BS3() still seems to be the best, but only seems to work correctly for the forward run of the reference dataset.

Is it normal that on the normal forward run to generate the reference dataset everything runs super fast and smoothly, but on the UDE I get that behaviour? The strange thing is that the NN weights don’t seem to get updated for each iteration (maybe due to the abort?). Moreover, when I get that error dH goes to NaN, but I’m pretty sure the PDE model is correct. I have also constrained the output values of the NN in order to give physically stable values, and I have set a low learning rate to avoid unstable results. What am I missing? I can provide an updated MWE if you want.

Is the UDE that you defined stable? It’s easy to accidentally make a U(u) that is positive over your whole solution interval, and so if you then u' = f(u) + U(u) suddenly the ODE is no longer stable with the random initial weights of the neural network. This is especially true for long time spans, which is why we do not suggest using direct single shooting for long time spans and stiff models. Did you follow the docs on the training strategies for these cases?

https://diffeqflux.sciml.ai/dev/examples/multiple_shooting/

https://diffeqflux.sciml.ai/dev/examples/local_minima/

The latter technique is usually required if you aren’t doing multiple shooting since otherwise you can get cases where the first complete solve is unstable with the first weights, in which case the optimization would not be able to start.

Thanks for the links, so much information in the DifferentialEquations.jl and DiffEqFlux.jl docs! I will read everything and try this.

However, I’m quite sure the problem might not come from unstable initial values. I have designed the NN to only produce predictions between a given range of physically stable values (sigmoid function at the output layer with max and min values):

minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
        FastDense(1,3, x->tanh.(x)),
        FastDense(3,10, x->tanh.(x)),
        FastDense(10,3, x->tanh.(x)),
        FastDense(3,1, sigmoid_A)
    )

I have checked the values generated by the NN and they are always physically realistic, so they should not produce instabilities in the ODE. Moreover, the input timeseries I’m using are mostly the same value repeated with a tiny noise, so it’s a very easy initial problem (just to debug it, for now), quite different from the highly varying time series from the multiple shooting example you sent.

:thinking: I’d have to see a picture of how it’s going unstable. Though it could also be from using an explicit method.

I have also tried VCABM() and it appears to be working well (like BS3()), but I still get the same error during sciml_train. Wanna try the new MWE and have a guess? :grinning:

using Statistics
using LinearAlgebra
using Random 
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using Tullio
using RecursiveArrayTools
using ComponentArrays

const t₁ = 10                 # number of simulation years 
const ρ = 900f0                     # Ice density [kg / m^3]
const g = 9.81f0                    # Gravitational acceleration [m / s^2]
const n = 3f0                       # Glen's flow law exponent
const maxA = 8f-16
const minA = 3f-17
const maxT = 1f0
const minT = -25f0
A = 1.3f-24 #2e-16  1 / Pa^3 s
A *= Float32(60 * 60 * 24 * 365.25) # [1 / Pa^3 yr]
C = 0
α = 0

@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]

function ref_glacier(temps, H₀)
      
    H = deepcopy(H₀)
    
    # Initialize all matrices for the solver
    S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
    dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
    D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
    V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
    
    # Gather simulation parameters
    current_year = Int(0)
    context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])

    # Perform reference simulation with forward model 
    println("Running forward PDE ice flow model...\n")
    iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
    iceflow_sol = solve(iceflow_prob, BS3(), reltol=1e-6, progress=true, saveat=1.0, progress_steps = 1)

    return Float32.(iceflow_sol[end])
end

function iceflow!(dH, H, context,t)
    # Unpack parameters
    #A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year 
    current_year = Ref(context.x[18])
    A = Ref(context.x[1])
    
    # Get current year for MB and ELA
    year = floor(Int, t) + 1
    if year != current_year[] && year <= t₁
        temp = Ref{Float32}(context.x[7][year])
        A[] .= A_fake(temp[])
        current_year[] .= year
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    SIA!(dH, H, context)
end      


function train_iceflow_UDE(H₀, UA, H_ref, temps)
    
    # Gather simulation parameters
    H = deepcopy(H₀)
    
    # Gather simulation parameters
    current_year = 0
    θ = initial_params(UA)
    context = ComponentArray(B=B, C=C, α=α, temps=temps,current_year=current_year, H_ref=H_ref)
    loss(θ) = loss_iceflow(θ, UA, H, context) # closure

    println("Training iceflow UDE...")
    iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.0001), maxiters = 10)

    return iceflow_trained
end

function loss_iceflow(θ, UA, H, context)

    H = predict_iceflow(θ, UA, H, context)

    l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], context.H_ref[H.!= 0.0]; agg=sum))
    println("Loss = ", l_H)

    return l_H
end

function predict_iceflow(θ, UA, H, context)
    
    iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, θ, t, context, UA) # closure
    tspan = (0.0,t₁)
    iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
    H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, reltol=1e-6, save_everystep=false, 
                   sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP()), 
                   progress=true, progress_steps = 1)

    return H_pred[end]
end

function iceflow_NN!(dH, H, θ, t, context, UA)
    
    year = floor(Int, t) + 1
    if year <= t₁
        temp = context.temps[year]
    else
        temp = context.temps[year-1]
    end
    YA = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters

    if t%1 == 0
        println("A: ", YA)
    end

    # Compute the Shallow Ice Approximation in a staggered grid
    dH .= SIA!(dH, H, YA, context)

end  

"""
    SIA(H, p)

Compute a step of the Shallow Ice Approximation PDE in a forward model
"""

function SIA!(dH, H, context)
    
    # Retrieve parameters
    #A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
    A = context.x[1]
    B = context.x[2]
    S = context.x[3]
    dSdx = context.x[4]
    dSdy = context.x[5]
    D = context.x[6]
    dSdx_edges = context.x[8]
    dSdy_edges = context.x[9]
    ∇S = context.x[10]
    Fx = context.x[11]
    Fy = context.x[12]
    
    # Update glacier surface altimetry
    S .= B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx .= diff_x(S) / Δx
    dSdy .= diff_y(S) / Δy
    ∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2) 

    Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s 
    D .= Γ .* avg(H).^(n + 2) .* ∇S

    # Compute flux components
    dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
    Fx .= .-avg_y(D) .* dSdx_edges
    Fy .= .-avg_x(D) .* dSdy_edges 

    #  Flux divergence
    inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here 
    
end

# Function without mutation for Zygote, with context as a tuple
function SIA!(dH, H, A, context::ComponentArray)
    
    # Retrieve parameters
    B = context.B

    # Update glacier surface altimetry
    S = B .+ H

    # All grid variables computed in a staggered grid
    # Compute surface gradients on edges
    dSdx = diff_x(S) / Δx
    dSdy = diff_y(S) / Δy
    ∇S = (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2) 

    Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s 
    D = Γ .* avg(H).^(n + 2) .* ∇S

    # Compute flux components
    dSdx_edges = diff(S[:,2:end - 1], dims=1) / Δx
    dSdy_edges = diff(S[2:end - 1,:], dims=2) / Δy
    Fx = .-avg_y(D) .* dSdx_edges
    Fy = .-avg_x(D) .* dSdy_edges 

    #  Flux divergence
    @tullio dH[i,j] := -(diff(Fx, dims=1)[pad(i-1,1,1),pad(j-1,1,1)] / Δx + diff(Fy, dims=2)[pad(i-1,1,1),pad(j-1,1,1)] / Δy) # MB to be added here 

    return dH
end


function A_fake(temp)
    return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end

predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16


#### Generate reference dataset ####
nx = ny = 100
const B = zeros(Float32, (nx, ny))
const σ = 1000
H₀ = Matrix{Float32}([ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ])    
Δx = Δy = 50 #m   

const temps =  Vector{Float32}([0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1])
H_ref = ref_glacier(temps, H₀)

# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
        FastDense(1,3, x->tanh.(x)),
        FastDense(3,10, x->tanh.(x)),
        FastDense(10,3, x->tanh.(x)),
        FastDense(3,1, sigmoid_A)
    )

iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)

Yeah, I’ll see when I can get to that. I’m in the middle of a big PR so I’m trying to debug from a distance :sweat_smile:

1 Like