Adjoint of 2d diffusion using SciMLSensitivity

I am trying to try to solve the adjoint problem of 2d diffusion using SciMLSensitivity. My test case follow the 1d case given at Partial Differential Equation (PDE) Constrained Optimization · SciMLSensitivity.jl
However, my case does not work and Optimization.solve gives a very long (and quite cryptic to me) stacktrace. Any idea of what is causing the issue? Thanks!
Here is the code

using DelimitedFiles, Plots
using DifferentialEquations, Optimization, OptimizationPolyalgorithms, Zygote
using SciMLSensitivity
# Problem setup parameters:
function meshgrid(x, y)
     X = [i for i in x, j in 1:length(y)]
     Y = [j for i in 1:length(x), j in y]
     X = transpose(X);
     Y = transpose(Y);
     return X, Y
end
deltax = 2.5;
Lx = 500.0;
Ly = 500.0;
x = -Lx:deltax:Lx;
y = -Ly:deltax:Ly;
n = size(x)[1];
m = size(y)[1];
X,Y = meshgrid(x,y);
radius = 50; 
R = sqrt.(X.^2 + Y.^2);
h00 = zeros(m,n);
h00[R.<radius] .= 1;
## Problem Parameters
p = [0.5]    # True solution parameters
xtrs = [deltax, m,n]      # Extra parameters
dt = 0.40 * 1/ p[1]^2 * deltax^2    # CFL condition
t0, tMax = 0.0, 100 * dt
tspan = (t0, tMax)
t = t0:dt:tMax;

function heat(dhdt,h, p, t)
    # Model parameters
    m = size(h)[1];
    n = size(h)[2];
    deltax =  2.5;
    for j = 3:n-2
        for i = 3 : m-2 
            dhdt[i,j] =  p[1] *  ((h[i,j+1] + h[i,j-1] -2 * h[i,j]) / deltax^2 + (h[i+1,j] +  h[i-1,j] -2 * h[i,j]) / deltax^2);
        end 
    end 
    return dhdt
end
# Testing Solver on linear PDE
prob = ODEProblem(heat, h00, tspan, p)
sol = solve(prob, Heun(), dt = dt, saveat = t); 

ps = [0.1];   # Initial guess for model parameters
function predict(θ)
    Array(solve(prob, Heun(), p = θ, dt = dt, saveat = t)) 
end
## Defining Loss function
function loss(θ)
    pred = predict(θ)
    l = predict(θ) - sol 
    return sum(abs2, l), pred # Mean squared error
end
l, pred = loss(ps)
size(pred), size(sol), size(t) # Checking sizes

LOSS = []                              # Loss accumulator
PRED = []                              # prediction accumulator
PARS = []                              # parameters accumulator

callback = function (θ, l, pred) #callback function to observe training
    display(l)
    append!(PRED, [pred])
    append!(LOSS, l)
    append!(PARS, [θ])
    false
end

callback(ps, loss(ps)...) # Testing callback function

# Let see prediction vs. Truth
scatter(sol[Int64(floor(m / 2 )),:, end], label = "Truth", size = (800, 500))
plot!(PRED[end][Int64(floor(m / 2)),:, end], lw = 2, label = "Prediction")

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps) 
res = Optimization.solve(optprob, PolyOpt(), callback = callback)

So it looks like Optimization.OptimizationProblem and Optimization.solve do not support functions with signature

function RHS(dudt,u,p,t)

But only functions with signature

function RHS(u,p,t)

This is probably due to Zygote, which does not support array mutation (Limitations · Zygote).
Indeed if the code above is rewritten in this way, it works

using DelimitedFiles, Plots
using DifferentialEquations, Optimization, OptimizationPolyalgorithms, Zygote
using SciMLSensitivity
# Problem setup parameters:
function meshgrid(x, y)
     X = [i for i in x, j in 1:length(y)]
     Y = [j for i in 1:length(x), j in y]
     X = transpose(X);
     Y = transpose(Y);
     return X, Y
end
deltax = 2.5;
Lx = 500.0;
Ly = 500.0;
x = -Lx:deltax:Lx;
y = -Ly:deltax:Ly;
n = size(x)[1];
m = size(y)[1];
X,Y = meshgrid(x,y);
radius = 50; 
R = sqrt.(X.^2 + Y.^2);
h00 = zeros(m,n);
h00[R.<radius] .= 1;
## Problem Parameters
p = [0.5]    # True solution parameters
xtrs = [deltax, m,n]      # Extra parameters
dt = 0.40 * 1/ p[1]^2 * deltax^2    # CFL condition
t0, tMax = 0.0, 100 * dt
tspan = (t0, tMax)
t = t0:dt:tMax;

function heatVect(h, p, t)
    # Model parameters
    m = size(h)[1];
    n = size(h)[2];
    deltax =  2.5;
    dhdt =  p[1] *  ((h[3:m-2,4:n-1] + h[3:m-2,2:n-3] -2 * h[3:m-2,3:n-2]) / deltax^2 + (h[4:m-1,3:n-2] +  h[2:m-3,3:n-2] -2 * h[3:m-2,3:n-2]) / deltax^2);
    dhdt = [zeros(m-4,2) dhdt];
    dhdt = [dhdt zeros(m-4,2)];
    dhdt = [zeros(2,n);dhdt];
    dhdt = [dhdt;zeros(2,n)];
    return dhdt
end
# Testing Solver on linear PDE
prob = ODEProblem(heatVect, h00, tspan, p)
sol = solve(prob, Heun(), dt = dt, saveat = t); 

ps = [0.1];   # Initial guess for model parameters
function predict(θ)
    Array(solve(prob, Heun(), p = θ, dt = dt, saveat = t)) 
end
## Defining Loss function
function loss(θ)
    pred = predict(θ)
    l = predict(θ) - sol 
    return sum(abs2, l), pred # Mean squared error
end
l, pred = loss(ps)
size(pred), size(sol), size(t) # Checking sizes

LOSS = []                              # Loss accumulator
PRED = []                              # prediction accumulator
PARS = []                              # parameters accumulator

callback = function (θ, l, pred) #callback function to observe training
    display(l)
    append!(PRED, [pred])
    append!(LOSS, l)
    append!(PARS, [θ])
    false
end

callback(ps, loss(ps)...) # Testing callback function

# Let see prediction vs. Truth
scatter(sol[Int64(floor(m / 2 )),:, end], label = "Truth", size = (800, 500))
plot!(PRED[end][Int64(floor(m / 2)),:, end], lw = 2, label = "Prediction")

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)  
res = Optimization.solve(optprob, PolyOpt(), callback = callback)

Now the question is that of course in this way one cannot pre-allocate array, nor compute the finite difference stencils with loops (which in my case speed up the computation of the forward model by quite a lot). Does anybody knows how one could keep the performance obtained with pre-allocation and looping with Zygote?
Thanks!

Can you please post it?

Internally it swaps to Enzyme so if you’re using any recent version the in-place version is supported. I would need to see the stack trace to know what’s going on.

Sure, I attached in an jl file because it was to long to be posted…
stacktrace.jl (144.9 KB)