The answer is that Zygote.Buffer silently returns zero gradients. It happens at this piece of code:
I fixed @JordiBolibar 's code by making it fully non-mutating.
## Environment and packages
using Statistics
using LinearAlgebra
using Random
using OrdinaryDiffEq
using DiffEqFlux
using Flux
using Tullio
using RecursiveArrayTools
using Infiltrator
using Plots
const t₁ = 5 # number of simulation years
const ρ = 900f0 # Ice density [kg / m^3]
const g = 9.81f0 # Gravitational acceleration [m / s^2]
const n = 3f0 # Glen's flow law exponent
const maxA = 8f-16
const minA = 3f-17
const maxT = 1f0
const minT = -25f0
A = 1.3f-24 #2e-16 1 / Pa^3 s
A *= Float32(60 * 60 * 24 * 365.25) # [1 / Pa^3 yr]
C = 0
α = 0
@views avg(A) = 0.25 .* ( A[1:end-1,1:end-1] .+ A[2:end,1:end-1] .+ A[1:end-1,2:end] .+ A[2:end,2:end] )
@views avg_x(A) = 0.5 .* ( A[1:end-1,:] .+ A[2:end,:] )
@views avg_y(A) = 0.5 .* ( A[:,1:end-1] .+ A[:,2:end] )
@views diff_x(A) = (A[begin + 1:end, :] .- A[1:end - 1, :])
@views diff_y(A) = (A[:, begin + 1:end] .- A[:, 1:end - 1])
@views inn(A) = A[2:end-1,2:end-1]
function generate_ref_dataset(temp_series, H₀, ensemble=ensemble)
# Compute reference dataset in parallel
H = deepcopy(H₀)
# Initialize all matrices for the solver
S, dSdx, dSdy = zeros(Float32,nx,ny),zeros(Float32,nx-1,ny),zeros(Float32,nx,ny-1)
dSdx_edges, dSdy_edges, ∇S = zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1),zeros(Float32,nx-1,ny-1)
D, dH, Fx, Fy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-2,ny-2),zeros(Float32,nx-1,ny-2),zeros(Float32,nx-2,ny-1)
V, Vx, Vy = zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1),zeros(Float32,nx-1,ny-1)
# Gather simulation parameters
current_year = 0
context = ArrayPartition([A], B, S, dSdx, dSdy, D, copy(temp_series[1]), dSdx_edges, dSdy_edges, ∇S, Fx, Fy, Vx, Vy, V, C, α, [current_year])
function prob_iceflow_func(prob, i, repeat, context, temp_series) # closure
println("Processing temp series #$i ≈ ", mean(temp_series[i]))
context.x[7] .= temp_series[i] # We set the temp_series for the ith trajectory
# iceflow_PDE_batch!(dH, H, p, t) = iceflow!(dH, H, context, t) # closure
return remake(prob, p=context)
prob_func(prob, i, repeat) = prob_iceflow_func(prob, i, repeat, context, temp_series) # closure
# Perform reference simulation with forward model
println("Running forward PDE ice flow model...\n")
iceflow_prob = ODEProblem(iceflow!,H,(0.0,t₁),context)
ensemble_prob = EnsembleProblem(iceflow_prob, prob_func = prob_func)
iceflow_sol = solve(ensemble_prob, BS3(), ensemble, trajectories = length(temp_series),
pmap_batch_size=length(temp_series), reltol=1e-6,
progress=true, saveat=1.0, progress_steps = 50)
return Float32.(iceflow_sol)
function train_iceflow_UDE(H₀, UA, H_refs, temp_series)
H = deepcopy(H₀)
current_year = 0f0
θ = initial_params(UA)
# Tuple with all the temp series and H_refs
context = (B, H, current_year, temp_series)
loss(θ) = loss_iceflow(θ, context, UA, H_refs) # closure
x = loss(θ)
y = loss(θ + rand(length(θ)))
@show x,y,x==y
# @infiltrate
println("Training iceflow UDE...")
iceflow_trained = DiffEqFlux.sciml_train(loss, θ, RMSProp(0.01f0), cb=callback, maxiters = 10)
return iceflow_trained
callback = function (θ,l) # callback function to observe training
@show l
function loss_iceflow(θ, context, UA, H_refs)
H_preds = predict_iceflow(θ, UA, context)
A_pred = predict_A̅(UA, θ, [mean(temp_series[5])])
A_ref = A_fake(mean(temp_series[5]))
println("Predicted A: ", A_pred)
println("True A: ", A_ref)
H = H_preds.u[end]
H_ref = H_refs[5][end]
l_H_avg = Flux.Losses.mse(H[H .!= 0.0], H_ref[H.!= 0.0]; agg=mean)
return l_H_avg
function predict_iceflow(θ, UA, context, ensemble=ensemble)
function prob_iceflow_func(prob, i, repeat, context, UA) # closure
# B, H, current_year, temp_series)
temp_series = context[4]
println("Processing temp series #$i ≈ ", mean(temp_series[i]))
# We add the ith temperature series
iceflow_UDE_batch(H, θ, t) = iceflow_NN(H, θ, t, context, temp_series[i], UA) # closure
return remake(prob, f=iceflow_UDE_batch)
prob_func(prob, i, repeat) = prob_iceflow_func(prob, i, repeat, context, UA)
# (B, H, current_year, temp_series, batch_idx)
H = context[2]
tspan = (0.0,t₁)
iceflow_UDE(H, θ, t) = iceflow_NN(H, θ, t, context, temp_series[1], UA) # closure
iceflow_prob = ODEProblem(iceflow_UDE,H,tspan,θ)
ensemble_prob = EnsembleProblem(iceflow_prob, prob_func = prob_func)
# H_pred = solve(ensemble_prob, BS3(), ensemble, trajectories = length(temp_series),
# pmap_batch_size=length(temp_series), u0=H, p=θ, reltol=1e-6,
# sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP()), save_everystep=false,
# progress=true, progress_steps = 50)
H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, reltol=1e-6,
sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP()), save_everystep=false,
progress=true, progress_steps = 50)
return H_pred
function predict_iceflow(θ, UA, context, ensemble=ensemble)
function prob_iceflow_func(prob, i, repeat, context, UA) # closure
# B, H, current_year, temp_series)
temp_series = context[4]
println("Processing temp series #$i ≈ ", mean(temp_series[i]))
# We add the ith temperature series
iceflow_UDE_batch(H, θ, t) = iceflow_NN(H, θ, t, context, temp_series[i], UA) # closure
return remake(prob, f=iceflow_UDE_batch!)
prob_func(prob, i, repeat) = prob_iceflow_func(prob, i, repeat, context, UA)
# (B, H, current_year, temp_series, batch_idx)
H = context[2]
tspan = (0.0,t₁)
iceflow_UDE(H, θ, t) = iceflow_NN(H, θ, t, context, temp_series[1], UA) # closure
iceflow_prob = ODEProblem(iceflow_UDE,H,tspan,θ)
ensemble_prob = EnsembleProblem(iceflow_prob, prob_func = prob_func)
# H_pred = solve(ensemble_prob, BS3(), ensemble, trajectories = length(temp_series),
# pmap_batch_size=length(temp_series), u0=H, p=θ, reltol=1e-6,
# sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP()), save_everystep=false,
# progress=true, progress_steps = 50)
H_pred = solve(iceflow_prob, BS3(), u0=H, p=θ, reltol=1e-6,
sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP()), save_everystep=false,
progress=true, progress_steps = 50)
return H_pred
and to catch this issue I am adding an error on the DiffEq side because apparently Zygote doesn’t realize it can’t handle this case.