Training a neural network within method of lines and ODE solve framework

I am trying to train a neural network which is a part of a PDE, I am solving the PDE with MOL and ODE solve. Below is the forward pass i.e getting the solution with ODE solve which runs fine.

import ModelingToolkit: Interval, infimum, supremum
using DiffEqSensitivity
using OrdinaryDiffEq, ModelingToolkit, MethodOfLines, DomainSets
using Plots, Flux
using Interpolations

x_dom = 0:0.1:2
dim = 1
net = Flux.Chain(Dense(dim,16,Flux.σ),Dense(16,16,Flux.σ),Dense(16,1))
p = Flux.params(net)
# p, re = Flux.destructure(net)
A = [net([x1])[1] for x1 in x_dom]

itp = interpolate(A, BSpline(Cubic(Line(OnGrid()))))
itps = scale(itp, x_dom)
net_func(x) = itps(x)

@parameters t x
@variables u(..) 
@register_symbolic net_func(x)

Dxx = Differential(x)^2
Dt = Differential(t)

t_min= 0.
t_max = 2.0
x_min = 0.
x_max = 2.
y_min = 0.
y_max = 2.
dim = 2

eq  = Dt(u(t,x)) ~ Dxx(u(t,x)) + net_func(x)

bcs = [u(t_min,x) ~ sin(2*pi*x),
       u(t,x_min) ~ 0.0,
       u(t,x_max) ~ 1.0]

domains = [t ∈ Interval(t_min,t_max),
           x ∈ Interval(x_min,x_max)]

dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t, approx_order=order)
                
@named pde_system = PDESystem(eq,bcs,domains,[t,x],[u(t, x)])
prob = discretize(pde_system,discretization)
sol = solve(prob, p = p, Tsit5(), saveat=0.1)

However, I am not able to train this network, below is the code for the loss function I am using, which is chosen just to see if a backward pass is possible or not.

function predict_sol()
    Array(solve(prob,Tsit5(),p = p, progress=true))

end

function loss_func()
    _sol = predict_sol()
    loss = sum(abs2,_sol)
    @show loss
    loss
end

data = Iterators.repeated((), 2)
opt = ADAM(0.025)
 
Flux.train!(loss_func, p, data, opt)

I am curious if it is possible to train a network within MOL framework if so, then can someone help me figure out where I might be wrong?

1 Like

This should work now and will auto-choose EnzymeVJP. Let me know if you continue to see any issues.

1 Like

Is there a way to register the neural network as symbolic without interpolation? I think interpolation will not work for training the network.

DataInterpolations.jl works fine with training.