Loss function with derivative of output

Hi everyone. I’m trying to run a FNN in which the derivative of the output is used in the loss function, as well as functions that have the output of the whole NN as their arguments.

Basically, I am struggling on how to make everything understandable for Flux. I have spent quite a time reading Flux tutorials, searching for similar questions and even trying to make parallels with similar problems being solved in Python, however my code doesn’t run.

I will replicate here and try to make sense of what I’m trying to do, so perhaps some of you can kindly give me some light.

using Flux
using ReverseDiff

##Train Data
t=vcat(0:0.1:4)
out=zeros(size(t)[1])

##Problem parameters
α=1; C=0.5; β=2; P=5; π0=1

##Building the model
len_hidden=5 #length of hidden layer

x=Chain(Dense(1,len_hidden,σ),Dense(len_hidden,3,relu))
dxdt(x)=x->ReverseDiff.gradient(x)

p=Chain(Dense(1,len_hidden,σ),Dense(len_hidden,3,relu))
dpdt(p)=p->ReverseDiff.gradient(p)

u=Chain(Dense(1,len_hidden,σ),Dense(len_hidden,3,relu))

#H
H(x,p,u)=α*u*x-C*u^2+p*β*x*(1-x)*(P*u-π0)

#Partials
dHdx(H,x)=gradient(H,x,p,u)[1]
dHdp(H,p)=gradient(H,x,p,u)[2]
dHdu(H,u)=gradient(H,x,p,u)[3]

#Full Model
#The last two terms are just impositions on initial condition for x and final condition for p
model(x,p,u)=(dxdt+dHdp)^2+(dpdt-dHdx)^2+(dHdu)^2+(x(0)-0.10)^2+(p(4)-1)^2

#Training
opt=Descent()
parameters=params(x,p,u)
data=[(t,out)]

loss(model,out)=Flux.Losses.mse(model(x,p,u),out)
Flux.train!(loss, parameters, data, opt)

My single hidden layer NN should have time as input and spit out x,p and u, all functions of time.

The loss function needs the time derivatives of each output (therefore the derivative with respect to input - which I am simply trying to tackle by the ReverseDiff.gradient). It also needs the partials of the function H, which is itself a function of the outputs as well. H=H(x,p,u).

I have been changing, tuning and trying to understand how to correctly implement this for a a while.

General observations, tips, comments are appreciated. Even if you think I should consider another package or approach for my problem, please feel free to let me know.

Thank you for you collaboration.

Gabriel

Hi, I fixed a couple of things in my code, and even though it runs now, I’m not getting the results I should get. I believe my loss function is not properly set yet. I also would like to see how the loss function decreases, to see how close to the actual function I am getting, but I don’t seem to find a away to see how the loss function progresses.

Here is my new code:

using Flux, Zygote, ForwardDiff

##Data
t=vcat(0:0.1:4)

##Problem parameters
α = 2; C = 1; β = 0.5; P = 1; π₀ = 0.5

#Initial and final conditions
x₀ = 0.5
p₄ = 1
t₀ = 0
t𝔣 = 4

#Hidden layer length 
len_hidden=5

X = Chain(Dense(1,len_hidden),Dense(len_hidden,1,relu))
x(t) = (t - t₀)*X([t])[1] + x₀
dxdt(t) = ForwardDiff.derivative(x,t)

Ρ = Chain(Dense(1,len_hidden),Dense(len_hidden,1,relu))
p(t) = p₄ + (t - t𝔣)*Ρ([t])[1]
dpdt(t) = ForwardDiff.derivative(p,t)

U = Chain(Dense(1,len_hidden),Dense(len_hidden,1,relu))
u(t) = U([t])[1] 

Θ = Flux.params(X,Ρ,U)

H(x,p,u) = α*u*x - C*u^2 + p*β*x*(1 - x)*(P*u - π₀)

#Partials
dHdx(t) = α*u(t) + p(t)*(1 - x(t))*β*(P*u(t) - π₀) - p(t)*x(t)*(P*u(t) - π₀)
dHdp(t) = (1 - x(t))*x(t)*β*(P*u(t) - π₀)
dHdu(t) = α*x(t) - 2*C*u(t) + P*p(t)*β*x(t)*(1 - x(t))

#Loss function
function loss(t)
    return (-dxdt(t) + dHdp(t))^2 + (dpdt(t) + dHdx(t))^2 + (dHdu(t))^2 
end

opt=Descent()
parameters=Θ
data=t

Flux.train!(loss, parameters, data, opt, cb = () -> println("Training"))

Is the way I wrote the loss function correct? For each time instant (which is my data vector), am I computing the loss with the updated value of each function in loss()?

Thank you again!
Gabriel

Generally when you want to do anything more than a simple fire-and-forget training loop, a custom loop is the way to go: Training · Flux. This will give you easy access to the loss value.

1 Like

Thank you for your advice. I have implemented a new version that follows your answer and I am much closer to the solution that I was with my original code.