Save and load NeuralPDE model for postprocessing


I want to save a trained NeuralPDE model and then load it to some other script at a later time for analysis and post-processing.

In this reply it is suggested to save the trained Lux parameters stored in res.u. I am doing that through

using JLD2
# @save "trained_model.jld2" res.u

However, the actual representation of the solution is not stored there and I am not sure where should I feed these parameters to get the trained model. Should one recreate the entire sys, strategy, discretization, prob sequence and then use the stored pretrained parameters? This is rather slow as it unnecessarily compiles a bunch of things that are not going to be used. The solution I am thinking is to save discretization.phi or res.cache.f.f.phi, bu can these be saved in the same way as res.u?

Thanks in advance for the help.

You just re-evaluate the neural architecture.

You don’t need any of those to evaluate the trained model.

Take for example the starting tutorial on PDEs:

The neural architecture is

chain = Lux.Chain(Dense(dim, 16, Lux.σ), Dense(16, 16, Lux.σ), Dense(16, 1))

We can initialize it the standard Lux way:

rng = Random.default_rng()
p, st = Lux.setup(rng, U)
const _st = st

chain is a function that takes (x,p,st). Your trained parameters p is res.u from before. So you just need to evaluate it with chain(x, res.u, _st) and you’re good!

In total this code looks like:

using JLD2, Lux, Random, ComponentArrays
trained_p = load("example.jld2")["res.u"]

rng = Random.default_rng()
p, st = Lux.setup(rng, U)
const _st = st

# Evaluate at new points x
chain(x, trained_p, _st)

ComponentArrays is just needed because the saved trained_p will be a ComponentArray by default, though you could convert to a NamedTuple and that would remove that dep requirement.

Yes you are right. I was somehow confused and thought that you need discretization.phi in order to evaluate the solution. Thanks for the clarification.

By the way, why do you do const _st = st? Does it have any convenience or performance benefit?

There’s a tiny performance benefit and a tiny improvement to Enzyme because it improves the compiler’s ability to optimize since it’s constant. It probably doesn’t matter much but I just instinctually make sure code is always well-typed.