Errors when running a Universal Differential Equation (UDE)

Hello,
I am building a UDE as a part of my work in Julia. I am using the following example as reference
https://docs.sciml.ai/Overview/stable/showcase/missing_physics/

Unfortunately I am getting a warning message and error during implementation. As I am new to this topic I am not able to understand where I am going wrong. The following is the code I am using

using OrdinaryDiffEq , SciMLSensitivity ,Optimization, OptimizationOptimisers,OptimizationOptimJL, LineSearches
using Statistics
using StableRNGs, JLD2, Lux, Zygote , Plots , ComponentArrays

# Set a random seed for reporoducible behaviour
rng = StableRNG(11)

# loading the training data
function find_discharge_end(Current_data,start=5)
    for i in start:length(Current_data)
        if abs(Current_data[i]) == 0
            return i 
        end
    end
    return -1 
end

# This below function finds the discharge current value at each C_rates
function current_val(Crate)
    if Crate == "0p5C"
        return 0.5*5.0
    elseif Crate == "1C"
        return 1.0*5.0
    elseif Crate == "2C"
        return 2.0*5.0
    elseif Crate == "1p5C"
        return 1.5*5.0
    end
end


#training conditions 
    
Crate1,Temp1 = "1C",10
Crate2,Temp2 = "0p5C",25
Crate3,Temp3 = "2C",0
Crate4,Temp4 = "1C",25
Crate5,Temp5 = "0p5C",0
Crate6,Temp6 = "2C",10

# Loading data
data_file = load("Datasets_ashima.jld2")["Datasets"]
data1  = data_file["$(Crate1)_T$(Temp1)"] 
data2  = data_file["$(Crate2)_T$(Temp2)"]
data3  = data_file["$(Crate3)_T$(Temp3)"]
data4  = data_file["$(Crate4)_T$(Temp4)"]
data5  = data_file["$(Crate5)_T$(Temp5)"]
data6  = data_file["$(Crate6)_T$(Temp6)"]

# Finding the end of discharge index value and current value
n1,I1 = find_discharge_end(data1["current"]),current_val(Crate1)
n2,I2 = find_discharge_end(data2["current"]),current_val(Crate2)
n3,I3 = find_discharge_end(data3["current"]),current_val(Crate3)
n4,I4 = find_discharge_end(data4["current"]),current_val(Crate4)
n5,I5 = find_discharge_end(data5["current"]),current_val(Crate5)
n6,I6 = find_discharge_end(data6["current"]),current_val(Crate6)

t1,T1,T∞1 = data1["time"][2:n1],data1["temperature"][2:n1],data1["temperature"][1]
t2,T2,T∞2 = data2["time"][2:n2],data2["temperature"][2:n2],data2["temperature"][1]
t3,T3,T∞3 = data3["time"][2:n3],data3["temperature"][2:n3],data3["temperature"][1]
t4,T4,T∞4 = data4["time"][2:n4],data4["temperature"][2:n4],data4["temperature"][1]
t5,T5,T∞5 = data5["time"][2:n5],data5["temperature"][2:n5],data5["temperature"][1]
t6,T6,T∞6 = data6["time"][2:n6],data6["temperature"][2:n6],data6["temperature"][1]

# Defining the neural network
const NN = Lux.Chain(Lux.Dense(3,20,tanh),Lux.Dense(20,20,tanh),Lux.Dense(20,1)) # The const ensure faster execution and no accidental modification to the variable NN

# Get the initial parameters and state variables of the Model
para,st = Lux.setup(rng,NN)
const _st = st

# Defining the hybrid Model
function NODE_model!(du,u,p,t,T∞,I)
    
    
    Cbat  =  5*3600 # Battery capacity based on nominal voltage and energy in As
    du[1] = -I/Cbat # To estimate the SOC of the battery


    C₁ = -0.00153 # Unit is s-1
    C₂ = 0.020306 # Unit is K/J
    G  = I*(NN([u[1],u[2],I],p,_st)[1][1]) # Input to the neural network is SOC, Cell temperature, current. 
    du[2] = (C₁*(u[2]-T∞)) + (C₂*G) # G is in W here

end

# Closure with known parameter
NODE_model1!(du,u,p,t) = NODE_model!(du,u,p,t,T∞1,I1)
NODE_model2!(du,u,p,t) = NODE_model!(du,u,p,t,T∞2,I2)
NODE_model3!(du,u,p,t) = NODE_model!(du,u,p,t,T∞3,I3)
NODE_model4!(du,u,p,t) = NODE_model!(du,u,p,t,T∞4,I4)
NODE_model5!(du,u,p,t) = NODE_model!(du,u,p,t,T∞5,I5)
NODE_model6!(du,u,p,t) = NODE_model!(du,u,p,t,T∞6,I6)

# Define the problem

prob1 = ODEProblem(NODE_model1!,[1.0,T∞1],(t1[1],t1[end]),para)
prob2 = ODEProblem(NODE_model2!,[1.0,T∞2],(t2[1],t2[end]),para)
prob3 = ODEProblem(NODE_model3!,[1.0,T∞3],(t3[1],t3[end]),para)
prob4 = ODEProblem(NODE_model4!,[1.0,T∞4],(t4[1],t4[end]),para)
prob5 = ODEProblem(NODE_model5!,[1.0,T∞5],(t5[1],t5[end]),para)
prob6 = ODEProblem(NODE_model6!,[1.0,T∞6],(t6[1],t6[end]),para)




# Function that predicts the state and calculates the loss

α = 1
function loss_NODE(θ)
    N_dataset = 6
    Solver = Tsit5()

    if α%N_dataset ==0
        _prob1 = remake(prob1,p=θ)
        sol = Array(solve(_prob1,Solver,saveat=t1,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss1 = mean(abs2,T1.-sol[2,:])
        return loss1

    elseif α%N_dataset ==1
        _prob2 = remake(prob2,p=θ)
        sol = Array(solve(_prob2,Solver,saveat=t2,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss2 = mean(abs2,T2.-sol[2,:])
        return loss2

    elseif α%N_dataset ==2
        _prob3 = remake(prob3,p=θ)
        sol = Array(solve(_prob3,Solver,saveat=t3,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss3 = mean(abs2,T3.-sol[2,:])
        return loss3

    elseif α%N_dataset ==3
        _prob4 = remake(prob4,p=θ)
        sol = Array(solve(_prob4,Solver,saveat=t4,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss4 = mean(abs2,T4.-sol[2,:])
        return loss4

    elseif α%N_dataset ==4
        _prob5 = remake(prob5,p=θ)
        sol = Array(solve(_prob5,Solver,saveat=t5,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss5 = mean(abs2,T5.-sol[2,:])
        return loss5

    elseif α%N_dataset ==5
        _prob6 = remake(prob6,p=θ)
        sol = Array(solve(_prob6,Solver,saveat=t6,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss6 = mean(abs2,T6.-sol[2,:])
        return loss6
    end
end

# Defining a callback function to monitor the training process
plot_ = plot(framestyle = :box, legend = :none, xlabel = "Iteration",ylabel = "Loss (RMSE)",title = "Neural Network Training")
itera = 0

callback = function (state,l)
    global α +=1
    global itera +=1
    colors_ = [:red,:blue,:green,:purple,:orange,:black]
    println("RMSE Loss at iteration $(itera) is $(sqrt(l)) ")
    scatter!(plot_,[itera],[sqrt(l)],markersize=4,markercolor = colors_[α%6+1])
    display(plot_)

    return false
end

# Training
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,k) -> loss_NODE(x),adtype)
optprob = Optimization.OptimizationProblem(optf,ComponentVector{Float64}(para)) # The component vector to ensure that parameters get a strucutred format

# Optimizing the parameters
res1 = Optimization.solve(optprob,OptimizationOptimisers.Adam(),callback=callback,maxiters = 500)
para_adam = res1.u 

First comes the following warning message

Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).
│ 
│ 1. If this was not the desired behavior overload the dispatch on `m`.
│ 
│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt C:\Users\Kalath_A\.julia\packages\LuxCore\8mVob\ext\LuxCoreArrayInterfaceReverseDiffExt.jl:10

Then after that error message pops up.

RMSE Loss at iteration 1 is 2.4709837988316155 
ERROR: UndefVarError: `dλ` not defined in local scope
Suggestion: check for an assignment to a local variable that shadows a global of the same name.
Stacktrace:
  [1] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::QuadratureAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, callback::Nothing, kwargs::@Kwargs{…})
    @ SciMLSensitivity C:\Users\Kalath_A\.julia\packages\SciMLSensitivity\RQ8Av\src\quadrature_adjoint.jl:402
  [2] _adjoint_sensitivities
    @ C:\Users\Kalath_A\.julia\packages\SciMLSensitivity\RQ8Av\src\quadrature_adjoint.jl:337 [inlined]
  [3] #adjoint_sensitivities#63
    @ C:\Users\Kalath_A\.julia\packages\SciMLSensitivity\RQ8Av\src\sensitivity_interface.jl:401 [inlined]
  [4] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#323"{…})(Δ::ODESolution{…})
    @ SciMLSensitivity C:\Users\Kalath_A\.julia\packages\SciMLSensitivity\RQ8Av\src\concrete_solve.jl:627
  [5] ZBack
    @ C:\Users\Kalath_A\.julia\packages\Zygote\TWpme\src\compiler\chainrules.jl:212 [inlined]
  [6] (::Zygote.var"#kw_zpullback#56"{…})(dy::ODESolution{…})
    @ Zygote C:\Users\Kalath_A\.julia\packages\Zygote\TWpme\src\compiler\chainrules.jl:238
  [7] #295
    @ C:\Users\Kalath_A\.julia\packages\Zygote\TWpme\src\lib\lib.jl:205 [inlined]
  [8] (::Zygote.var"#2169#back#297"{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\Kalath_A\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:72
  [9] #solve#51
    @ C:\Users\Kalath_A\.julia\packages\DiffEqBase\R2Vjs\src\solve.jl:1038 [inlined]
 [10] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\Kalath_A\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
 [11] #295
    @ C:\Users\Kalath_A\.julia\packages\Zygote\TWpme\src\lib\lib.jl:205 [inlined]
 [12] (::Zygote.var"#2169#back#297"{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\Kalath_A\.julia\packages\ZygoteRules\CkVIK\src\adjoint.jl:72
 [13] solve
    @ C:\Users\Kalath_A\.julia\packages\DiffEqBase\R2Vjs\src\solve.jl:1028 [inlined]
 [14] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\Kalath_A\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
 [15] loss_NODE
    @ c:\Users\Kalath_A\OneDrive - University of Warwick\PhD\ML Notebooks\Neural ODE\Julia\T Mixed\With Qgen multiplied with I\updated_code.jl:128 [inlined]
 [16] (::Zygote.Pullback{Tuple{typeof(loss_NODE), ComponentVector{Float64, Vector{…}, Tuple{…}}}, Any})(Δ::Float64)
    @ Zygote C:\Users\Kalath_A\.julia\packages\Zygote\TWpme\src\compiler\interface2.jl:0
 [17] #13
    @ c:\Users\Kalath_A\OneDrive - University of Warwick\PhD\ML Notebooks\Neural ODE\Julia\T Mixed\With Qgen multiplied with I\updated_code.jl:169 [inlined]
 [18] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote C:\Users\Kalath_A\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:91
 [19] withgradient(::Function, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}}, ::Vararg{Any})
    @ Zygote C:\Users\Kalath_A\.julia\packages\Zygote\TWpme\src\compiler\interface.jl:213
 [20] value_and_gradient
    @ C:\Users\Kalath_A\.julia\packages\DifferentiationInterface\TtV2Z\ext\DifferentiationInterfaceZygoteExt\DifferentiationInterfaceZygoteExt.jl:118 [inlined]
 [21] value_and_gradient!(f::Function, grad::ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep, backend::AutoZygote, x::ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceZygoteExt C:\Users\Kalath_A\.julia\packages\DifferentiationInterface\TtV2Z\ext\DifferentiationInterfaceZygoteExt\DifferentiationInterfaceZygoteExt.jl:143
 [22] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::ComponentVector{…}, θ::ComponentVector{…})
    @ OptimizationZygoteExt C:\Users\Kalath_A\.julia\packages\OptimizationBase\gvXsf\ext\OptimizationZygoteExt.jl:53
 [23] macro expansion
    @ C:\Users\Kalath_A\.julia\packages\OptimizationOptimisers\xC7Ic\src\OptimizationOptimisers.jl:101 [inlined]
 [24] macro expansion
    @ C:\Users\Kalath_A\.julia\packages\Optimization\6Asog\src\utils.jl:32 [inlined]
 [25] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers C:\Users\Kalath_A\.julia\packages\OptimizationOptimisers\xC7Ic\src\OptimizationOptimisers.jl:83
 [26] solve!(cache::OptimizationCache{…})
    @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\3fgw8\src\solve.jl:187
 [27] solve(::OptimizationProblem{…}, ::Optimisers.Adam; kwargs::@Kwargs{…})
    @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\3fgw8\src\solve.jl:95
 [28] top-level scope
    @ c:\Users\Kalath_A\OneDrive - University of Warwick\PhD\ML Notebooks\Neural ODE\Julia\T Mixed\With Qgen multiplied with I\updated_code.jl:173
Some type information was truncated. Use `show(err)` to see complete types.

Does anyone know why this warning and error message pops up? I am following the UDE example which I mentioned earlier as a reference. The example works well without any errors. In the example Vern7() is used to solve the ODE. I tried that too. But the same warning and error pops up. I am reading on some theory to see if learning more about Automatic Differentiation (AD) would help in debugging this.

Any help would be much appreciated

I don’t now anything about UDEs but this part looks like it could cause some issues. Do the examples really use global variables like that?
You might want to use a loop instead of global itera +=1 and maybe have α as an input to the loss function.
The code have many global variables in general. This could be the cause of the issues.

I updated the code such that I am not defining any global variables.

function loss_NODE(θ)
    N_dataset = 6
    Solver = Tsit5()
    α = length(losses) + 1
    if α%N_dataset ==0
        _prob1 = remake(prob1,p=θ)
        sol = Array(solve(_prob1,Solver,saveat=t1,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss1 = mean(abs2,T1.-sol[2,:])
        return loss1

    elseif α%N_dataset ==1
        _prob2 = remake(prob2,p=θ)
        sol = Array(solve(_prob2,Solver,saveat=t2,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss2 = mean(abs2,T2.-sol[2,:])
        return loss2

    elseif α%N_dataset ==2
        _prob3 = remake(prob3,p=θ)
        sol = Array(solve(_prob3,Solver,saveat=t3,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss3 = mean(abs2,T3.-sol[2,:])
        return loss3

    elseif α%N_dataset ==3
        _prob4 = remake(prob4,p=θ)
        sol = Array(solve(_prob4,Solver,saveat=t4,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss4 = mean(abs2,T4.-sol[2,:])
        return loss4

    elseif α%N_dataset ==4
        _prob5 = remake(prob5,p=θ)
        sol = Array(solve(_prob5,Solver,saveat=t5,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss5 = mean(abs2,T5.-sol[2,:])
        return loss5

    elseif α%N_dataset ==5
        _prob6 = remake(prob6,p=θ)
        sol = Array(solve(_prob6,Solver,saveat=t6,abstol=1e-6,reltol=1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        loss6 = mean(abs2,T6.-sol[2,:])
        return loss6
    end
end

# Defining a callback function to monitor the training process
plot_ = plot(framestyle = :box, legend = :none, xlabel = "Iteration",ylabel = "Loss (RMSE)",title = "Neural Network Training")

losses = Float64[]
callback = function (state,l)
    
    push!(losses,l)
    α = length(losses)
    itera = length(losses)
    colors_ = [:red,:blue,:green,:purple,:orange,:black]
    println("RMSE Loss at iteration $(itera) is $(sqrt(l)) ")
    scatter!(plot_,[itera],[sqrt(l)],markersize=4,markercolor = colors_[α%6+1])
    display(plot_)

    return false
end

# Training
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_NODE(x),adtype)
optprob = Optimization.OptimizationProblem(optf,ComponentVector{Float64}(para)) # The component vector to ensure that parameters get a strucutred format

# Optimizing the parameters
res1 = Optimization.solve(optprob,OptimizationOptimisers.Adam(),callback=callback,maxiters = 500)
para_adam = res1.u 

This is the updated code. I defined the a variable named losses to keep track of iterations. But the same error occurs.

ERROR: UndefVarError: `dλ` not defined in local scope
Suggestion: check for an assignment to a local variable that shadows a global of the same name.

I wonded what is . I cannot find any info on that.

The error comes from SciMLSensitivity.jl/src/quadrature_adjoint.jl at f9e7b58ce6d9c8c20ac96510d1a67d91f0f5bba9 · SciML/SciMLSensitivity.jl · GitHub

I think the best approach is to create an issues on the github repository and ask them directly why it is not working.

Other tips;

  • The examples is still using global variables. I would consider all variables that are not passed to a function as input or constructed inside the function global variables. This includes variables like prob1 and losses.
  • It will be easier to support you if you provide a minimal working example (MWE). This is a small example that produces the error and that others are able to run on their own computer. Your current example can not be run without the data.

Thank you for your suggestions. I will provide a minimal working example (MWE) as another post to see if others also have same issue as me :slight_smile:

I’ll need a reproducer for this if it’s hitting the branch for handling dt=0. It’s a very odd case.