Hi,
I’m trying to train a PINN using NeuralPDE.jl. I was able to start the training and achieved (at least in my opinion) a quiet small loss < 1e-7
. But when I plotted the network solution against the analytic solution the two results looked very different.
That’s why I wanted to understand the loss function. Sadly I was not able to find anything about how it is definded, except that it is a combination of the boundary condition loss and the loss for the pde. So I would be very happy if someone here could help me understand how the loss is defined.
For you to better understand the problem, I have prepared my minimal working example:
using Lux, NeuralPDE, OptimizationOptimisers, Plots, Random
import ModelingToolkit: Interval
@parameters t
@variables u(..)
Dt = Differential(t)
eq = Dt(u(t)) ~ (u(t))^2 - (u(t))^3
δ = 0.01
bcs = [u(0.0) ~ δ]
domains = [t ∈ Interval(0.0, 2 / δ)]
chain = Lux.Chain(Dense(1, 8, Lux.σ), Dense(8, 8, Lux.σ), Dense(8, 1))
initParams, initState = Lux.setup(Random.default_rng(), chain)
optimiser = OptimizationOptimisers.Adam(0.01)
discretization = PhysicsInformedNN(chain, QuasiRandomTraining(1000))
@named pde_system = PDESystem(eq, bcs, domains, [t], [u(t)])
prob = discretize(pde_system, discretization)
callback = function (params, loss_val)
ps, st = Lux.setup(Random.default_rng(), chain)
y, st = Lux.apply(chain, [0.5], params, st)
println("loss = ", loss_val)
# stop optimization when loss value is smaller than 1e-7
return loss_val < 1e-7
end
params_trained = solve(prob, optimiser, callback=callback, maxiters=2000, save_best=true)
state_trained = discretization.phi
analytic_sol_func(t) = 1 / (lambertw((1 / δ - 1) * exp(1 / δ - 1 - t)) + 1)
# plot result
dx = 0.001
xs = [infimum(d.domain):(dx/10):supremum(d.domain) for d in domains][1]
u_real = analytic_sol_func.(xs)
state_with_params(x) = state_trained(x, params_trained.u)
u_predict = first.(state_with_params.(xs))
x_plot = collect(xs)
plot(x_plot, u_real, label="real")
plot!(x_plot, u_predict, label="predict")
And here is an image of the final plot:
The blue line in the image is the analytical solution and the red/orange line represents the output of the network. And as you can see those two differ quiet much.