Breakpoint/infiltrate inside a gradient calculation (Zygote or ReverseDiff)?

I’m trying to debug why Zygote and ReverseDiff don’t work with MethodOfLines, however, I’m struggling to break inside the gradient calculation. Zygote and ReverseDiff are changing the types of some inputs. Breakpoints in the Debugger aren’t working, and @infiltrate will cause precompilation to fail. Simple example:

using DifferentialEquations, ModelingToolkit, MethodOfLines, DomainSets

using PDEBase: add_metadata!
using ModelingToolkit: get_metadata

using Zygote
using ReverseDiff
import AbstractDifferentiation as AD
using SciMLSensitivity
# Method of Manufactured Solutions: exact solution
u_exact = (x,t) -> exp.(-t) * cos.(x)

# Parameters, variables, and derivatives
@parameters x t
@variables u(..)
@parameters α β
Dt = Differential(t)
Dxx = Differential(x)^2

# 1D PDE and boundary conditions
eq  = Dt(u(t, x)) ~(α + β) * Dxx(u(t, x))
bcs = [u(0, x) ~ cos(x),
        u(t, 0) ~ exp(-t),
        u(t, 1) ~ exp(-t) * cos(1)]

# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
           x ∈ Interval(0.0, 1.0)]

# Parameters
ps = [α => 1.2, β => 2.1]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], ps)

# Method of lines discretization
dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)

# Convert the PDE problem into an ODE problem
sys,tspan = symbolic_discretize(pdesys,discretization)
simpsys = structural_simplify(sys)
add_metadata!(get_metadata(simpsys), sys)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, Pair[], tspan; discretization.kwargs...)


function remake_p(prob::ODEProblem, p; simpsys=simpsys)
    tspan = prob.tspan
    u0 = prob.u0

    _f = prob.f
    ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, u0, tspan, p)#, prob.problem_type )
end

param_vars = [α, β]
idxs = ModelingToolkit.varmap_to_vars([param_vars[1] => 1, param_vars[2] => 2], param_vars)

test_p = [1.2, 1.4]
test_p[Int.(idxs)]
function pde_solution2(ps)
    #_prob = remake_p(prob, ps)
    ps = ps[Int.(idxs)]
    sol = solve(prob, Tsit5(), saveat=0.1, p=ps, wrap=Val(false))
    return sum(Array(sol)[end, :])
end

using Zygote

pde_solution2([1.2,.3]);
ADzyg = AD.ZygoteBackend()
function grad(ps)
    AD.gradient(ADzyg, pde_solution2, ps)
end

grad(rand(2))

What error are you running into?

Revise is throwing @infiltrate not defined error when I try to load a package with @infiltrate in it, don’t know why I’m thinking precompilation error…

Yes, Infiltrator is in that package’s deps (In Project.toml)

Did you do using Infiltrate in addition to adding it as a dependency?

If you use Revise, then it’s not necessary to add Infiltrator to your packages dependencies. Instead I’d recommend putting Main.@infiltrate statements into your function.

Also check out the other possibilities for using Infiltrator here.

2 Likes