I’m trying to debug why Zygote and ReverseDiff don’t work with MethodOfLines, however, I’m struggling to break inside the gradient calculation. Zygote and ReverseDiff are changing the types of some inputs. Breakpoints in the Debugger aren’t working, and @infiltrate
will cause precompilation to fail. Simple example:
using DifferentialEquations, ModelingToolkit, MethodOfLines, DomainSets
using PDEBase: add_metadata!
using ModelingToolkit: get_metadata
using Zygote
using ReverseDiff
import AbstractDifferentiation as AD
using SciMLSensitivity
# Method of Manufactured Solutions: exact solution
u_exact = (x,t) -> exp.(-t) * cos.(x)
# Parameters, variables, and derivatives
@parameters x t
@variables u(..)
@parameters α β
Dt = Differential(t)
Dxx = Differential(x)^2
# 1D PDE and boundary conditions
eq = Dt(u(t, x)) ~(α + β) * Dxx(u(t, x))
bcs = [u(0, x) ~ cos(x),
u(t, 0) ~ exp(-t),
u(t, 1) ~ exp(-t) * cos(1)]
# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
x ∈ Interval(0.0, 1.0)]
# Parameters
ps = [α => 1.2, β => 2.1]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], ps)
# Method of lines discretization
dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)
# Convert the PDE problem into an ODE problem
sys,tspan = symbolic_discretize(pdesys,discretization)
simpsys = structural_simplify(sys)
add_metadata!(get_metadata(simpsys), sys)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, Pair[], tspan; discretization.kwargs...)
function remake_p(prob::ODEProblem, p; simpsys=simpsys)
tspan = prob.tspan
u0 = prob.u0
_f = prob.f
ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, u0, tspan, p)#, prob.problem_type )
end
param_vars = [α, β]
idxs = ModelingToolkit.varmap_to_vars([param_vars[1] => 1, param_vars[2] => 2], param_vars)
test_p = [1.2, 1.4]
test_p[Int.(idxs)]
function pde_solution2(ps)
#_prob = remake_p(prob, ps)
ps = ps[Int.(idxs)]
sol = solve(prob, Tsit5(), saveat=0.1, p=ps, wrap=Val(false))
return sum(Array(sol)[end, :])
end
using Zygote
pde_solution2([1.2,.3]);
ADzyg = AD.ZygoteBackend()
function grad(ps)
AD.gradient(ADzyg, pde_solution2, ps)
end
grad(rand(2))