Higher order derivatives/ automatic differentiation

Tldr: Getting gradients for training PINNs for higher order ODEs/PDEs.

There’s a certain Diff Eq I want to solve using PINNs. Let’s say a damped harmonic oscillator:
u''(t)+ \mu u'(t)+ k u(t)= 0
with a Dirichlet BC (u(0)=0 and a Neumann BC u'(t)=0.
I define the architecture and loss functions as:

using Flux, Plots, BenchmarkTools, ProgressMeter
# defining the NN
n_input, n_output, n_hidden, n_layers= 1,1,32,3;
l_input= Dense(n_input, n_hidden, tanh, init=Flux.glorot_normal);
l_hidden= [Dense(n_hidden, n_hidden, tanh, init=Flux.glorot_normal) for i in 1:n_layers-1];
l_output= Dense(n_hidden, n_output, init=Flux.glorot_normal);
uNN= Chain([x->[x], l_input, l_hidden..., l_output, first]...);

# Derivative
function diff(f, t) # for numerical diff
    ϵ= cbrt(eps(t));
    return (f(t+ϵ)- f(t-ϵ))/2ϵ;
duNN(t)= diff(uNN, t);
d2uNN(t)= diff(duNN, t); # creating a FD stencil for second order performs worse, when it doesn't blow out

t_bound= 0.; # boundary point
t_coll= range(0.0, 1, length= 50); #collocation points

loss_bc1()= abs2(uNN(t_bound)- u0); # Dirichlet BC
loss_bc2()= abs2(duNN(t_bound)- 0); # Neumann BC
ode(t)= d2uNN(t)+ μ*duNN(t)+ k*uNN(t);
loss_ode()= Flux.mse(ode.(t_coll), zero(t_coll));
λ_bc, λ_ode= 1, 1;
loss()= loss_bc1()+ λ_bc*loss_bc2()+ λ_ode*loss_ode();

# Training
opt= Adam(1e-1);
ps= Flux.params(uNN);
n_epochs= 5001;
losses= [loss()];

for i in 1:n_epochs
    gs= Flux.gradient(loss, ps);
    Flux.Optimise.update!(opt, ps, gs);
    push!(losses, loss())
    ProgressMeter.next!(p; showvalues = [(:epoch,i),

This soon gives NaN training loss. My first guess was maybe the second derivative is blowing out because of the \epsilon^2 term in the denominator. I tried to get some AD libraries to help me out. One of the threads mentioned TaylorDiff.jl. Wjile it looks like it computes the right gradients, it probably is not compatible with Flux.jl.
Replacing the derivatives with

duNN(t)= TaylorDiff.derivative((t)->uNN(t), t, 1);
d2uNN(t)= TaylorDiff.derivative((t)->uNN(t), t, 2);

gives the error
Compiling Tuple{Type{TaylorScalar{Float64, Any}}, Float64, Float64}: type Nothing has no field code .
Next I tried ForwardDiff.jl. The derivatives now look like:

duNN(t)= ForwardDiff.derivative((t)->uNN(t), t)
d2uNN(t)= ForwardDiff.derivative((t)->duNN(t), t)

This runs perfectly find during training, just that it does not converge at any point. I tried it on a simple first order ODE, and maybe realised that while it computes correct gradients, it somehow messes with the gradient of the network. I compared it with a FD derivative on a very simple ODE, and the FD version converged while the Flux gradients got messed up with ForwardDiff.jl.
All the while, using Flux.gradient() kills my kernel even with a first order ODE. I could not exactly figure out why.
Maybe not a big thing, but ForwardDiff and TaylorDiff give gradients that do not match until the last floating point for the NN. My guess is somewhere Float32 is being implicitly used in one of them.
There probably is one simple implementation that I may have missed but I could not find one on my end. Thanks in advance!
(somehow deleted the previous post, Sorry for that)​​

I am skeptical that loss for ode works. It is dominant over boundary cost so optimiser tries to go to a local minima where it can zero out all u′′, u’ and u except in the boundary. I had tried it in this question and it seemed following ode cost does not work really even though it has much less λ_ode. The PINN implementation asked several times in this discourse from the searches. You may find your answers in one of them maybe.

Thanks @tomaklutfu, I realize that the \lambda's I’ve initialised are not the best ones and it does zero out everything. However, I tried it with other values as well and as much as I do not claim about finding the right balance, somehow the ansatz solution u(t=0)+ (t-0) u_{NN}(t) also does not work very well. I had looked at the the thread you mentioned prior to posting my question and could not find much help with that. My bigger concern, however, was that even with the first order derivative, ForwardDiff.derivative() is messing up with the gradients which somewhat goes beyond my understanding. Here’s a MWE for a simpler version where I try to solve for u'(t)= -kt with u(0)= 100:

using Flux, Plots, BenchmarkTools, ProgressMeter
k= 2.;
f(t)= -k*t;

u0= 100;

u_NN= Chain(t->[t],
    Dense(1, 15, tanh),
    Dense(15, 25, tanh),
    Dense(25, 15, tanh),

tgrid= sort(rand(100) .*10.);
ϵ= eps(Float64)* 1e10;

# gradients ()

du_NN2(t)= ForwardDiff.derivative((t)-> u_NN(t), t);
du_NN(t)= sum(gradient((t)-> u_NN(t), t)); # found somewhere about using sum instead of to circumvent the get_index() problem
du_NN3(t)= (u_NN(t+ ϵ)- u_NN(t- ϵ))/ 2ϵ
@show du_NN2(0.5) 
@show du_NN3(0.5)

loss_ODE()= Flux.mse(f.(tgrid), du_NN3.(tgrid));
loss_BC()= abs2(u_NN(0.)- u0)

loss()= loss_ODE()+ 0.001* loss_BC();

opt= Adam(0.01);
ps= Flux.params(u_NN)

epochs = 5000;
lsi= loss()
losses= [lsi];
for i in 1:epochs
    tgrid= sort(rand(50) .*10.);
    gs= gradient(loss, ps);
    Flux.Optimise.update!(opt, ps, gs);
    lsi= loss();
    push!(losses, lsi)
    ProgressMeter.next!(p; showvalues = [(:epoch,i),
            (:grad_norm, norm(gs, 2))])

# plot(losses)

xg2= 0:0.01:10;
plot(xg2, u_NN.(xg2), label= "NN soln")
plot!(xg2, -λ/2 .*xg2.^2 .+ u0, label= "analytical soln")

This does not converge well when I use the ForwardDiff gradient
loss_ODE()= Flux.mse(f.(tgrid), du_NN2.(tgrid));
The gradients computed by numerical diff and ForwardDiff come out pretty close but the latter somehow messes with Flux gradients required for training the NN. One can easily collect the norm of gradients from the two gradient implementations and see this is the case.

I compared the gradient norms for both in the for-loop and it seems ForwardDiff diverge from the numerical one after some point.

Right, at the moment I’m going on with FiniteDiff.jl. Things at least seem consistent that way until a more robust AD framework develops in Julia.

I seem to remember it is Flux which uses Float32 weights by default, if that helps