Hi All,
I’m trying to discretize a PDE with an integral term using NeuralPDE.jl. The problem discretizes fine with MethodOfLines.jl, but I get an error with NeuralPDE. The code is below - any suggestions?
using ModelingToolkit
using DifferentialEquations
using MethodOfLines
using NeuralPDE
using Lux
using Optimization
using OptimizationOptimisers
using DomainSets
using Plots
@parameters t a β γ S₀ I₀ R₀
@variables S(..) I(..) R(..)
Dt = Differential(t)
Da = Differential(a)
tmin = 0.0 # Minimum time
tmax = 40.0 # Maximum time
dt = 0.05 # Time step
amin = 0.0 # Minimum age
amax = 40.0 # Maximum age
da = 1.0 # Age step
Ia = Integral(a in DomainSets.ClosedInterval(amin,amax))
domains = [t ∈ (tmin, tmax), a ∈ (amin, amax)]
eqs = [Dt(S(t)) ~ -β * S(t) * Ia(I(a,t)),
Dt(I(a,t)) + Da(I(a,t)) ~ -γ*I(a,t),
Dt(R(t)) ~ γ*Ia(I(a,t))]
bcs = [S(0) ~ S₀,
I(0,t) ~ β*S(t)*Ia(I(a,t)),
I(a,0) ~ I₀/(amax-amin),
R(0) ~ R₀]
p = Dict([β=>0.5, γ=>0.25, S₀=>0.99, I₀=>0.01, R₀=>0.0])
@named pde_system = PDESystem(eqs,
bcs,
domains,
[a, t],
[S(t), I(a, t), R(t)],
[β, γ, S₀, I₀, R₀];
defaults=p);
# Discretize using the method of lines
discretization_mol = MOLFiniteDifference([a=>da],
t,
approx_order=2,
advection_scheme=UpwindScheme(2))
@time prob_mol = discretize(pde_system, discretization_mol)
@time sol_mol = solve(prob_mol, Rodas5(), dt=dt, saveat=0.1)
t_points = sol_mol.t
S_sol_mol = sol_mol[S(t)]
I_sol_mol = (sum(sol_mol[I(a, t)], dims=1)[1,:]).*da # Sum over age groups
R_sol_mol = sol_mol[R(t)];
plot(t_points, S_sol_mol, label="S", xlabel="Time", ylabel="Number")
plot!(t_points, I_sol_mol, label="I")
plot!(t_points, R_sol_mol, label="R")
# Discretize by PINN
chain = Lux.Chain(Dense(1, 16, Lux.σ), Dense(16, 16, Lux.σ), Dense(16, 1))
discretization_pinn = PhysicsInformedNN(chain, QuadratureTraining())
prob_pinn = discretize(pde_system, discretization_pinn)
callback = function (p, l)
println("Current loss is: $l")
return false
end
@time res = Optimization.solve(prob_pinn, ADAM(0.1); callback = callback, maxiters = 4000)
The last line throws this error:
ERROR: MethodError: no method matching getindex(::Nothing, ::UnitRange{Int64})
The function `getindex` exists, but no method is defined for this combination of argument types.
Stacktrace:
[1] (::Cubature.var"#17#18"{…})()
@ Cubature ~/.julia/packages/Cubature/5zwuu/src/Cubature.jl:215
[2] disable_sigint
@ ./c.jl:167 [inlined]
[3] cubature(xscalar::Bool, fscalar::Bool, vectorized::Bool, padaptive::Bool, fdim::Int64, f::IntegralsCubatureExt.var"#3#10"{…}, xmin_::Vector{…}, xmax_::Vector{…}, reqRelError::Float64, reqAbsError::Float64, maxEval::Int64, error_norm::Int32)
@ Cubature ~/.julia/packages/Cubature/5zwuu/src/Cubature.jl:169
[4] hcubature_v
@ ~/.julia/packages/Cubature/5zwuu/src/Cubature.jl:230 [inlined]
[5] __solvebp_call(prob::IntegralProblem{…}, alg::Integrals.CubatureJLh, sensealg::Integrals.ReCallVJP{…}, domain::Tuple{…}, p::ComponentArrays.ComponentVector{…}; reltol::Float64, abstol::Float64, maxiters::Int64)
@ IntegralsCubatureExt ~/.julia/packages/Integrals/e3NH3/ext/IntegralsCubatureExt.jl:48
[6] __solvebp_call
@ ~/.julia/packages/Integrals/e3NH3/ext/IntegralsCubatureExt.jl:7 [inlined]
[7] #__solvebp_call#4
@ ~/.julia/packages/Integrals/e3NH3/src/common.jl:118 [inlined]
[8] __solvebp_call
@ ~/.julia/packages/Integrals/e3NH3/src/common.jl:117 [inlined]
[9] #rrule#28
@ ~/.julia/packages/Integrals/e3NH3/ext/IntegralsZygoteExt.jl:53 [inlined]
[10] rrule
@ ~/.julia/packages/Integrals/e3NH3/ext/IntegralsZygoteExt.jl:49 [inlined]
[11] rrule
@ ~/.julia/packages/ChainRulesCore/U6wNx/src/rules.jl:144 [inlined]
[12] chain_rrule_kw
@ ~/.julia/packages/Zygote/zowwZ/src/compiler/chainrules.jl:236 [inlined]
[13] macro expansion
@ ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0 [inlined]
[14] _pullback
@ ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:91 [inlined]
[15] _apply
@ ./boot.jl:946 [inlined]
[16] adjoint
@ ~/.julia/packages/Zygote/zowwZ/src/lib/lib.jl:202 [inlined]
[17] _pullback
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
[18] #__solve#50
@ ~/.julia/packages/Integrals/e3NH3/src/Integrals.jl:69 [inlined]
[19] _apply
@ ./boot.jl:946 [inlined]
[20] adjoint
@ ~/.julia/packages/Zygote/zowwZ/src/lib/lib.jl:202 [inlined]
[21] _pullback
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
[22] __solve
@ ~/.julia/packages/Integrals/e3NH3/src/Integrals.jl:69 [inlined]
[23] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Integrals.__solve), ::Integrals.IntegralCache{…}, ::Integrals.CubatureJLh, ::Integrals.ReCallVJP{…}, ::Tuple{…}, ::ComponentArrays.ComponentVector{…})
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
[24] #__solve#52
@ ~/.julia/packages/Integrals/e3NH3/src/Integrals.jl:92 [inlined]
[25] _pullback(::Zygote.Context{…}, ::Integrals.var"##__solve#52", ::@Kwargs{…}, ::typeof(Integrals.__solve), ::Integrals.IntegralCache{…}, ::Integrals.ChangeOfVariables{…}, ::Integrals.ReCallVJP{…}, ::Tuple{…}, ::ComponentArrays.ComponentVector{…})
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
[26] __solve
@ ~/.julia/packages/Integrals/e3NH3/src/Integrals.jl:79 [inlined]
[27] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(Integrals.__solve), ::Integrals.IntegralCache{…}, ::Integrals.ChangeOfVariables{…}, ::Integrals.ReCallVJP{…}, ::Tuple{…}, ::ComponentArrays.ComponentVector{…})
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
[28] solve!
@ ~/.julia/packages/Integrals/e3NH3/src/common.jl:108 [inlined]
[29] _pullback(ctx::Zygote.Context{…}, f::typeof(solve!), args::Integrals.IntegralCache{…})
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
[30] #solve#3
@ ~/.julia/packages/Integrals/e3NH3/src/common.jl:104 [inlined]
[31] _pullback(::Zygote.Context{…}, ::Integrals.var"##solve#3", ::@Kwargs{…}, ::typeof(solve), ::IntegralProblem{…}, ::Integrals.CubatureJLh)
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
[32] solve
@ ~/.julia/packages/Integrals/e3NH3/src/common.jl:101 [inlined]
[33] #115
@ ~/.julia/packages/NeuralPDE/nkWKK/src/training_strategies.jl:354 [inlined]
[34] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#115#118"{…}, ::Vector{…}, ::Vector{…}, ::NeuralPDE.var"#242#243"{…}, ::ComponentArrays.ComponentVector{…})
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
[35] #116
@ ~/.julia/packages/NeuralPDE/nkWKK/src/training_strategies.jl:357 [inlined]
[36] _pullback(ctx::Zygote.Context{…}, f::NeuralPDE.var"#116#120"{…}, args::ComponentArrays.ComponentVector{…})
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
[37] #308
@ ./none:0 [inlined]
[38] _pullback(ctx::Zygote.Context{…}, f::NeuralPDE.var"#308#329"{…}, args::NeuralPDE.var"#116#120"{…})
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
[39] (::Zygote.var"#667#671"{Zygote.Context{…}, NeuralPDE.var"#308#329"{…}})(args::Function)
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/lib/array.jl:188
[40] iterate
@ ./generator.jl:48 [inlined]
[41] _collect(c::Vector{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
@ Base ./array.jl:811
[42] collect_similar
@ ./array.jl:720 [inlined]
[43] map
@ ./abstractarray.jl:3371 [inlined]
[44] ∇map(cx::Zygote.Context{…}, f::NeuralPDE.var"#308#329"{…}, args::Vector{…})
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/lib/array.jl:188
[45] _pullback(cx::Zygote.Context{false}, ::typeof(collect), g::Base.Generator{Vector{…}, NeuralPDE.var"#308#329"{…}})
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/lib/array.jl:231
[46] full_loss_function
@ ~/.julia/packages/NeuralPDE/nkWKK/src/discretize.jl:462 [inlined]
[47] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#full_loss_function#328"{…}, ::ComponentArrays.ComponentVector{…}, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface2.jl:0
[48] pullback(::Function, ::Zygote.Context{…}, ::ComponentArrays.ComponentVector{…}, ::Vararg{…})
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface.jl:90
[49] pullback(::Function, ::ComponentArrays.ComponentVector{Float64, Vector{…}, Tuple{…}}, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface.jl:88
[50] withgradient(::Function, ::ComponentArrays.ComponentVector{Float64, Vector{…}, Tuple{…}}, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/zowwZ/src/compiler/interface.jl:205
[51] value_and_gradient
@ ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:115 [inlined]
[52] value_and_gradient!(f::Function, grad::ComponentArrays.ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep{…}, backend::AutoZygote, x::ComponentArrays.ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
@ DifferentiationInterfaceZygoteExt ~/.julia/packages/DifferentiationInterface/zJHX8/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:131
[53] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::ComponentArrays.ComponentVector{…}, θ::ComponentArrays.ComponentVector{…})
@ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/UXLhR/ext/OptimizationZygoteExt.jl:53
[54] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/xC7Ic/src/OptimizationOptimisers.jl:101 [inlined]
[55] macro expansion
@ ~/.julia/packages/Optimization/e1Lg1/src/utils.jl:32 [inlined]
[56] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/xC7Ic/src/OptimizationOptimisers.jl:83
[57] solve!(cache::OptimizationCache{…})
@ SciMLBase ~/.julia/packages/SciMLBase/m1Jrs/src/solve.jl:227
[58] solve(::OptimizationProblem{…}, ::Adam{…}; kwargs::@Kwargs{…})
@ SciMLBase ~/.julia/packages/SciMLBase/m1Jrs/src/solve.jl:129
[59] macro expansion
@ ./timing.jl:581 [inlined]