I am trying to try to solve the adjoint problem of 2d diffusion using SciMLSensitivity. My test case follow the 1d case given at Partial Differential Equation (PDE) Constrained Optimization · SciMLSensitivity.jl
However, my case does not work and Optimization.solve gives a very long (and quite cryptic to me) stacktrace. Any idea of what is causing the issue? Thanks!
Here is the code
using DelimitedFiles, Plots
using DifferentialEquations, Optimization, OptimizationPolyalgorithms, Zygote
using SciMLSensitivity
# Problem setup parameters:
function meshgrid(x, y)
X = [i for i in x, j in 1:length(y)]
Y = [j for i in 1:length(x), j in y]
X = transpose(X);
Y = transpose(Y);
return X, Y
end
deltax = 2.5;
Lx = 500.0;
Ly = 500.0;
x = -Lx:deltax:Lx;
y = -Ly:deltax:Ly;
n = size(x)[1];
m = size(y)[1];
X,Y = meshgrid(x,y);
radius = 50;
R = sqrt.(X.^2 + Y.^2);
h00 = zeros(m,n);
h00[R.<radius] .= 1;
## Problem Parameters
p = [0.5] # True solution parameters
xtrs = [deltax, m,n] # Extra parameters
dt = 0.40 * 1/ p[1]^2 * deltax^2 # CFL condition
t0, tMax = 0.0, 100 * dt
tspan = (t0, tMax)
t = t0:dt:tMax;
function heat(dhdt,h, p, t)
# Model parameters
m = size(h)[1];
n = size(h)[2];
deltax = 2.5;
for j = 3:n-2
for i = 3 : m-2
dhdt[i,j] = p[1] * ((h[i,j+1] + h[i,j-1] -2 * h[i,j]) / deltax^2 + (h[i+1,j] + h[i-1,j] -2 * h[i,j]) / deltax^2);
end
end
return dhdt
end
# Testing Solver on linear PDE
prob = ODEProblem(heat, h00, tspan, p)
sol = solve(prob, Heun(), dt = dt, saveat = t);
ps = [0.1]; # Initial guess for model parameters
function predict(θ)
Array(solve(prob, Heun(), p = θ, dt = dt, saveat = t))
end
## Defining Loss function
function loss(θ)
pred = predict(θ)
l = predict(θ) - sol
return sum(abs2, l), pred # Mean squared error
end
l, pred = loss(ps)
size(pred), size(sol), size(t) # Checking sizes
LOSS = [] # Loss accumulator
PRED = [] # prediction accumulator
PARS = [] # parameters accumulator
callback = function (θ, l, pred) #callback function to observe training
display(l)
append!(PRED, [pred])
append!(LOSS, l)
append!(PARS, [θ])
false
end
callback(ps, loss(ps)...) # Testing callback function
# Let see prediction vs. Truth
scatter(sol[Int64(floor(m / 2 )),:, end], label = "Truth", size = (800, 500))
plot!(PRED[end][Int64(floor(m / 2)),:, end], lw = 2, label = "Prediction")
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)
res = Optimization.solve(optprob, PolyOpt(), callback = callback)