Partial Autodiff in NN using Flux

I have been using Flux to model a NN which has a loss function dependent on its own NN output and other functions that contains functions also dependent on the output of the NN, in a sort of recurrent way. My data is just a regular grid of points. I have made improvements on my original code based on a previous answer found here.

I have been dueling with my code for a couple of weeks now, and even though it seems to be a common type of problem structure (I have seen some folks from Physics Informed NNs constructing similar problems), I am not sure why my loss function is not passing a certain threshold during training. I wonder if my code is actually doing what I want.

Here it is:

using Flux, ForwardDiff

#Data
time = repeat(collect(0.1:0.4:4), inner=1, outer=5)
ic = repeat(collect(0.1:0.2:0.9), inner=1, outer=10)

#Problem parameters
α = 2; C = 1; β = 0.5; P = 1; π₀ = 0.5

#Hidden layer length
len_hidden = 5

#NNs
#These are the three functions I want to approximate. 
#They map time and ic to a real
NNₓ = Chain(Dense(2,len_hidden,relu),Dense(len_hidden,1)) 
NNp = Chain(Dense(2,len_hidden,relu),Dense(len_hidden,1))
NNᵤ = Chain(Dense(2,len_hidden,relu),Dense(len_hidden,1,))

#Optimizer
opt = ADAM(0.001, (0.9, 0.999))

#Parameters
Θ = params(NNₓ,NNp,NNᵤ)

#Training
L=[]
K=1000 #How many times I call the optimization update
for k in 1:K
    
    #Get the NN's and create one function for each 
    x(i,j) = NNₓ([i,j])[1] 
    p(i,j) = NNp([i,j])[1]
    u(i,j) = NNᵤ([i,j])[1]
    
    #Hamiltonians
    #These are also functions of the two inputs, and also contain the 3 approximations 
    #as above, x, p and u
    dHdx(i,j) = α*u(i,j) + (1 - x(i,j))*β*p(i,j)*(P*u(i,j) - π₀) - x(i,j)*β*p(i,j)*(P*u(i,j) - π₀) #Ok
    dHdp(i,j) = (1 - x(i,j))*x(i,j)*β*(P*u(i,j) - π₀) #Ok 
    dHdu(i,j) = α*x(i,j) - 2*C*u(i,j) + P*p(i,j)*β*x(i,j)*(1 - x(i,j)) #Ok

    #Here I want to compute the time derivatives of x and p
    #Since they are derivatives wrt the inputs, I am computing the AutoDiff for each with 
    #respect to the input relative to time
    dxdt(i,j) = ForwardDiff.derivative(j -> x(i,j),j)
    dpdt(i,j) = ForwardDiff.derivative(j -> p(i,j),j)
        
    #Loss function, computing for all points in the "grid" at once
    loss(ic,time) = sum((dHdp.(ic,time) .- dxdt.(ic,time)).^2 .+ (dpdt.(ic,time) .+ dHdx.(ic,time)).^2 .+ (dHdu.(ic,time)).^2 .+ (x.(ic,0) .- ic).^2 .+ (p.(ic,4) .- 1).^2)
        
    grad = Flux.gradient(Θ) do
    loss(ic,time) 
    end

    Flux.Optimise.update!(opt, Θ, grad)
end

The loss function decreases down to a threshold and increases after that. Even still, the variable u is far from behaving like I expected. I know the parameters are being update, but I suspect that either the auto differentiation procedure is not working as I intended or the insertion of these equations within the for loop is not enabling computation to happen in the order that it should.

ANY comments/inputs/criticisms are welcome. Thank you!

Did you ever find a solution to this?

I would be really interested, since I am looking at a similar problem. I also want to use the gradient of the neural network as part of the loss function.