Tldr: Getting gradients for training PINNs for higher order ODEs/PDEs.
There’s a certain Diff Eq I want to solve using PINNs. Let’s say a damped harmonic oscillator:
u''(t)+ \mu u'(t)+ k u(t)= 0
with a Dirichlet BC (u(0)=0 and a Neumann BC u'(t)=0.
I define the architecture and loss functions as:
using Flux, Plots, BenchmarkTools, ProgressMeter
# defining the NN
n_input, n_output, n_hidden, n_layers= 1,1,32,3;
l_input= Dense(n_input, n_hidden, tanh, init=Flux.glorot_normal);
l_hidden= [Dense(n_hidden, n_hidden, tanh, init=Flux.glorot_normal) for i in 1:n_layers-1];
l_output= Dense(n_hidden, n_output, init=Flux.glorot_normal);
uNN= Chain([x->[x], l_input, l_hidden..., l_output, first]...);
# Derivative
function diff(f, t) # for numerical diff
ϵ= cbrt(eps(t));
return (f(t+ϵ)- f(t-ϵ))/2ϵ;
end
duNN(t)= diff(uNN, t);
d2uNN(t)= diff(duNN, t); # creating a FD stencil for second order performs worse, when it doesn't blow out
t_bound= 0.; # boundary point
t_coll= range(0.0, 1, length= 50); #collocation points
loss_bc1()= abs2(uNN(t_bound)- u0); # Dirichlet BC
loss_bc2()= abs2(duNN(t_bound)- 0); # Neumann BC
ode(t)= d2uNN(t)+ μ*duNN(t)+ k*uNN(t);
loss_ode()= Flux.mse(ode.(t_coll), zero(t_coll));
λ_bc, λ_ode= 1, 1;
loss()= loss_bc1()+ λ_bc*loss_bc2()+ λ_ode*loss_ode();
# Training
opt= Adam(1e-1);
ps= Flux.params(uNN);
n_epochs= 5001;
losses= [loss()];
ProgressMeter.ijulia_behavior(:clear)
p=Progress(n_epochs)
for i in 1:n_epochs
gs= Flux.gradient(loss, ps);
Flux.Optimise.update!(opt, ps, gs);
push!(losses, loss())
ProgressMeter.next!(p; showvalues = [(:epoch,i),
(:train_loss,losses[end-1])])
end
This soon gives NaN
training loss. My first guess was maybe the second derivative is blowing out because of the \epsilon^2 term in the denominator. I tried to get some AD libraries to help me out. One of the threads mentioned TaylorDiff.jl
. Wjile it looks like it computes the right gradients, it probably is not compatible with Flux.jl
.
Replacing the derivatives with
duNN(t)= TaylorDiff.derivative((t)->uNN(t), t, 1);
d2uNN(t)= TaylorDiff.derivative((t)->uNN(t), t, 2);
gives the error
Compiling Tuple{Type{TaylorScalar{Float64, Any}}, Float64, Float64}: type Nothing has no field code
.
Next I tried ForwardDiff.jl
. The derivatives now look like:
duNN(t)= ForwardDiff.derivative((t)->uNN(t), t)
d2uNN(t)= ForwardDiff.derivative((t)->duNN(t), t)
This runs perfectly find during training, just that it does not converge at any point. I tried it on a simple first order ODE, and maybe realised that while it computes correct gradients, it somehow messes with the gradient of the network. I compared it with a FD derivative on a very simple ODE, and the FD version converged while the Flux
gradients got messed up with ForwardDiff.jl
.
All the while, using Flux.gradient()
kills my kernel even with a first order ODE. I could not exactly figure out why.
Maybe not a big thing, but ForwardDiff
and TaylorDiff
give gradients that do not match until the last floating point for the NN. My guess is somewhere Float32
is being implicitly used in one of them.
There probably is one simple implementation that I may have missed but I could not find one on my end. Thanks in advance!
(somehow deleted the previous post, Sorry for that)