NN with a ODE in the loss function

Dear all,

I’m running a simple NN that can also be understood as a linear regression problem (there are specific reasons I’m posing it as a NN). Instead of a classic least square approach in the loss function, mine has a ODE being solved inside the loss function.

I’m trying to do this using Flux and my code is the following. Here, the input data is contained in a Matrix called InputData.

using DiffEqFlux, OrdinaryDiffEq, Flux,

tspan=(0.0,4.0)
init_cond_range=(0.05:0.05:0.95)

#NeuralNetwork 
NN = Chain(Dense((M+1)*(N+1),1))
W = params(NN)
opt = ADAM()

#Control Approximation
control_app(W, Input) = dot(W, Input)

#Solving the Trajectory
function SolveODE(x₀, W, input)
    g(x,p,t) = 1/2*x*(1 - x)*(control_app(W, input) - 1/2)
    prob = ODEProblem(g, x₀, tspan)
    sol = solve(prob, Tsit5(), reltol=1e-8, abstol=1e-8)
end 

timesize=size(time)[1]
#Optimization Loop
j=0
for x0 in init_cond_range
    x₀=x0
    for i in ((j*80) + 1):((j*80) + timesize)
        x_t   = SolveODE(x₀, W[1], InputData[i,:]).t
        x_sol = SolveODE(x₀, W[1], InputData[i,:]).u
        x_spl = Spline1D(x_t, x_sol; k=3, bc="nearest", s=0.0)
        
    
    grads = Flux.gradient(W) do
    loss_sum += -0.05*(2*control_app(W[1], InputData[i,:])*x_spl(i) - control_app(W[1], InputData[i,:])^2) 
    end    
            
    Flux.Optimise.update!(opt, W, grads)
    end
end

I want to update the loss for each row of my matrix, therefore I included the loss in the very inner loop that accounts for each row of InputData.

I have constructed this problem in many different ways, and always seem to get the same error which is unknown to me. It reads:

MethodError: no method matching +(::IRTools.Inner.Undefined, ::Float64)
Closest candidates are:
  +(::Any, ::Any, !Matched::Any, !Matched::Any...) at operators.jl:538
  +(!Matched::ChainRulesCore.One, ::Any) at /Users/Gabriel/.julia/packages/ChainRulesCore/qbmEe/src/differential_arithmetic.jl:146
  +(!Matched::Missing, ::Number) at missing.jl:115
  ...

Stacktrace:
 [1] macro expansion at /Users/Gabriel/.julia/packages/Zygote/zowrf/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(::Zygote.Context, ::typeof(+), ::IRTools.Inner.Undefined, ::Float64) at /Users/Gabriel/.julia/packages/Zygote/zowrf/src/compiler/interface2.jl:9
 [3] #9 at ./In[33]:34 [inlined]
 [4] _pullback(::Zygote.Context, ::var"#9#10"{Int64,Spline1D}) at /Users/Gabriel/.julia/packages/Zygote/zowrf/src/compiler/interface2.jl:0
 [5] pullback(::Function, ::Params) at /Users/Gabriel/.julia/packages/Zygote/zowrf/src/compiler/interface.jl:250
 [6] gradient(::Function, ::Params) at /Users/Gabriel/.julia/packages/Zygote/zowrf/src/compiler/interface.jl:58
 [7] top-level scope at In[33]:33
 [8] include_string(::Function, ::Module, ::String, ::String) at ./loading.jl:1091

I came across this matching +(::IRTools.Inner.Undefined,… many times while trying to get around this issue.

Please help me understand what I am doing wrong or provide me hints where I shuold look closely. Thank you a lot,

Gabriel

what the heck is this spline for? Why not use saveat? The problem is probably the non-differentiability of sol.t

Thank you for the reply. I included the spline there because I don’t want only the solution at specific time instants but I want the trajectory of x written as a function of time. I will reconsider the use of the saveat option.

Thank you for your help.

Gabriel

sol(t) is a function of time already, and it will work well with pure autodiff but not adjoints for the same reason. That will be a higher quality spline than doing a direct spline interpolation.

I don’t see you using this in a way that is incompatible with saveat though, since the i values are known before the solve.

1 Like

Thank you for your time.
Gabriel