Using NeuralPDE for prediction?

Hi, I’m pretty new to Julia ML and I’m trying to use NeuralPDE.jl to solve a system of ODEs. I’m following this tutorial and am able to get a pretty good fit close to the true solution. But is there a way I can use the pre-trained PINN to predict solutions for a different set of initial conditions directly without having to go through the training process again (which takes a while)? Or is that not the purpose of the package?

For that, treat it as a PDE with initial conditions being the different independent variable.

Ok I’m trying to do that using the below code but I get this error. I’m not sure what this means - any idea for how to resolve it?

julia> res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters=200)
ERROR: BoundsError: attempt to access Tuple{Matrix{Float64}} at index [2]

I’m using the code below (following method suggested here to specify the initial conditions)

using NeuralPDE, Lux, ModelingToolkit
using Optimization, OptimizationOptimJL, OptimizationOptimisers
import ModelingToolkit: Interval, infimum, supremum

# PDE with initial conditions being the different independent variable

@parameters t, δ0, ω0
@variables δ(..), ω(..)

Dt = Differential(t)

P_1 = 0.08;
B_12 = 0.2;
V_1 = V_2 = 1.0; 
m_1 = 0.1;
d_1 = 0.05;

# 1D ODE
eqs  = [Dt(δ(t,δ0,ω0)) - ω(t,δ0,ω0) ~ 0.0,
    m_1 * Dt(ω(t,δ0,ω0)) + d_1 * Dt(δ(t,δ0,ω0)) + B_12 * V_1 * V_2 * sin(δ(t,δ0,ω0)) - P_1 ~ 0.0];

# Initial conditions
bcs = [δ(0.,δ0,ω0) ~ δ0, 
       ω(0.,δ0,ω0) ~ ω0] ;

# Time domain
domains = [t ∈ Interval(0.0,1.0),
    δ0 ∈ Interval(0.0,1.0),
    ω0 ∈ Interval(0.0,1.0)];

# Neural network
input_ = length(domains)
n = 15
chain =[Lux.Chain(Dense(input_,n,Lux.σ),Dense(n,n,Lux.σ),Dense(n,1)) for _ in 1:2] # 1:number of @variables

# strategy = QuasiRandomTraining(20)
strategy = QuadratureTraining()
discretization = PhysicsInformedNN(chain, strategy)
@named pde_system = PDESystem(eq,bcs,domains,[t,δ0,ω0],[δ(t,δ0,ω0), ω(t,δ0,ω0)])
prob = discretize(pde_system,discretization)

callback = function (p,l)
    println("Current loss is: $l")
    return false
end

res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters=200)
phi = discretization.phi