# Partial Autodiff in NN using Flux

I have been using Flux to model a NN which has a loss function dependent on its own NN output and other functions that contains functions also dependent on the output of the NN, in a sort of recurrent way. My data is just a regular grid of points. I have made improvements on my original code based on a previous answer found here.

I have been dueling with my code for a couple of weeks now, and even though it seems to be a common type of problem structure (I have seen some folks from Physics Informed NNs constructing similar problems), I am not sure why my loss function is not passing a certain threshold during training. I wonder if my code is actually doing what I want.

Here it is:

``````using Flux, ForwardDiff

#Data
time = repeat(collect(0.1:0.4:4), inner=1, outer=5)
ic = repeat(collect(0.1:0.2:0.9), inner=1, outer=10)

#Problem parameters
α = 2; C = 1; β = 0.5; P = 1; π₀ = 0.5

#Hidden layer length
len_hidden = 5

#NNs
#These are the three functions I want to approximate.
#They map time and ic to a real
NNₓ = Chain(Dense(2,len_hidden,relu),Dense(len_hidden,1))
NNp = Chain(Dense(2,len_hidden,relu),Dense(len_hidden,1))
NNᵤ = Chain(Dense(2,len_hidden,relu),Dense(len_hidden,1,))

#Optimizer

#Parameters
Θ = params(NNₓ,NNp,NNᵤ)

#Training
L=[]
K=1000 #How many times I call the optimization update
for k in 1:K

#Get the NN's and create one function for each
x(i,j) = NNₓ([i,j])
p(i,j) = NNp([i,j])
u(i,j) = NNᵤ([i,j])

#Hamiltonians
#These are also functions of the two inputs, and also contain the 3 approximations
#as above, x, p and u
dHdx(i,j) = α*u(i,j) + (1 - x(i,j))*β*p(i,j)*(P*u(i,j) - π₀) - x(i,j)*β*p(i,j)*(P*u(i,j) - π₀) #Ok
dHdp(i,j) = (1 - x(i,j))*x(i,j)*β*(P*u(i,j) - π₀) #Ok
dHdu(i,j) = α*x(i,j) - 2*C*u(i,j) + P*p(i,j)*β*x(i,j)*(1 - x(i,j)) #Ok

#Here I want to compute the time derivatives of x and p
#Since they are derivatives wrt the inputs, I am computing the AutoDiff for each with
#respect to the input relative to time
dxdt(i,j) = ForwardDiff.derivative(j -> x(i,j),j)
dpdt(i,j) = ForwardDiff.derivative(j -> p(i,j),j)

#Loss function, computing for all points in the "grid" at once
loss(ic,time) = sum((dHdp.(ic,time) .- dxdt.(ic,time)).^2 .+ (dpdt.(ic,time) .+ dHdx.(ic,time)).^2 .+ (dHdu.(ic,time)).^2 .+ (x.(ic,0) .- ic).^2 .+ (p.(ic,4) .- 1).^2)