I don’t know about that direction, but the other direction (optimizing FEM techniques with NNs) is something that people are exploring, in things like this: https://www.sciencedirect.com/science/article/abs/pii/S0263823118303537 . And searching for that you can find like 20 other papers on this. I know the dolfin-adjoint people are getting in on this as well, but I don’t know if they’ve put out a paper quite yet.
Pretty much all PDEs are inherently stiff under some measure. CFL constants are a measure of maximal step size for stability, which is a measure of stiffness. There are of course many different ways to try and handle this stiffness. One very common way is to use implicit methods (for certain classes of PDEs), but other choices like multirate methods do exist (for semi-stiff equations). However, implicit methods have traditionally done extremely well, so I think more and more you’ll see them mixed with neural network approaches. They will lag behind other methods in terms of development though mostly because they are much more difficult to build.
Another place where this will likely show up in probably in Hessian calculations due to gradient pathologies. [2001.04536] Understanding and mitigating gradient pathologies in physics-informed neural networks details the gradient issues of PINNs quite well.
You definitely can/should differentiate numerical solvers. DifferentialEquations.jl provides differentiable numerical solvers (DiffEqSensitivity.jl) that NeuralNetDiffEq.jl and DiffEqFlux.jl build on for the neural-based methods. A lot of papers are easily generalized by thinking about it as a problem of differentiable numerical methods.
Want to do the Multistep PINN? That method is just DiffEqFlux where you use a method like VCABM3 and set adaptive=false
. How do you make it adaptive? Don’t set adaptive=false
. How do you make it handle stiff equations? Replace the ODE solver choice.
Similarly, the UDE paper explains how this method is just a specific SDE in a differentiable implementation of Euler-Maruyama (EM()
in DiffEq), so how do you generalize that to adaptive time stepping? LambaEM()
. Stiff equations? ImplicitEM()
, or SROCK()
, or etc. So it turns out that a lot of these methods can get generalized if you just think about them in the format of a differentiable DE solver, and the advantage is that you can then get all of the optimizations of solver directly.