Initialize some functions at every remake in PDEProblem


Is there a way to initialize a function in MethodOfLines so that I can “compile” it once for every parameter set, here’s an example:

using ModelingToolkit, MethodOfLines, LinearAlgebra, OrdinaryDiffEq, DomainSets
using ModelingToolkit: Differential
using DifferentialEquations
using ComponentArrays
using DataInterpolations
using Memoize

@parameters t,x
@parameters Tvec
@parameters α

@variables f(..)

∂_t = Differential(t)
∂_x = Differential(x)

@memoize interpf(tvec::Vector{Float64}) = LinearInterpolation(tvec, [1.,2.,3.,4.,5.])
interpnonmem(tvec::Vector{Float64}) =  LinearInterpolation(tvec, [1.,2.,3.,4.,5.])

function Temperature(t,x, Tvec)

    interp = interpnonmem(Tvec)
    return interp(x)
@register_symbolic Temperature(t,x,Tvec)

eqs = [∂_x(f(t,x)) ~ α * Temperature(t,x, Tvec)]

domain = [t ∈ Interval(0.,10.), x ∈ Interval(0., 10)]
bcs = [f(t,0.) ~ 0., f(0.,x) ~ 0.]

tvec_def = [280.,290.,300.,310.,320]
α_def = 3.
ps = [Tvec => tvec_def,
        α => α_def]
@named pdesys = PDESystem(eqs, bcs, domain, [t,x], [f(t,x)], ps)

discretization = MOLFiniteDifference([x => 0.025], t)

prob = discretize(pdesys, discretization)

function pde_solution(ps = [tvec_def,α_def]; prob = prob)

    updated_prob = remake(prob, p = ps)
    # Solution of the ODE system
    sol = solve(updated_prob, TRBDF2()) 


In this example (without Memoize), a LinearInterpolation is done at each timestep for each dx. Since it’s the same interpolation being called, wouldn’t there be a way to “initialize” it?
I tried using Memoize but the performance difference was negligible. In my other “heavier” setup, the performance difference before and after adding the LinearInterpolation was huge.

Update: I’ve benchmarked Memoize on my bigger problem and it seems with Memoize I get one third of the memory usage and twice the speed. Meaning, it probably works. However the performance difference is still really large between before the array parameter and after the array parameter. I went from 62ms per run to 3s per run. However, when replacing the interpolation with ifs, it goes down to 1.4s per run, with 30% less memory used.

I looked into this earlier this year, and while it is a known missing feature, it seems to still be in progress.

While your problem may be different than mine, since I only ever needed on the order of 10 boundary condition sets at a time, my solution was to initialize all runs with different datasets at once in a multithreaded loop with the interpolation defined within the loop and to save the discretized problems in a vector for easy access. I could then modify non-BC parameters across all problems without re-discretizing. Some generic code similar to what I used is below and it works well enough, though it’s not very elegant.

#define some input_data matrix or data structure such that it can be indexed for each BC set.
probvec = []
solvec = []
Threads.@threads for i = 1:n_sims
     #just need to somehow index with loop such that t_input, x_input change with i
     t_input = input_data[i,1,:] 
     x_input = input_data[i,2,:]
     x_BC = LinearInterpolation(x_input,t_input)

     #rest of code
     #relevant BC
     x(t,0) ~ x_BC(t)

     #more code, discretization
      prob = ModelingToolkit.discretize(pdesys,discretization);
     push!(probvec, (i, prob)) #saved as tuple for indexing since multithreading means code can finish in random orders

     sol = solve(prob) #done to run compilation and speed up future solves with different non-BC parameters
     push!(solvec, (i, sol)) #saved to check whether solve ran properly, but not really necessary

#strip probs out of tuple and put in correct order
#definitely could be done more elegantly, but I really just pretend to be a programmer
ordered_prob_tuples = sort(probvec, rev=false)
ordered_probs = []
for i = 1:size(ordered_prob_tuples,1) 
    push!(ordered_probs, ordered_prob_tuples[i][2])

#remake for non-BC parameters is fast, uses i to find correct problem 
#from problem vector and remake with some set of parameters from param_vector
newprob = remake(ordered_probs[i], p = param_vector)
sol = solve(newprob)

I still hope to figure out a better way to pass in BCs since I will eventually want to modify those parametrically, but I haven’t had much luck yet beyond some extremely low-quality experiments with series approximations. If you end up figuring something better out, please post about it, I’d be very interested to hear!

Sadly, I’m using my model as “the last layer” of a NN with thousands of points. The scale and the problem itself disqualify the use of such a solution. Currently for example, a 3s runtime is prohibitively long and I’ll also need to calculate the gradient…

This isn’t supported at the moment, can you create an issue containing a example of what you’d like to work?