Thanks @tomaklutfu, I realize that the \lambda's I’ve initialised are not the best ones and it does zero out everything. However, I tried it with other values as well and as much as I do not claim about finding the right balance, somehow the ansatz solution u(t=0)+ (t-0) u_{NN}(t) also does not work very well. I had looked at the the thread you mentioned prior to posting my question and could not find much help with that. My bigger concern, however, was that even with the first order derivative, ForwardDiff.derivative()
is messing up with the gradients which somewhat goes beyond my understanding. Here’s a MWE for a simpler version where I try to solve for u'(t)= -kt with u(0)= 100:
using Flux, Plots, BenchmarkTools, ProgressMeter
k= 2.;
f(t)= -k*t;
u0= 100;
u_NN= Chain(t->[t],
Dense(1, 15, tanh),
Dense(15, 25, tanh),
Dense(25, 15, tanh),
Dense(15,1),
first
)
tgrid= sort(rand(100) .*10.);
ϵ= eps(Float64)* 1e10;
# gradients ()
du_NN2(t)= ForwardDiff.derivative((t)-> u_NN(t), t);
du_NN(t)= sum(gradient((t)-> u_NN(t), t)); # found somewhere about using sum instead of to circumvent the get_index() problem
du_NN3(t)= (u_NN(t+ ϵ)- u_NN(t- ϵ))/ 2ϵ
@show du_NN2(0.5)
@show du_NN3(0.5)
loss_ODE()= Flux.mse(f.(tgrid), du_NN3.(tgrid));
loss_BC()= abs2(u_NN(0.)- u0)
loss()= loss_ODE()+ 0.001* loss_BC();
opt= Adam(0.01);
ps= Flux.params(u_NN)
epochs = 5000;
lsi= loss()
losses= [lsi];
ProgressMeter.ijulia_behavior(:clear)
p=Progress(epochs)
for i in 1:epochs
tgrid= sort(rand(50) .*10.);
gs= gradient(loss, ps);
Flux.Optimise.update!(opt, ps, gs);
lsi= loss();
push!(losses, lsi)
ProgressMeter.next!(p; showvalues = [(:epoch,i),
(:train_loss,lsi),
(:grad_norm, norm(gs, 2))])
end
# plot(losses)
xg2= 0:0.01:10;
plot(xg2, u_NN.(xg2), label= "NN soln")
plot!(xg2, -λ/2 .*xg2.^2 .+ u0, label= "analytical soln")
This does not converge well when I use the ForwardDiff
gradient
loss_ODE()= Flux.mse(f.(tgrid), du_NN2.(tgrid));
The gradients computed by numerical diff and ForwardDiff
come out pretty close but the latter somehow messes with Flux
gradients required for training the NN. One can easily collect the norm of gradients from the two gradient implementations and see this is the case.