Problems porting discretization from MethodOfLines.jl to NeuralPDE.jl

Hi All,

I’m trying to discretize a PDE with an integral term using NeuralPDE.jl. The problem discretizes fine with MethodOfLines.jl, but I get an error with NeuralPDE. The code is below - any suggestions?

using ModelingToolkit
using DifferentialEquations
using MethodOfLines
using NeuralPDE
using Lux
using Optimization
using OptimizationOptimisers
using DomainSets
using Plots

@parameters t a β γ S₀ I₀ R₀
@variables S(..) I(..) R(..)
Dt = Differential(t) 
Da = Differential(a)
tmin = 0.0       # Minimum time
tmax = 40.0      # Maximum time
dt = 0.05        # Time step
amin = 0.0       # Minimum age
amax = 40.0      # Maximum age
da = 1.0         # Age step
Ia = Integral(a in DomainSets.ClosedInterval(amin,amax))
domains = [t ∈ (tmin, tmax), a ∈ (amin, amax)]

eqs = [Dt(S(t)) ~ -β * S(t) * Ia(I(a,t)), 
       Dt(I(a,t)) + Da(I(a,t)) ~ -γ*I(a,t),
       Dt(R(t)) ~ γ*Ia(I(a,t))]

bcs = [S(0) ~ S₀,
       I(0,t) ~ β*S(t)*Ia(I(a,t)),
       I(a,0) ~ I₀/(amax-amin),
       R(0) ~ R₀]

p = Dict([β=>0.5, γ=>0.25, S₀=>0.99, I₀=>0.01, R₀=>0.0])
@named pde_system = PDESystem(eqs,
                              bcs,
                              domains,
                              [a, t],
                              [S(t), I(a, t), R(t)],
                              [β, γ, S₀, I₀, R₀];
                              defaults=p);

# Discretize using the method of lines
discretization_mol = MOLFiniteDifference([a=>da],
                                          t,
                                          approx_order=2,
                                          advection_scheme=UpwindScheme(2))
@time prob_mol = discretize(pde_system, discretization_mol)
@time sol_mol = solve(prob_mol, Rodas5(), dt=dt, saveat=0.1)
t_points = sol_mol.t
S_sol_mol = sol_mol[S(t)]
I_sol_mol = (sum(sol_mol[I(a, t)], dims=1)[1,:]).*da  # Sum over age groups
R_sol_mol = sol_mol[R(t)];

plot(t_points, S_sol_mol, label="S", xlabel="Time", ylabel="Number")
plot!(t_points, I_sol_mol, label="I")
plot!(t_points, R_sol_mol, label="R")

# Discretize by PINN
chain = Lux.Chain(Dense(1, 16, Lux.σ), Dense(16, 16, Lux.σ), Dense(16, 1))
discretization_pinn = PhysicsInformedNN(chain, QuadratureTraining())
prob_pinn = discretize(pde_system, discretization_pinn)

callback = function (p, l)
    println("Current loss is: $l")
    return false
end

@time res = Optimization.solve(prob_pinn, ADAM(0.1); callback = callback, maxiters = 4000)

The last line throws this error:

ERROR: MethodError: no method matching getindex(::Nothing, ::UnitRange{Int64})
The function `getindex` exists, but no method is defined for this combination of argument types.
Stacktrace:
  [1] (::Cubature.var"#17#18"{…})()
    @ Cubature ~/.julia/packages/Cubature/5zwuu/src/Cubature.jl:215
  [2] disable_sigint
    @ ./c.jl:167 [inlined]
  [3] cubature(xscalar::Bool, fscalar::Bool, vectorized::Bool, padaptive::Bool, fdim::Int64, f::IntegralsCubatureExt.var"#3#10"{…}, xmin_::Vector{…}, xmax_::Vector{…}, reqRelError::Float64, reqAbsError::Float64, maxEval::Int64, error_norm::Int32)
    @ Cubature ~/.julia/packages/Cubature/5zwuu/src/Cubature.jl:169
  [4] hcubature_v
    @ ~/.julia/packages/Cubature/5zwuu/src/Cubature.jl:230 [inlined]
  [5] __solvebp_call(prob::IntegralProblem{…}, alg::Integrals.CubatureJLh, sensealg::Integrals.ReCallVJP{…}, domain::Tuple{…}, p::ComponentArrays.ComponentVector{…}; reltol::Float64, abstol::Float64, maxiters::Int64)
    @ IntegralsCubatureExt ~/.julia/packages/Integrals/e3NH3/ext/IntegralsCubatureExt.jl:48
  [6] __solvebp_call
    @ ~/.julia/packages/Integrals/e3NH3/ext/IntegralsCubatureExt.jl:7 [inlined]
  [7] #__solvebp_call#4
    @ ~/.julia/packages/Integrals/e3NH3/src/common.jl:118 [inlined]
  [8] __solvebp_call
    @ ~/.julia/packages/Integrals/e3NH3/src/common.jl:117 [inlined]
  [9] #rrule#28
    @ ~/.julia/packages/Integrals/e3NH3/ext/IntegralsZygoteExt.jl:53 [inlined]
 [10] rrule
    @ ~/.julia/packages/Integrals/e3NH3/ext/IntegralsZygoteExt.jl:49 [inlined]
 [11] rrule
    @ ~/.julia/packages/ChainRulesCore/U6wNx/src/rules.jl:144 [inlined]
 [12] chain_rrule_kw
    @ ~/.julia/packages/Zygote/zowwZ/src/compiler/chainrules.jl:236 [inlined]
 [13] macro expansion
    @ ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0 [inlined]
 [14] _pullback
    @ ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:91 [inlined]
 [15] _apply
    @ ./boot.jl:946 [inlined]
 [16] adjoint
    @ ~/.julia/packages/Zygote/zowwZ/src/lib/lib.jl:202 [inlined]
 [17] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [18] #__solve#50
    @ ~/.julia/packages/Integrals/e3NH3/src/Integrals.jl:69 [inlined]
 [19] _apply
    @ ./boot.jl:946 [inlined]
 [20] adjoint
    @ ~/.julia/packages/Zygote/zowwZ/src/lib/lib.jl:202 [inlined]
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [22] __solve
    @ ~/.julia/packages/Integrals/e3NH3/src/Integrals.jl:69 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Integrals.__solve), ::Integrals.IntegralCache{…}, ::Integrals.CubatureJLh, ::Integrals.ReCallVJP{…}, ::Tuple{…}, ::ComponentArrays.ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
 [24] #__solve#52
    @ ~/.julia/packages/Integrals/e3NH3/src/Integrals.jl:92 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::Integrals.var"##__solve#52", ::@Kwargs{…}, ::typeof(Integrals.__solve), ::Integrals.IntegralCache{…}, ::Integrals.ChangeOfVariables{…}, ::Integrals.ReCallVJP{…}, ::Tuple{…}, ::ComponentArrays.ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
 [26] __solve
    @ ~/.julia/packages/Integrals/e3NH3/src/Integrals.jl:79 [inlined]
 [27] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Integrals.__solve), ::Integrals.IntegralCache{…}, ::Integrals.ChangeOfVariables{…}, ::Integrals.ReCallVJP{…}, ::Tuple{…}, ::ComponentArrays.ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
 [28] solve!
    @ ~/.julia/packages/Integrals/e3NH3/src/common.jl:108 [inlined]
 [29] _pullback(ctx::Zygote.Context{…}, f::typeof(solve!), args::Integrals.IntegralCache{…})
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
 [30] #solve#3
    @ ~/.julia/packages/Integrals/e3NH3/src/common.jl:104 [inlined]
 [31] _pullback(::Zygote.Context{…}, ::Integrals.var"##solve#3", ::@Kwargs{…}, ::typeof(solve), ::IntegralProblem{…}, ::Integrals.CubatureJLh)
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
 [32] solve
    @ ~/.julia/packages/Integrals/e3NH3/src/common.jl:101 [inlined]
 [33] #115
    @ ~/.julia/packages/NeuralPDE/nkWKK/src/training_strategies.jl:354 [inlined]
 [34] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#115#118"{…}, ::Vector{…}, ::Vector{…}, ::NeuralPDE.var"#242#243"{…}, ::ComponentArrays.ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
 [35] #116
    @ ~/.julia/packages/NeuralPDE/nkWKK/src/training_strategies.jl:357 [inlined]
 [36] _pullback(ctx::Zygote.Context{…}, f::NeuralPDE.var"#116#120"{…}, args::ComponentArrays.ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
 [37] #308
    @ ./none:0 [inlined]
 [38] _pullback(ctx::Zygote.Context{…}, f::NeuralPDE.var"#308#329"{…}, args::NeuralPDE.var"#116#120"{…})
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
 [39] (::Zygote.var"#667#671"{Zygote.Context{…}, NeuralPDE.var"#308#329"{…}})(args::Function)
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/lib/array.jl:188
 [40] iterate
    @ ./generator.jl:48 [inlined]
 [41] _collect(c::Vector{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
    @ Base ./array.jl:811
 [42] collect_similar
    @ ./array.jl:720 [inlined]
 [43] map
    @ ./abstractarray.jl:3371 [inlined]
 [44] ∇map(cx::Zygote.Context{…}, f::NeuralPDE.var"#308#329"{…}, args::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/lib/array.jl:188
 [45] _pullback(cx::Zygote.Context{false}, ::typeof(collect), g::Base.Generator{Vector{…}, NeuralPDE.var"#308#329"{…}})
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/lib/array.jl:231
 [46] full_loss_function
    @ ~/.julia/packages/NeuralPDE/nkWKK/src/discretize.jl:462 [inlined]
 [47] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#full_loss_function#328"{…}, ::ComponentArrays.ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
 [48] pullback(::Function, ::Zygote.Context{…}, ::ComponentArrays.ComponentVector{…}, ::Vararg{…})
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface.jl:90
 [49] pullback(::Function, ::ComponentArrays.ComponentVector{Float64, Vector{…}, Tuple{…}}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface.jl:88
 [50] withgradient(::Function, ::ComponentArrays.ComponentVector{Float64, Vector{…}, Tuple{…}}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface.jl:205
 [51] value_and_gradient
    @ ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:115 [inlined]
 [52] value_and_gradient!(f::Function, grad::ComponentArrays.ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep{…}, backend::AutoZygote, x::ComponentArrays.ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceZygoteExt ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:131
 [53] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::ComponentArrays.ComponentVector{…}, θ::ComponentArrays.ComponentVector{…})
    @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/UXLhR/ext/OptimizationZygoteExt.jl:53
 [54] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/xC7Ic/src/OptimizationOptimisers.jl:101 [inlined]
 [55] macro expansion
    @ ~/.julia/packages/Optimization/e1Lg1/src/utils.jl:32 [inlined]
 [56] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/xC7Ic/src/OptimizationOptimisers.jl:83
 [57] solve!(cache::OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/m1Jrs/src/solve.jl:227
 [58] solve(::OptimizationProblem{…}, ::Adam{…}; kwargs::@Kwargs{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/m1Jrs/src/solve.jl:129
 [59] macro expansion
    @ ./timing.jl:581 [inlined]

We’re redoing the NeuralPDE parser as a summer GSoC since we know it needs work:

I’d put this into the box of “check back later this summer when the new form merges”. In particular, handling integral terms is something that should get simplified and thus as a byproduct, get a lot of bugfixes.