No, not on the return. Internal
I think the issue comes from all the @views:
A = Ref{Float32}(context[1])
B = @view context[2][:,:]
S = @view context[3][:,:]
dSdx = @view context[4][:,:]
dSdy = @view context[5][:,:]
D = @view context[6][:,:]
dSdx_edges = @view context[8][:,:]
dSdy_edges = @view context[9][:,:]
∇S = @view context[10][:,:]
Fx = @view context[11][:,:]
Fy = @view context[12][:,:]
They appear as type Any
, which might induce type inference issues for Enzyme.jl.
@code_warntype SIA!(zeros(Float32,nx,ny),H,p)
Variables
#self#::Core.Const(SIA!)
dH::Matrix{Float32}
H::Matrix{Float64}
context::Vector{Any}
Γ::Float64
Fy::Any
Fx::Any
∇S::Any
dSdy_edges::Any
dSdx_edges::Any
D::Any
dSdy::Any
dSdx::Any
S::Any
B::Any
A::Base.RefValue{Float32}
@_17::Any
@_18::Any
@_19::Any
@_20::Any
@_21::Any
@_22::Any
@_23::Any
@_24::Any
@_25::Any
@_26::Any
Is there any way to avoid this issue when using @views? The original matrices are all specified with type Float32
.
I think this has abstract type, Vector{Any}
, and thus every context[10]
etc. will be an instability. You could make context
a tuple instead.
Using @view
here is probably harmful. Just indexing will retrieve the array, without copying, but the view makes for a more complicated type:
julia> context = [[1 2; 3 4], 0]
2-element Vector{Any}:
[1 2; 3 4]
0
julia> @view context[1][:,:]
2×2 view(::Matrix{Int64}, :, :) with eltype Int64:
1 2
3 4
julia> @btime $context[1] # this doesn't copy
min 2.833 ns, mean 2.981 ns (0 allocations)
2×2 Matrix{Int64}:
1 2
3 4
Let us check this with JET.jl
. First we need a function to check
function test()
#### Generate reference dataset ####
nx = ny = 100
B = zeros(Float64, (nx, ny))
σ = 1000
H₀ = [ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ]
Δx = Δy = 50 #m
temps = [0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1]
H_ref = ref_glacier(temps, H₀)
# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
FastDense(1,3, x->tanh.(x)),
FastDense(3,10, x->tanh.(x)),
FastDense(10,3, x->tanh.(x)),
FastDense(3,1, sigmoid_A)
)
iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)
end
Then we run JET.@report_opt test()
and get something like this.
Thanks for your reply, Michael. The reason I’m using an array instead of a tuple is because my model requires updating certain parameters during the forward run (e.g. A
). What would be the best way to have multiple parameters passed by reference while being able to update them and avoiding type instabilities? I feel like I’ve been going round in circles fixing one issue which creates another.
But must you replace A, or can you just update A?
julia> context = ([1 2; 3 4], Ref(0));
julia> context[2][] = 99;
julia> context[1] .= 7:8;
julia> context
([7 7; 8 8], Base.RefValue{Int64}(99))
For the Diffeq adjoint it’s best to just make it a vector of numbers. I would recommend you just make the parameters a vector of numbers and then A = reshape(@view p[n:m],i,j)
etc. to build out all of the structures. An ArrayPartition is a good object for doing this kind of thing BTW.
This is a possibility, but using @views allows me to write all the physical equations with meaningful variable names instead of e.g. context[1]
. I’d like a solution that is readable on top of being efficient.
I tried using a tuple and avoiding @views and I still get another error:
....
....
....
!10087 = !DILocation(line: 287, scope: !2750, inlinedAt: !10088)
!10088 = !DILocation(line: 969, scope: !2750, inlinedAt: !10089)
!10089 = !DILocation(line: 148, scope: !2753, inlinedAt: !10090)
!10090 = !DILocation(line: 375, scope: !2755, inlinedAt: !10083)
define internal fastcc { i8* } @fakeaugmented_julia_unalias_11598({ {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* noalias nocapture nonnull sret writeonly align 8 dereferenceable(48) %0, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* nocapture %"'", [1 x {} addrspace(10)*]* noalias nocapture nonnull writeonly align 8 dereferenceable(8) %1, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2, {} addrspace(10)* %"'1", { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* nocapture nonnull readonly align 8 dereferenceable(48) %3, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* nocapture %"'2") unnamed_addr !dbg !10042 {
top:
%_replacementA = phi {}***
%4 = load i8, i8* inttoptr (i64 4469935449 to i8*), align 1, !dbg !10043, !tbaa !145, !invariant.load !4
%5 = and i8 %4, 1, !dbg !10045
%.not.not = icmp eq i8 %5, 0, !dbg !10045
br i1 %.not.not, label %L4, label %L82, !dbg !10045
L4: ; preds = %top
%6 = load i8, i8* inttoptr (i64 5579314521 to i8*), align 1, !dbg !10043, !tbaa !145, !invariant.load !4
%7 = and i8 %6, 1, !dbg !10045
%.not.not26 = icmp eq i8 %7, 0, !dbg !10045
br i1 %.not.not26, label %L7, label %L82, !dbg !10045
L7: ; preds = %L4
%8 = addrspacecast {} addrspace(10)* %2 to {} addrspace(11)*, !dbg !10047
%9 = call nonnull align 8 {}* @julia.pointer_from_objref({} addrspace(11)* %8) #4, !dbg !10047
%"'ip_phi" = phi {}* , !dbg !10047
%10 = bitcast {}* %9 to i64*, !dbg !10047
%11 = load i64, i64* %10, align 8, !dbg !10047, !tbaa !145, !range !233, !invariant.load !4
%"'il_phi" = phi i64 , !dbg !10050
%12 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 0, !dbg !10050
%13 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %12 unordered, align 8, !dbg !10050, !tbaa !145, !invariant.load !4, !nonnull !4, !dereferenceable !379, !align !380
%"'il_phi3" = phi {} addrspace(10)* , !dbg !10052
%"'ipc" = addrspacecast {} addrspace(10)* %"'il_phi3" to {} addrspace(11)*, !dbg !10052
%14 = addrspacecast {} addrspace(10)* %13 to {} addrspace(11)*, !dbg !10052
%15 = call {}* @julia.pointer_from_objref({} addrspace(11)* %"'ipc"), !dbg !10052
%16 = call nonnull align 8 {}* @julia.pointer_from_objref({} addrspace(11)* %14) #4, !dbg !10052
%17 = bitcast {}* %16 to i64*, !dbg !10052
%18 = load i64, i64* %17, align 8, !dbg !10052, !tbaa !145, !range !233, !invariant.load !4
%.not = icmp eq i64 %11, %18, !dbg !10055
br i1 %.not, label %L20, label %L82, !dbg !10046
L20: ; preds = %L7
%19 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 1, i32 0, i64 1, !dbg !10058
%20 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 1, i32 0, i64 0, !dbg !10063
%21 = load i64, i64 addrspace(11)* %19, align 8, !dbg !10065, !tbaa !145, !invariant.load !4
%22 = load i64, i64 addrspace(11)* %20, align 8, !dbg !10065, !tbaa !145, !invariant.load !4
%23 = call { i64, i1 } @llvm.ssub.with.overflow.i64(i64 %21, i64 %22), !dbg !10065
%24 = extractvalue { i64, i1 } %23, 0, !dbg !10065
%25 = extractvalue { i64, i1 } %23, 1, !dbg !10065
br i1 %25, label %L29, label %L32, !dbg !10067
L29: ; preds = %L20
%26 = call fastcc nonnull {} addrspace(10)* @julia_throw_overflowerr_binaryop_11556() #12, !dbg !10067
unreachable, !dbg !10067
L32: ; preds = %L20
%27 = call { i64, i1 } @llvm.sadd.with.overflow.i64(i64 %24, i64 1), !dbg !10068
%28 = extractvalue { i64, i1 } %27, 1, !dbg !10068
br i1 %28, label %L36, label %L64, !dbg !10070
L36: ; preds = %L32
%29 = call fastcc nonnull {} addrspace(10)* @julia_throw_overflowerr_binaryop_11556() #12, !dbg !10070
unreachable, !dbg !10070
L64: ; preds = %L32
%30 = extractvalue { i64, i1 } %27, 0, !dbg !10068
%31 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 1, i32 1, i64 0, i64 0, !dbg !10071
%32 = load i64, i64 addrspace(11)* %31, align 8, !dbg !10075, !tbaa !145, !invariant.load !4
%33 = call nonnull {} addrspace(10)* @jl_alloc_array_2d({} addrspace(10)* addrspacecast ({}* inttoptr (i64 5464987568 to {}*) to {} addrspace(10)*), i64 %30, i64 %32) [ "jl_roots"(i64 addrspace(11)* %31) ], !dbg !10075
%"'mi" = phi {} addrspace(10)* , !dbg !10078
%34 = call fastcc nonnull {} addrspace(10)* @julia_copyto__11605({} addrspace(10)* nonnull align 16 dereferenceable(40) %33, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* nocapture nonnull readonly align 8 dereferenceable(48) %3), !dbg !10078
%"'ip_phi7" = phi {} addrspace(10)* , !dbg !10079
%35 = icmp sgt i64 %30, 0, !dbg !10079
%36 = select i1 %35, i64 %30, i64 0, !dbg !10086
%37 = icmp sgt i64 %32, 0, !dbg !10079
%38 = select i1 %37, i64 %32, i64 0, !dbg !10079
%39 = getelementptr inbounds [1 x {} addrspace(10)*], [1 x {} addrspace(10)*]* %1, i64 0, i64 0, !dbg !10046
store {} addrspace(10)* %33, {} addrspace(10)** %39, align 8, !dbg !10046
%.repack = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 0, !dbg !10046
store {} addrspace(10)* %33, {} addrspace(10)** %.repack, align 8, !dbg !10046
%.repack27.repack.repack = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 1, i32 0, i64 0, !dbg !10046
store i64 1, i64* %.repack27.repack.repack, align 8, !dbg !10046
%.repack27.repack.repack35 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 1, i32 0, i64 1, !dbg !10046
store i64 %36, i64* %.repack27.repack.repack35, align 8, !dbg !10046
%40 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 1, i32 1, i64 0, i64 0, !dbg !10046
store i64 %38, i64* %40, align 8, !dbg !10046
%.repack29 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0, i64 0, i32 2, !dbg !10046
%41 = bitcast i64* %.repack29 to i8*, !dbg !10046
call void @llvm.memset.p0i8.i64(i8* nocapture nonnull writeonly align 8 dereferenceable(16) %41, i8 0, i64 16, i1 false), !dbg !10046
ret { i8* } undef, !dbg !10046
L82: ; preds = %L7, %L4, %top
%42 = getelementptr inbounds { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }, { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3, i64 0, i32 0, !dbg !10046
%43 = load {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %42, align 8, !dbg !10046
%"'il_phi8" = phi {} addrspace(10)* , !dbg !10046
%44 = getelementptr inbounds [1 x {} addrspace(10)*], [1 x {} addrspace(10)*]* %1, i64 0, i64 0, !dbg !10046
store {} addrspace(10)* %43, {} addrspace(10)** %44, align 8, !dbg !10046
%45 = bitcast { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 }* %0 to i8*, !dbg !10046
%46 = bitcast { {} addrspace(10)*, { [2 x i64], [1 x [1 x i64]] }, i64, i64 } addrspace(11)* %3 to i8 addrspace(11)*, !dbg !10046
call void @llvm.memcpy.p0i8.p11i8.i64(i8* nocapture nonnull writeonly align 8 dereferenceable(48) %45, i8 addrspace(11)* nonnull align 8 dereferenceable(48) %46, i64 48, i1 false), !dbg !10046
ret { i8* } undef, !dbg !10046
allocsForInversion: ; No predecessors!
}
%"'il_phi3" = phi {} addrspace(10)* , !dbg !164
Assertion failed: (I->use_empty()), function erase, file /workspace/srcdir/Enzyme/enzyme/Enzyme/CacheUtility.cpp, line 72.
signal (6): Abort trap: 6
in expression starting at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:215
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 375843513 (Pool: 375774007; Big: 69506); GC: 174
@ChrisRackauckas: thanks, I’ll explore those data structures once I have something working end-to-end with AD for this case.
So I’ve followed your advice and I have implemented this using ArrayPartition
s. Enzyme.jl is still crashing, but with a different error.
Here is my new MWE with all the recent updates from this discussion:
using Statistics
using LinearAlgebra
using Random
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using RecursiveArrayTools
using Infiltrator
const t₁ = 10 # number of simulation years
const ρ = 900 # Ice density [kg / m^3]
const g = 9.81 # Gravitational acceleration [m / s^2]
const n = 3 # Glen's flow law exponent
const maxA = 8e-16
const minA = 3e-17
const maxT = 1
const minT = -25
A = 1.3f-24 #2e-16 1 / Pa^3 s
A *= 60 * 60 * 24 * 365.25 # [1 / Pa^3 yr]
C = 0
α = 0
@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]
function ref_glacier(temps, H₀)
H = deepcopy(H₀)
# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
# Gather simulation parameters
current_year = Int(0)
context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])
# Perform reference simulation with forward model
println("Running forward PDE ice flow model...\n")
iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
iceflow_sol = solve(iceflow_prob, BS3(), progress=true, saveat=1.0, progress_steps = 1)
return Float32.(iceflow_sol[end])
end
function iceflow!(dH, H, context,t)
# Unpack parameters
#A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year
current_year = context.x[18]
A = context.x[1]
# Get current year for MB and ELA
year = floor(Int, t) + 1
if year != current_year && year <= t₁
temp = Ref{Float32}(context.x[7][year])
context.x[1] .= A_fake(temp[])
current_year[] = year
end
# Compute the Shallow Ice Approximation in a staggered grid
SIA!(dH, H, context)
end
function train_iceflow_UDE(H₀, UA, H_ref, temps)
# Gather simulation parameters
H = deepcopy(H₀)
# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
# Gather simulation parameters
current_year = 0
θ = initial_params(UA)
context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year], H_ref, H, UA, θ)
loss(θ) = loss_iceflow(UA, θ, H, context) # closure
println("Training iceflow UDE...")
iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01), maxiters = 5)
return iceflow_trained
end
function loss_iceflow(UA, θ, H, context)
H = predict_iceflow(UA, θ, H, context)
H_ref = context.x[19]
l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], H_ref[H.!= 0.0]; agg=sum))
println("Loss = ", l_H)
return l_H
end
function predict_iceflow(UA, θ, H, context)
iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, θ, t, context) # closure
tspan = (0.0,t₁)
iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, save_everystep=false,
sensealg = BacksolveAdjoint(autojacvec=EnzymeVJP()),
progress=true, progress_steps = 1)
return H_pred[end]
end
function iceflow_NN!(dH, H, θ, t, context)
# Unpack parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
current_year = Ref(context.x[18])
A = Ref(context.x[1])
UA = context.x[21]
# Get current year for MB and ELA
year = floor(Int, t) + 1
if year != current_year && year <= t₁
temp = context.x[7][year]
A[] .= predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters
current_year[] .= year
end
# Compute the Shallow Ice Approximation in a staggered grid
SIA!(dH, H, context)
return nothing
end
"""
SIA(H, p)
Compute a step of the Shallow Ice Approximation PDE in a forward model
"""
function SIA!(dH, H, context)
# Retrieve parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
A = context.x[1]
B = context.x[2]
S = context.x[3]
dSdx = context.x[4]
dSdy = context.x[5]
D = context.x[6]
dSdx_edges = context.x[8]
dSdy_edges = context.x[9]
∇S = context.x[10]
Fx = context.x[11]
Fy = context.x[12]
# Update glacier surface altimetry
S .= B .+ H
# All grid variables computed in a staggered grid
# Compute surface gradients on edges
dSdx .= diff_x(S) / Δx
dSdy .= diff_y(S) / Δy
∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2)
Γ = 2 * A[] * (ρ * g)^n / (n+2) # 1 / m^3 s
D .= Γ .* avg(H).^(n + 2) .* ∇S
# Compute flux components
dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
Fx .= .-avg_y(D) .* dSdx_edges
Fy .= .-avg_x(D) .* dSdy_edges
# Flux divergence
inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here
return nothing
end
function A_fake(temp)
return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end
predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16
#### Generate reference dataset ####
nx = ny = 100
B = zeros(Float32, (nx, ny))
σ = 1000
H₀ = Matrix{Float32}([ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ])
Δx = Δy = 50 #m
temps = Vector{Float32}([0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1])
H_ref = ref_glacier(temps, H₀)
# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
FastDense(1,3, x->tanh.(x)),
FastDense(3,10, x->tanh.(x)),
FastDense(10,3, x->tanh.(x)),
FastDense(3,1, sigmoid_A)
)
iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)
And here’s the error I’m getting:
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/g5epq/src/compiler.jl:375
┌ Warning: ("reverse differentiating jl_invoke call without split mode", iceflow_NN!, MethodInstance for iceflow_NN!(::Matrix{Float32}, ::Matrix{Float32}, ::Function, ::Vector{Float32}, ::Float64, ::ArrayPartition{Float64, Tuple{Vector{Float64}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Vector{Int64}, Matrix{Float32}, Matrix{Float32}}}))
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/g5epq/src/compiler.jl:404
need %35 = call token (...) @llvm.julia.gc_preserve_begin({} addrspace(10)* nonnull %34), !dbg !215 via call void @llvm.julia.gc_preserve_end(token %35), !dbg !245
Assertion failed: (!inst->getType()->isTokenTy()), function is_value_needed_in_reverse, file /workspace/srcdir/Enzyme/enzyme/Enzyme/DifferentialUseAnalysis.h, line 379.
signal (6): Abort trap: 6
in expression starting at /Users/Bolib001/Desktop/Jordi/Julia/odinn_toy_model/scripts/examples/MWE_iceflow.jl:214
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 382360361 (Pool: 382273315; Big: 87046); GC: 196
Some progress, but still not there yet. Any idea on what’s wrong now regarding the “split mode”? Thanks!
I’m going to need to call in the Enzyme devs here since this is something I would now expect to work . Seems like there’s a dynamic call it cannot handle somewhere?
OK, so as discussed, I’m going back to a Zygote implementation of this, with the goal of making it as optimized as possible while waiting for Enzyme to work.
Unlike my manual previous implementation of this using pullback
which used to work, when I use BacksolveAdjoint(autojacvec=ZygoteVJP())
I’m getting an error that I cannot debug. In order to avoid mutation issues, I’m passing a tuple of model parameters and I’m returning a copy of dH
.
Here’s a new MWE with the updated implementation:
using Statistics
using LinearAlgebra
using Random
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using Tullio
using RecursiveArrayTools
using Infiltrator
const t₁ = 10 # number of simulation years
const ρ = 900f0 # Ice density [kg / m^3]
const g = 9.81f0 # Gravitational acceleration [m / s^2]
const n = 3f0 # Glen's flow law exponent
const maxA = 8f-16
const minA = 3f-17
const maxT = 1f0
const minT = -25f0
A = 1.3f-24 #2e-16 1 / Pa^3 s
A *= Float32(60 * 60 * 24 * 365.25) # [1 / Pa^3 yr]
C = 0
α = 0
@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]
function ref_glacier(temps, H₀)
H = deepcopy(H₀)
# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
# Gather simulation parameters
current_year = Int(0)
context = ArrayPartition([A], B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])
# Perform reference simulation with forward model
println("Running forward PDE ice flow model...\n")
iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
iceflow_sol = solve(iceflow_prob, BS3(), progress=true, saveat=1.0, progress_steps = 1)
return Float32.(iceflow_sol[end])
end
function iceflow!(dH, H, context,t)
# Unpack parameters
#A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year
current_year = Ref(context.x[18])
A = Ref(context.x[1])
# Get current year for MB and ELA
year = floor(Int, t) + 1
if year != current_year[] && year <= t₁
temp = Ref{Float32}(context.x[7][year])
A[] .= A_fake(temp[])
current_year[] .= year
end
# Compute the Shallow Ice Approximation in a staggered grid
SIA!(dH, H, context)
end
function train_iceflow_UDE(H₀, UA, H_ref, temps)
# Gather simulation parameters
H = deepcopy(H₀)
# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
# Gather simulation parameters
current_year = 0
θ = initial_params(UA)
context = (A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H)
loss(θ) = loss_iceflow(UA, θ, H, context) # closure
println("Training iceflow UDE...")
iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01), maxiters = 10)
return iceflow_trained
end
function loss_iceflow(UA, θ, H, context)
H = predict_iceflow(UA, θ, H, context)
H_ref = context[19]
l_H = sqrt(Flux.Losses.mse(H[H .!= 0.0], H_ref[H.!= 0.0]; agg=sum))
println("Loss = ", l_H)
return l_H
end
function predict_iceflow(UA, θ, H, context)
iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, context, UA, θ, t) # closure
tspan = (0.0,t₁)
iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, save_everystep=false,
sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP()),
progress=true, progress_steps = 1)
return H_pred[end]
end
function iceflow_NN!(dH, H, context, UA, θ, t)
# Unpack parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
current_year = context[18]
A = context[1]
# Get current year for MB and ELA
year = floor(Int, t) + 1
if year != current_year && year <= t₁
temp = context[7][year]
YA = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters
# Unpack and repack tuple to update `A` and `current_year`
A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H = context
context = (YA, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, year, H_ref, H)
end
# Compute the Shallow Ice Approximation in a staggered grid
dH = SIA!(dH, H, context)
end
"""
SIA(H, p)
Compute a step of the Shallow Ice Approximation PDE in a forward model
"""
function SIA!(dH, H, context)
# Retrieve parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
A = context.x[1]
B = context.x[2]
S = context.x[3]
dSdx = context.x[4]
dSdy = context.x[5]
D = context.x[6]
dSdx_edges = context.x[8]
dSdy_edges = context.x[9]
∇S = context.x[10]
Fx = context.x[11]
Fy = context.x[12]
# Update glacier surface altimetry
S .= B .+ H
# All grid variables computed in a staggered grid
# Compute surface gradients on edges
dSdx .= diff_x(S) / Δx
dSdy .= diff_y(S) / Δy
∇S .= (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2)
Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s
D .= Γ .* avg(H).^(n + 2) .* ∇S
# Compute flux components
dSdx_edges .= diff(S[:,2:end - 1], dims=1) / Δx
dSdy_edges .= diff(S[2:end - 1,:], dims=2) / Δy
Fx .= .-avg_y(D) .* dSdx_edges
Fy .= .-avg_x(D) .* dSdy_edges
# Flux divergence
inn(dH) .= .-(diff(Fx, dims=1) / Δx .+ diff(Fy, dims=2) / Δy) # MB to be added here
# Compute velocities
#Vx = -D./(avg(H) .+ ϵ).*avg_y(dSdx)
#Vy = -D./(avg(H) .+ ϵ).*avg_x(dSdy)
end
# Function without mutation for Zygote, with context as a tuple
function SIA!(dH, H, context::Tuple)
# Retrieve parameters
#A, B, S, dSdx, dSdy, D, norm_temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H, UA, θ
A = context[1]
B = context[2]
S = context[3]
dSdx = context[4]
dSdy = context[5]
D = context[6]
dSdx_edges = context[8]
dSdy_edges = context[9]
∇S = context[10]
Fx = context[11]
Fy = context[12]
# Update glacier surface altimetry
S = B .+ H
# All grid variables computed in a staggered grid
# Compute surface gradients on edges
dSdx = diff_x(S) / Δx
dSdy = diff_y(S) / Δy
∇S = (avg_y(dSdx).^2 .+ avg_x(dSdy).^2).^((n - 1)/2)
Γ = 2 * A * (ρ * g)^n / (n+2) # 1 / m^3 s
D = Γ .* avg(H).^(n + 2) .* ∇S
# Compute flux components
dSdx_edges = diff(S[:,2:end - 1], dims=1) / Δx
dSdy_edges = diff(S[2:end - 1,:], dims=2) / Δy
Fx = -avg_y(D) .* dSdx_edges
Fy = -avg_x(D) .* dSdy_edges
# Flux divergence
@tullio dH[i,j] := -(diff(Fx, dims=1)[pad(i-1,1,1),pad(j-1,1,1)] / Δx + diff(Fy, dims=2)[pad(i-1,1,1),pad(j-1,1,1)] / Δy) # MB to be added here
return dH
end
function A_fake(temp)
return @. minA + (maxA - minA) * ((temp-minT)/(maxT-minT) )^2
end
predict_A̅(UA, θ, temp) = UA(temp, θ)[1] .* 1e-16
#### Generate reference dataset ####
nx = ny = 100
const B = zeros(Float32, (nx, ny))
const σ = 1000
H₀ = Matrix{Float32}([ 250 * exp( - ( (i - nx/2)^2 + (j - ny/2)^2 ) / σ ) for i in 1:nx, j in 1:ny ])
Δx = Δy = 50 #m
const temps = Vector{Float32}([0.0, -0.5, -0.2, -0.1, -0.3, -0.1, -0.2, -0.3, -0.4, -0.1])
H_ref = ref_glacier(temps, H₀)
# Train UDE
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
FastDense(1,3, x->tanh.(x)),
FastDense(3,10, x->tanh.(x)),
FastDense(10,3, x->tanh.(x)),
FastDense(3,1, sigmoid_A)
)
iceflow_trained = train_iceflow_UDE(H₀, UA, H_ref, temps)
And this is the new error I’m having:
ERROR: LoadError: MethodError: no method matching vec(::Nothing)
Closest candidates are:
vec(::FillArrays.Ones{T, N, Axes} where {N, Axes}) where T at /Users/Bolib001/.julia/packages/FillArrays/Vzxer/src/fillalgebra.jl:3
vec(::Adjoint{var"#s832", var"#s8321"} where {var"#s832"<:Real, var"#s8321"<:(AbstractVector{T} where T)}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:240
vec(::Tuple{Vararg{MultivariatePolynomials.AbstractVariable, N} where N}) at /Users/Bolib001/.julia/packages/MultivariatePolynomials/vqcb5/src/operators.jl:351
...
Stacktrace:
[1] _vecjacobian!(dλ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, y::Matrix{Float32}, λ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float32}, t::Float64, S::DiffEqSensitivity.ODEBacksolveSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Matrix{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, UniformScaling{Bool}}, BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float32}, ODEProblem{Matrix{Float32}, Tuple{Float64, Float64}, true, Vector{Float32}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, isautojacvec::ZygoteVJP, dgrad::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, dy::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, W::Nothing)
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/sLLcu/src/derivative_wrappers.jl:448
[2] #vecjacobian!#37
@ ~/.julia/packages/DiffEqSensitivity/sLLcu/src/derivative_wrappers.jl:224 [inlined]
[3] (::DiffEqSensitivity.ODEBacksolveSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Matrix{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, UniformScaling{Bool}}, BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP, Bool}, Matrix{Float32}, ODEProblem{Matrix{Float32}, Tuple{Float64, Float64}, true, Vector{Float32}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, var"#iceflow_UDE!#6"{FastChain{Tuple{FastDense{var"#14#17", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#15#18", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{var"#16#19", DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}, FastDense{typeof(sigmoid_A), DiffEqFlux.var"#initial_params#90"{Vector{Float32}}}}}, Tuple{Float32, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Matrix{Float32}, Int64, Int64, Int64, Matrix{Float32}, Matrix{Float32}}}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}})(du::Vector{Float32}, u::Vector{Float32}, p::Vector{Float32}, t::Float64)
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/sLLcu/src/backsolve_adjoint.jl:36
...
...
...
Any idea where the error comes from? I find it super hard to debug errors when using the DiffEqFlux.sciml_train
wrapper function. Thanks again!
The way I usually debug these things is to simplify the ODE by commenting things out until it works, and then progressively add back lines of code until the error is hit. My guess is that the pullback is calculating a nothing on one of the terms (i.e. no gradient) that then is being used for a gradient calculation, which shouldn’t happen and could be an issue in one of the adjoint definitions. If you help isolate it I can take a deeper look.
I have progressively commented everything inside the ODE and the error persists. Could the issue come from the way I’m using a closure to pass additional parameters to the PDE function? Here’s what I’m doing:
function predict_iceflow(UA, θ, H, context)
iceflow_UDE!(dH, H, θ, t) = iceflow_NN!(dH, H, context, UA, θ, t) # closure
tspan = (0.0,t₁)
iceflow_prob = ODEProblem(iceflow_UDE!,H,tspan,θ)
H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, save_everystep=false,
sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP()),
progress=true, progress_steps = 1)
return H_pred[end]
end
This appears to work in the forward pass, but crashes during the pullback with the same error as above. Is that closure correct?
If context
and UA
are constants (not dependent on the optimized values) then it should be fine.
context
is a tuple, containing the prediction made by UA
with the optimized values (the tuple is updated by unpacking and repacking the values):
year = floor(Int, t) + 1
if year != current_year && year <= t₁
temp = context[7][year]
YA = predict_A̅(UA, θ, [temp]) # FastChain prediction requires explicit parameters
# Unpack and repack tuple to update `A` and `current_year`
A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H = context
context = (YA, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, year, H_ref, H)
end
In fact I’m also using another closure for the loss function (for the same reason):
θ = initial_params(UA)
context = (A, B, S, dSdx, dSdy, D, temps, dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, current_year, H_ref, H)
loss(θ) = loss_iceflow(UA, θ, H, context) # closure
iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01), maxiters = 10)
Yeah, so if those are parameters of the differential equation then they have to come from p
, otherwise they are constant. This is the piece about @view
from a vector etc.
I don’t know if I shared this with you before, but ComponentArrays.jl is a very good way to set this kind of thing up.
https://jonniedie.github.io/ComponentArrays.jl/dev/examples/DiffEqFlux/
Thanks for the tip, indeed, a ComponentArray
makes things clearer than using a tuple.
I have managed to fix that problem, now Zygote seems to be working, but the solver is returning Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable
after computing the loss. I have tried several different solvers that you suggested, BS3()
still seems to be the best, but only seems to work correctly for the forward run of the reference dataset.
Is it normal that on the normal forward run to generate the reference dataset everything runs super fast and smoothly, but on the UDE I get that behaviour? The strange thing is that the NN weights don’t seem to get updated for each iteration (maybe due to the abort?). Moreover, when I get that error dH
goes to NaN, but I’m pretty sure the PDE model is correct. I have also constrained the output values of the NN in order to give physically stable values, and I have set a low learning rate to avoid unstable results. What am I missing? I can provide an updated MWE if you want.
Is the UDE that you defined stable? It’s easy to accidentally make a U(u) that is positive over your whole solution interval, and so if you then u' = f(u) + U(u)
suddenly the ODE is no longer stable with the random initial weights of the neural network. This is especially true for long time spans, which is why we do not suggest using direct single shooting for long time spans and stiff models. Did you follow the docs on the training strategies for these cases?
https://diffeqflux.sciml.ai/dev/examples/multiple_shooting/
https://diffeqflux.sciml.ai/dev/examples/local_minima/
The latter technique is usually required if you aren’t doing multiple shooting since otherwise you can get cases where the first complete solve is unstable with the first weights, in which case the optimization would not be able to start.
Thanks for the links, so much information in the DifferentialEquations.jl and DiffEqFlux.jl docs! I will read everything and try this.
However, I’m quite sure the problem might not come from unstable initial values. I have designed the NN to only produce predictions between a given range of physically stable values (sigmoid function at the output layer with max and min values):
minA_out = 0.3
maxA_out = 8
sigmoid_A(x) = minA_out + (maxA_out - minA_out) / ( 1 + exp(-x) )
UA = FastChain(
FastDense(1,3, x->tanh.(x)),
FastDense(3,10, x->tanh.(x)),
FastDense(10,3, x->tanh.(x)),
FastDense(3,1, sigmoid_A)
)
I have checked the values generated by the NN and they are always physically realistic, so they should not produce instabilities in the ODE. Moreover, the input timeseries I’m using are mostly the same value repeated with a tiny noise, so it’s a very easy initial problem (just to debug it, for now), quite different from the highly varying time series from the multiple shooting example you sent.