Improving the speed for the forward solve of a Universal Differential Equation (UDE)

Hello everyone,
I have a UDE here, and I am trying to improve its speed during the forward solve so that the optimization is also faster. Right now optimization part is very slow. Here is a working code you can try out

using DifferentialEquations
using StableRNGs, Lux
using ComponentArrays,SparseArrays

#Defining the neural network
U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
rng = StableRNG(1111)
_para,st = Lux.setup(rng,U)
const _st_ = st

#Setting the number of states
n = 5

# Finding the parameters
d   = 21e-3 # diameters in m
r   = d/2
L   = 70e-3 # radius in m
Δr = r/n

const k = 1.05
const rho_c = 2.85e6
const h = 5.0
const Cbat = 5*3600

# Initializing
AL = zeros(n)
AB = zeros(n)
V  = zeros(n)

for i in 1:n
        AL[i] = 2*π*i*Δr*L
        AB[i] = π*(Δr^2)*(i^2 - ((i-1)^2))
        V[i]  = AB[i]*L
end

# Precomputing M and B
M = zeros(n,n)
B = zeros(n)



for i in 1:n
    if i ==1
        M[1,1] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[1]))/(rho_c*V[1])
        M[1,2] = (k*AL[1])/(1.5*Δr*rho_c*V[1])
        B[1]   = (2*h*AB[1])/(rho_c*V[1])
    
    elseif i == 2
        M[2,2] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[2])-((k*AL[2])/Δr))/(rho_c*V[2])
        M[2,1] = (k*AL[1])/(1.5*Δr*rho_c*V[2])
        M[2,3] = (k*AL[2])/(rho_c*V[2]*Δr)
        B[2]   = (2*h*AB[2])/(rho_c*V[2])
    
    elseif i == n
        M[n,n]   = (((-k*AL[n-1])/Δr)-(2*h*AB[n])-(1.5*h*AL[n]))/(rho_c*V[n])
        M[n,n-1] = (((k*AL[n-1])/Δr)+(0.5*h*AL[n]))/(rho_c*V[n])
        B[n]     = ((2*h*AB[n])+(h*AL[n]))/(rho_c*V[n])

    else 
        M[i,i]   = (((-k*AL[i-1])/Δr) - ((k*AL[i])/Δr)-(2*h*AB[i]))/(rho_c*V[i])
        M[i,i-1] = (k*AL[i-1])/(Δr*rho_c*V[i])
        M[i,i+1] = (k*AL[i])/(Δr*rho_c*V[i])
        B[i]     = (2*h*AB[i])/(rho_c*V[i])

    end
end
M = sparse(M)

# Defining the ODE model
function UDE_model!(du,u,p,t,n,M,B,T∞,I)
        


    #Neural Network
    G = [(sign(I)^2)*((U([u[n+1],u[i],I],p,_st_)[1][1])^2) for i in 1:n]

    du[1:n] .= M*u[1:n] + B.*T∞ + G./rho_c

   
    du[n+1] = -I/Cbat

end

T∞1 = 282.553
T∞2 = 297.67
T∞3 = 297.664

I1 = 5.0
I2 = 5.0
I3 = 10.0

t1 = collect(99076.0:1:102233.0)
t2 = collect(79053.0:1:82426.0)
t3 = collect(105517.0:1:107199.0)

T01 = fill(T∞1,n)
T02 = fill(T∞2,n)
T03 = fill(T∞3,n)

u01 = vcat(T01,1.0)
u02 = vcat(T02,1.0)
u03 = vcat(T03,1.0)



UDE_model1!(du,u,p,t) = UDE_model!(du,u,p,t,n,M,B,T∞1,I1)
UDE_model2!(du,u,p,t) = UDE_model!(du,u,p,t,n,M,B,T∞2,I2)
UDE_model3!(du,u,p,t) = UDE_model!(du,u,p,t,n,M,B,T∞3,I3)


prob1 = ODEProblem(UDE_model1!,u01,(t1[1],t1[end]),_para)
prob2 = ODEProblem(UDE_model2!,u02,(t2[1],t2[end]),_para)
prob3 = ODEProblem(UDE_model3!,u03,(t3[1],t3[end]),_para)


@time inisol1 = solve(prob1,Rosenbrock23(),saveat = t1)
@time inisol2 = solve(prob2,Rosenbrock23(),saveat = t2)
@time inisol3 = solve(prob3,Rosenbrock23(),saveat = t3)

Right now the speed is

inisol1 -> 13.82 k allocations: 1.094 MiB
inisol2 -> 15.64 k allocations: 1.242 MiB
inisol3 -> 10.87 k allocations: 891.984 KiB

My collegue who works on UDE said the allocations are high now.
So I need some tips on improving it so that my optimization is faster. Any help would be much appreciated.

the fastest I could get is :

using DifferentialEquations
using StableRNGs, Lux
using ComponentArrays,SparseArrays
using PreallocationTools

struct params{MO,P,S,MT,BT,T,C,C2}
    model ::MO
    ps ::P
    st ::S
    n ::Int
    M ::MT 
    B ::BT
    T∞ ::T
    I ::Int
    cache ::C
    cache2 ::C2
end

# Defining the ODE model
function UDE_model!(du,u,p,t)
    # Extracting parameters
    n = p.n
    M = p.M
    B = p.B
    T∞ = p.T∞
    I = p.I
    rho_c = 2.85e6
    Cbat = 5*3600
    cache = get_tmp(p.cache,u)
    mul!(cache,M,@view(u[1:n]))
    cache2 = get_tmp(p.cache2,u)
    @views cache2[1,:] .= u[n+1]
    @views cache2[2,:] .= u[1:n]
    @views cache2[3,:] .= I 
    C,_ = p.model(cache2,p.ps,p.st)
    C .= C.^2
    C .= C .* sign(p.I)^2
    for i in 1:n
        du[i] = C[1,i]/rho_c + B[i] * T∞ + cache[i]
    end
    du[n+1] = -I/Cbat
    nothing
end

function initialize()
    #Defining the neural network
    U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
    rng = StableRNG(1111)
    _para,st = Lux.setup(rng,U)
    _para = f64(_para)
    #Setting the number of states
    n = 5
    # Finding the parameters
    d   = 21e-3 # diameters in m
    r   = d/2
    L   = 70e-3 # radius in m
    Δr = r/n
    k = 1.05
    rho_c = 2.85e6
    h = 5.0
    Cbat = 5*3600
    # Initializing
    AL = zeros(n)
    AB = zeros(n)
    V  = zeros(n)
    for i in 1:n
            AL[i] = 2*π*i*Δr*L
            AB[i] = π*(Δr^2)*(i^2 - ((i-1)^2))
            V[i]  = AB[i]*L
    end
    # Precomputing M and B
    M = zeros(n,n)
    B = zeros(n)
    for i in 1:n
        if i ==1
            M[1,1] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[1]))/(rho_c*V[1])
            M[1,2] = (k*AL[1])/(1.5*Δr*rho_c*V[1])
            B[1]   = (2*h*AB[1])/(rho_c*V[1])
        elseif i == 2
            M[2,2] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[2])-((k*AL[2])/Δr))/(rho_c*V[2])
            M[2,1] = (k*AL[1])/(1.5*Δr*rho_c*V[2])
            M[2,3] = (k*AL[2])/(rho_c*V[2]*Δr)
            B[2]   = (2*h*AB[2])/(rho_c*V[2])
        
        elseif i == n
            M[n,n]   = (((-k*AL[n-1])/Δr)-(2*h*AB[n])-(1.5*h*AL[n]))/(rho_c*V[n])
            M[n,n-1] = (((k*AL[n-1])/Δr)+(0.5*h*AL[n]))/(rho_c*V[n])
            B[n]     = ((2*h*AB[n])+(h*AL[n]))/(rho_c*V[n])

        else 
            M[i,i]   = (((-k*AL[i-1])/Δr) - ((k*AL[i])/Δr)-(2*h*AB[i]))/(rho_c*V[i])
            M[i,i-1] = (k*AL[i-1])/(Δr*rho_c*V[i])
            M[i,i+1] = (k*AL[i])/(Δr*rho_c*V[i])
            B[i]     = (2*h*AB[i])/(rho_c*V[i])

        end
    end
    M = sparse(M)
    T∞1 = 282.553
    T∞2 = 297.67
    T∞3 = 297.664
    I1 = 5
    I2 = 5
    I3 = 10
    cache1 = DiffCache(zeros(n))
    cache12 = DiffCache(zeros(3,n))
    cache2 = DiffCache(zeros(n))
    cache22 = DiffCache(zeros(3,n))
    cache3 = DiffCache(zeros(n))
    cache32 = DiffCache(zeros(3,n))
    p1 = params(U,_para,st,n,M,B,T∞1,I1,cache1,cache12)
    p2 = params(U,_para,st,n,M,B,T∞2,I2,cache2,cache22)
    p3 = params(U,_para,st,n,M,B,T∞3,I3,cache3,cache32)
    t1 = collect(99076.0:1:102233.0)
    t2 = collect(79053.0:1:82426.0)
    t3 = collect(105517.0:1:107199.0)
    T01 = fill(T∞1,n)
    T02 = fill(T∞2,n)
    T03 = fill(T∞3,n)
    u01 = vcat(T01,1.0)
    u02 = vcat(T02,1.0)
    u03 = vcat(T03,1.0)
    prob1 = ODEProblem(UDE_model!,u01,(t1[1],t1[end]),p1)
    prob2 = ODEProblem(UDE_model!,u02,(t2[1],t2[end]),p2)
    prob3 = ODEProblem(UDE_model!,u03,(t3[1],t3[end]),p3)
    return prob1, prob2, prob3, t1, t2, t3
end

prob1, prob2, prob3, t1, t2, t3 = initialize()
@time inisol1 = solve(prob1,Rosenbrock23(),saveat = t1);
@time inisol2 = solve(prob2,Rosenbrock23(),saveat = t2);
@time inisol3 = solve(prob3,Rosenbrock23(),saveat = t3);

however you will need to implement the interface if you want to diff through that :

you can use const var if you want but avoid the closure you did that’s as bad as having global, also I’m not sure caching everything is good when you will diff through it you may need to avoid some caching I did because they may hide the derivatives.

Finally it is still quite a mess I think I’ve never been able to write the interface for ML related stuff like why this not work :

using DifferentialEquations
using StableRNGs, Lux
using ComponentArrays,SparseArrays
using PreallocationTools
using SciMLSensitivity
import SciMLStructures as SS
using LinearAlgebra
using Zygote
using Parameters

mutable struct params{MO,P,S,MT,BT,T,C,C2}
    model ::MO
    ps ::P
    st ::S
    n ::Int
    M ::MT 
    B ::BT
    T∞ ::T
    I ::Int
    cache ::C
    cache2 ::C2
end
SS.isscimlstructure(::params) = true
SS.ismutablescimlstructure(::params) = true
SS.hasportion(::SS.Tunable, ::params) = true
function SS.canonicalize(::SS.Tunable, p::params)
  buffer = copy(p.ps)
  repack = let p = p
    function repack(newbuffer)
      SS.replace(SS.Tunable(), p, newbuffer)
    end
  end
  return buffer, repack, false
end
function SS.replace(::SS.Tunable, p::params, newbuffer)
  return params(
    p.model,newbuffer,p.st,p.n,p.M,p.B,p.T∞,p.I,p.cache,p.cache2
  )
end

function SS.replace!(::SS.Tunable, p::params, newbuffer)
    p.ps = newbuffer
    return p
end

function initialize()
    #Defining the neural network
    U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
    rng = StableRNG(1111)
    _para,st = Lux.setup(rng,U)
    _para = ComponentArray(f64(_para))
    #Setting the number of states
    n = 5
    # Finding the parameters
    d   = 21e-3 # diameters in m
    r   = d/2
    L   = 70e-3 # radius in m
    Δr = r/n
    k = 1.05
    rho_c = 2.85e6
    h = 5.0
    Cbat = 5*3600
    # Initializing
    AL = zeros(n)
    AB = zeros(n)
    V  = zeros(n)
    for i in 1:n
            AL[i] = 2*π*i*Δr*L
            AB[i] = π*(Δr^2)*(i^2 - ((i-1)^2))
            V[i]  = AB[i]*L
    end
    # Precomputing M and B
    M = zeros(n,n)
    B = zeros(n)
    for i in 1:n
        if i ==1
            M[1,1] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[1]))/(rho_c*V[1])
            M[1,2] = (k*AL[1])/(1.5*Δr*rho_c*V[1])
            B[1]   = (2*h*AB[1])/(rho_c*V[1])
        elseif i == 2
            M[2,2] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[2])-((k*AL[2])/Δr))/(rho_c*V[2])
            M[2,1] = (k*AL[1])/(1.5*Δr*rho_c*V[2])
            M[2,3] = (k*AL[2])/(rho_c*V[2]*Δr)
            B[2]   = (2*h*AB[2])/(rho_c*V[2])
        
        elseif i == n
            M[n,n]   = (((-k*AL[n-1])/Δr)-(2*h*AB[n])-(1.5*h*AL[n]))/(rho_c*V[n])
            M[n,n-1] = (((k*AL[n-1])/Δr)+(0.5*h*AL[n]))/(rho_c*V[n])
            B[n]     = ((2*h*AB[n])+(h*AL[n]))/(rho_c*V[n])

        else 
            M[i,i]   = (((-k*AL[i-1])/Δr) - ((k*AL[i])/Δr)-(2*h*AB[i]))/(rho_c*V[i])
            M[i,i-1] = (k*AL[i-1])/(Δr*rho_c*V[i])
            M[i,i+1] = (k*AL[i])/(Δr*rho_c*V[i])
            B[i]     = (2*h*AB[i])/(rho_c*V[i])

        end
    end
    M = sparse(M)
    T∞1 = 282.553
    T∞2 = 297.67
    T∞3 = 297.664
    I1 = 5
    I2 = 5
    I3 = 10
    cache1 = DiffCache(zeros(n))
    cache12 = DiffCache(zeros(3,n))
    cache2 = DiffCache(zeros(n))
    cache22 = DiffCache(zeros(3,n))
    cache3 = DiffCache(zeros(n))
    cache32 = DiffCache(zeros(3,n))
    p1 = params(U,_para,st,n,M,B,T∞1,I1,cache1,cache12)
    p2 = params(U,_para,st,n,M,B,T∞2,I2,cache2,cache22)
    p3 = params(U,_para,st,n,M,B,T∞3,I3,cache3,cache32)
    t1 = collect(99076.0:1:102233.0)
    t2 = collect(79053.0:1:82426.0)
    t3 = collect(105517.0:1:107199.0)
    T01 = fill(T∞1,n)
    T02 = fill(T∞2,n)
    T03 = fill(T∞3,n)
    u01 = vcat(T01,1.0)
    u02 = vcat(T02,1.0)
    u03 = vcat(T03,1.0)
    prob1 = ODEProblem(UDE_model!,u01,(t1[1],t1[end]),p1)
    prob2 = ODEProblem(UDE_model!,u02,(t2[1],t2[end]),p2)
    prob3 = ODEProblem(UDE_model!,u03,(t3[1],t3[end]),p3)
    return prob1, prob2, prob3, t1, t2, t3, u01, u02, u03, p1, p2, p3
end


# Defining the ODE model
function UDE_model!(du,u,p,t)
    # Extracting parameters
    Parameters.@unpack model,ps,st,n,M,B,T∞,I,cache,cache2 = p
    rho_c = 2.85e6
    Cbat = 5*3600
    cache = get_tmp(cache,u)
    mul!(cache,M,@view(u[1:n]))
    cache2 = get_tmp(cache2,u)
    C = model(cache2,ps,st)[1].^2 .* (sign(I)^2)
    @views du[1:n] .= C[1,:]./rho_c .+ B .* T∞ .+ cache
    du[n+1] = -I/Cbat
    nothing
end


prob1, prob2, prob3, t1, t2, t3, u01, u02, u03, p1, p2, p3 = initialize()
@time inisol1 = solve(prob1,Rosenbrock23(),saveat = t1);
@time inisol2 = solve(prob2,Rosenbrock23(),saveat = t2);
@time inisol3 = solve(prob3,Rosenbrock23(),saveat = t3);

function run_diff(ps)
    p = p1
    p.ps = ps
    prob = remake(prob1, p = p)
    sol = solve(prob, Rosenbrock23(), saveat = t1)
    return sol.u |> last |> sum
end

run_diff(p1.ps)
Zygote.gradient(run_diff, p1.ps)

for some reason the solve call similar on the structure and if I implement it it will call length which is nonsense

Using Rosenbrock23 is almost never a good idea with training. It didn’t have the right properties. The title is also a bit confusing too, it says speed of the forward pass but then I think the discussion is on training time? Those can have very different properties, and prefer different solvers. I’d probably swap out to FBDF and maybe try Reactant compiling first before doing other things. Post your update and I should be able to try it this weekend if this isn’t already solved by what I mentioned

Thank you so much for your reply. I will give it a try. Really appreciate your effort :slight_smile:

Thank you for your reply. I am trying to use Reactant. But the following error message pops up

Precompiling Reactant...
Info Given Reactant was explicitly requested, output will be shown live
ERROR: LoadError: UndefVarError: `RNumber` not defined in `EnzymeCore`
Stacktrace:
  [1] getproperty(x::Module, f::Symbol)
    @ Base .\Base.jl:42
  [2] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\analyses\activity.jl:76
  [3] include(mod::Module, _path::String)
    @ Base .\Base.jl:557
  [4] include(x::String)
    @ Enzyme.Compiler C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:1
  [5] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:175
  [6] include(mod::Module, _path::String)
    @ Base .\Base.jl:557
  [7] include(x::String)
    @ Enzyme C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:1
  [8] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:127
  [9] include
    @ .\Base.jl:557 [inlined]
 [10] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::String)
    @ Base .\loading.jl:2881
 [11] top-level scope
    @ stdin:6
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\analyses\activity.jl:76
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:1
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:1
in expression starting at stdin:6
ERROR: LoadError: Failed to precompile Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] to "C:\\Users\\Kalath_A\\.julia\\compiled\\v1.11\\Enzyme\\jl_2E3E.tmp".
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] compilecache(pkg::Base.PkgId, path::String, internal_stderr::IO, internal_stdout::IO, keep_loaded_modules::Bool; flags::Cmd, cacheflags::Base.CacheFlags, reasons::Dict{String, Int64}, loadable_exts::Nothing)
    @ Base .\loading.jl:3174
  [3] (::Base.var"#1110#1111"{Base.PkgId})()
    @ Base .\loading.jl:2579
  [4] mkpidlock(f::Base.var"#1110#1111"{Base.PkgId}, at::String, pid::Int32; kwopts::@Kwargs{stale_age::Int64, wait::Bool})    
    @ FileWatching.Pidfile C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:95
  [5] #mkpidlock#6
    @ C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:90 [inlined]
  [6] trymkpidlock(::Function, ::Vararg{Any}; kwargs::@Kwargs{stale_age::Int64})
    @ FileWatching.Pidfile C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:116PUArraysCoreExt
  [7] #invokelatest#2
    @ .\essentials.jl:1057 [inlined]
  [8] invokelatest
    @ .\essentials.jl:1052 [inlined]
  [9] maybe_cachefile_lock(f::Base.var"#1110#1111"{Base.PkgId}, pkg::Base.PkgId, srcpath::String; stale_age::Int64)
    @ Base .\loading.jl:3698
 [10] maybe_cachefile_lock
    @ .\loading.jl:3695 [inlined]
 [11] _require(pkg::Base.PkgId, env::String)
    @ Base .\loading.jl:2565
 [12] __require_prelocked(uuidkey::Base.PkgId, env::String)
    @ Base .\loading.jl:2388
 [13] #invoke_in_world#3
    @ .\essentials.jl:1089 [inlined]
 [14] invoke_in_world
    @ .\essentials.jl:1086 [inlined]
 [15] _require_prelocked(uuidkey::Base.PkgId, env::String)
    @ Base .\loading.jl:2375
 [16] macro expansion
    @ .\loading.jl:2314 [inlined]
 [17] macro expansion
    @ .\lock.jl:273 [inlined]
 [18] __require(into::Module, mod::Symbol)
    @ Base .\loading.jl:2271
 [19] #invoke_in_world#3
    @ .\essentials.jl:1089 [inlined]
 [20] invoke_in_world
    @ .\essentials.jl:1086 [inlined]
 [21] require(into::Module, mod::Symbol)
    @ Base .\loading.jl:2260
 [22] include
    @ .\Base.jl:557 [inlined]
 [23] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::Nothing)
    @ Base .\loading.jl:2881
 [24] top-level scope
    @ stdin:6
in expression starting at C:\Users\Kalath_A\.julia\packages\Reactant\yfoXU\src\Reactant.jl:1
in expression starting at stdin:6
  ✗ Enzyme
  ✗ Enzyme → EnzymeGPUArraysCoreExt
  ✗ Reactant
  0 dependencies successfully precompiled in 21 seconds. 61 already precompiled.

ERROR: The following 1 direct dependency failed to precompile:

Reactant

Failed to precompile Reactant [3c362404-f566-11ee-1572-e11a4b42c853] to "C:\\Users\\Kalath_A\\.julia\\compiled\\v1.11\\Reactant\\jl_2C40.tmp".
ERROR: LoadError: UndefVarError: `RNumber` not defined in `EnzymeCore`
Stacktrace:
  [1] getproperty(x::Module, f::Symbol)
    @ Base .\Base.jl:42
  [2] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\analyses\activity.jl:76
  [3] include(mod::Module, _path::String)
    @ Base .\Base.jl:557
  [4] include(x::String)
    @ Enzyme.Compiler C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:1
  [5] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:175
  [6] include(mod::Module, _path::String)
    @ Base .\Base.jl:557
  [7] include(x::String)
    @ Enzyme C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:1
  [8] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:127
  [9] include
    @ .\Base.jl:557 [inlined]
 [10] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::String)
    @ Base .\loading.jl:2881
 [11] top-level scope
    @ stdin:6
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\analyses\activity.jl:76
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:1
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:1
in expression starting at stdin:6
ERROR: LoadError: Failed to precompile Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] to "C:\\Users\\Kalath_A\\.julia\\compiled\\v1.11\\Enzyme\\jl_2E3E.tmp".
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] compilecache(pkg::Base.PkgId, path::String, internal_stderr::IO, internal_stdout::IO, keep_loaded_modules::Bool; flags::Cmd, cacheflags::Base.CacheFlags, reasons::Dict{String, Int64}, loadable_exts::Nothing)
    @ Base .\loading.jl:3174
  [3] (::Base.var"#1110#1111"{Base.PkgId})()
    @ Base .\loading.jl:2579
  [4] mkpidlock(f::Base.var"#1110#1111"{Base.PkgId}, at::String, pid::Int32; kwopts::@Kwargs{stale_age::Int64, wait::Bool})    
    @ FileWatching.Pidfile C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:95
  [5] #mkpidlock#6
    @ C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:90 [inlined]
  [6] trymkpidlock(::Function, ::Vararg{Any}; kwargs::@Kwargs{stale_age::Int64})
    @ FileWatching.Pidfile C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:116
  [7] #invokelatest#2
    @ .\essentials.jl:1057 [inlined]
  [8] invokelatest
    @ .\essentials.jl:1052 [inlined]
  [9] maybe_cachefile_lock(f::Base.var"#1110#1111"{Base.PkgId}, pkg::Base.PkgId, srcpath::String; stale_age::Int64)
    @ Base .\loading.jl:3698
 [10] maybe_cachefile_lock
    @ .\loading.jl:3695 [inlined]
 [11] _require(pkg::Base.PkgId, env::String)
    @ Base .\loading.jl:2565
 [12] __require_prelocked(uuidkey::Base.PkgId, env::String)
    @ Base .\loading.jl:2388
 [13] #invoke_in_world#3
    @ .\essentials.jl:1089 [inlined]
 [14] invoke_in_world
    @ .\essentials.jl:1086 [inlined]
 [15] _require_prelocked(uuidkey::Base.PkgId, env::String)
    @ Base .\loading.jl:2375
 [16] macro expansion
    @ .\loading.jl:2314 [inlined]
 [17] macro expansion
    @ .\lock.jl:273 [inlined]
 [18] __require(into::Module, mod::Symbol)
    @ Base .\loading.jl:2271
 [19] #invoke_in_world#3
    @ .\essentials.jl:1089 [inlined]
 [20] invoke_in_world
    @ .\essentials.jl:1086 [inlined]
 [21] require(into::Module, mod::Symbol)
    @ Base .\loading.jl:2260
 [22] include
    @ .\Base.jl:557 [inlined]
 [23] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::Nothing)
    @ Base .\loading.jl:2881
 [24] top-level scope
    @ stdin:6
in expression starting at C:\Users\Kalath_A\.julia\packages\Reactant\yfoXU\src\Reactant.jl:1
in expression starting at stdin:

I couldn’t find any information on this error. The Pkg.status() is

Status `E:\PhD Ashima\Neural ODE\Julia\env_NODE\Project.toml`
⌃ [b0b7db55] ComponentArrays v0.15.23
⌃ [0c46a032] DifferentialEquations v7.15.0
⌃ [a98d9a8b] Interpolations v0.15.1
⌃ [033835bb] JLD2 v0.5.11
⌃ [d3d80556] LineSearches v7.3.0
⌃ [b2108857] Lux v1.6.0
  [23992714] MAT v0.10.7
⌃ [7f7a1694] Optimization v4.1.0
⌃ [36348300] OptimizationOptimJL v0.4.1
  [42dfb2eb] OptimizationOptimisers v0.3.7
  [58dd65bb] Plotly v0.4.1
⌃ [91a5bcdd] Plots v1.40.9
  [d236fae5] PreallocationTools v0.4.27
⌃ [3c362404] Reactant v0.2.41
⌃ [1ed8b502] SciMLSensitivity v7.72.0
⌃ [860ef19b] StableRNGs v1.0.2
  [10745b16] Statistics v1.11.1
⌅ [e88e6eb3] Zygote v0.6.75
  [37e2e46d] LinearAlgebra v1.11.0
  [9a3f8284] Random v1.11.0
  [2f01184e] SparseArrays v1.11.0
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted 
by compatibility constraints from upgrading. To see why use `status --outdated`

and my Julia version is

Julia Version 1.11.3
Commit d63adeda50 (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 40 × Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, skylake-avx512)
Threads: 1 default, 0 interactive, 1 GC (on 40 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_VSCODE_REPL = 1

Any idea why this is happening?

Reactant can’t be used on windows it works (sometimes) with wsl sublinux but you may hit weird bugs

Have you looked at the Julia performance tips and the Differential Equations performance tips?

  1. Don’t use global variables, pass them via parameters
  2. You are allocating lots of temporary arrays. Don’t.
  3. Since your matrices are only 5x5, consider using StaticArrays.jl. SparseArrays are probably not beneficial for such small arrays; in this case the array seems to be tridiagonal so you could use the Tridiagonal type (which can still be used with StaticArrays).