Improving the speed for the forward solve of a Universal Differential Equation (UDE)

Hello everyone,
I have a UDE here, and I am trying to improve its speed during the forward solve so that the optimization is also faster. Right now optimization part is very slow. Here is a working code you can try out

using DifferentialEquations
using StableRNGs, Lux
using ComponentArrays,SparseArrays

#Defining the neural network
U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
rng = StableRNG(1111)
_para,st = Lux.setup(rng,U)
const _st_ = st

#Setting the number of states
n = 5

# Finding the parameters
d   = 21e-3 # diameters in m
r   = d/2
L   = 70e-3 # radius in m
Δr = r/n

const k = 1.05
const rho_c = 2.85e6
const h = 5.0
const Cbat = 5*3600

# Initializing
AL = zeros(n)
AB = zeros(n)
V  = zeros(n)

for i in 1:n
        AL[i] = 2*π*i*Δr*L
        AB[i] = π*(Δr^2)*(i^2 - ((i-1)^2))
        V[i]  = AB[i]*L
end

# Precomputing M and B
M = zeros(n,n)
B = zeros(n)



for i in 1:n
    if i ==1
        M[1,1] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[1]))/(rho_c*V[1])
        M[1,2] = (k*AL[1])/(1.5*Δr*rho_c*V[1])
        B[1]   = (2*h*AB[1])/(rho_c*V[1])
    
    elseif i == 2
        M[2,2] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[2])-((k*AL[2])/Δr))/(rho_c*V[2])
        M[2,1] = (k*AL[1])/(1.5*Δr*rho_c*V[2])
        M[2,3] = (k*AL[2])/(rho_c*V[2]*Δr)
        B[2]   = (2*h*AB[2])/(rho_c*V[2])
    
    elseif i == n
        M[n,n]   = (((-k*AL[n-1])/Δr)-(2*h*AB[n])-(1.5*h*AL[n]))/(rho_c*V[n])
        M[n,n-1] = (((k*AL[n-1])/Δr)+(0.5*h*AL[n]))/(rho_c*V[n])
        B[n]     = ((2*h*AB[n])+(h*AL[n]))/(rho_c*V[n])

    else 
        M[i,i]   = (((-k*AL[i-1])/Δr) - ((k*AL[i])/Δr)-(2*h*AB[i]))/(rho_c*V[i])
        M[i,i-1] = (k*AL[i-1])/(Δr*rho_c*V[i])
        M[i,i+1] = (k*AL[i])/(Δr*rho_c*V[i])
        B[i]     = (2*h*AB[i])/(rho_c*V[i])

    end
end
M = sparse(M)

# Defining the ODE model
function UDE_model!(du,u,p,t,n,M,B,T∞,I)
        


    #Neural Network
    G = [(sign(I)^2)*((U([u[n+1],u[i],I],p,_st_)[1][1])^2) for i in 1:n]

    du[1:n] .= M*u[1:n] + B.*T∞ + G./rho_c

   
    du[n+1] = -I/Cbat

end

T∞1 = 282.553
T∞2 = 297.67
T∞3 = 297.664

I1 = 5.0
I2 = 5.0
I3 = 10.0

t1 = collect(99076.0:1:102233.0)
t2 = collect(79053.0:1:82426.0)
t3 = collect(105517.0:1:107199.0)

T01 = fill(T∞1,n)
T02 = fill(T∞2,n)
T03 = fill(T∞3,n)

u01 = vcat(T01,1.0)
u02 = vcat(T02,1.0)
u03 = vcat(T03,1.0)



UDE_model1!(du,u,p,t) = UDE_model!(du,u,p,t,n,M,B,T∞1,I1)
UDE_model2!(du,u,p,t) = UDE_model!(du,u,p,t,n,M,B,T∞2,I2)
UDE_model3!(du,u,p,t) = UDE_model!(du,u,p,t,n,M,B,T∞3,I3)


prob1 = ODEProblem(UDE_model1!,u01,(t1[1],t1[end]),_para)
prob2 = ODEProblem(UDE_model2!,u02,(t2[1],t2[end]),_para)
prob3 = ODEProblem(UDE_model3!,u03,(t3[1],t3[end]),_para)


@time inisol1 = solve(prob1,Rosenbrock23(),saveat = t1)
@time inisol2 = solve(prob2,Rosenbrock23(),saveat = t2)
@time inisol3 = solve(prob3,Rosenbrock23(),saveat = t3)

Right now the speed is

inisol1 -> 13.82 k allocations: 1.094 MiB
inisol2 -> 15.64 k allocations: 1.242 MiB
inisol3 -> 10.87 k allocations: 891.984 KiB

My collegue who works on UDE said the allocations are high now.
So I need some tips on improving it so that my optimization is faster. Any help would be much appreciated.

the fastest I could get is :

using DifferentialEquations
using StableRNGs, Lux
using ComponentArrays,SparseArrays
using PreallocationTools

struct params{MO,P,S,MT,BT,T,C,C2}
    model ::MO
    ps ::P
    st ::S
    n ::Int
    M ::MT 
    B ::BT
    T∞ ::T
    I ::Int
    cache ::C
    cache2 ::C2
end

# Defining the ODE model
function UDE_model!(du,u,p,t)
    # Extracting parameters
    n = p.n
    M = p.M
    B = p.B
    T∞ = p.T∞
    I = p.I
    rho_c = 2.85e6
    Cbat = 5*3600
    cache = get_tmp(p.cache,u)
    mul!(cache,M,@view(u[1:n]))
    cache2 = get_tmp(p.cache2,u)
    @views cache2[1,:] .= u[n+1]
    @views cache2[2,:] .= u[1:n]
    @views cache2[3,:] .= I 
    C,_ = p.model(cache2,p.ps,p.st)
    C .= C.^2
    C .= C .* sign(p.I)^2
    for i in 1:n
        du[i] = C[1,i]/rho_c + B[i] * T∞ + cache[i]
    end
    du[n+1] = -I/Cbat
    nothing
end

function initialize()
    #Defining the neural network
    U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
    rng = StableRNG(1111)
    _para,st = Lux.setup(rng,U)
    _para = f64(_para)
    #Setting the number of states
    n = 5
    # Finding the parameters
    d   = 21e-3 # diameters in m
    r   = d/2
    L   = 70e-3 # radius in m
    Δr = r/n
    k = 1.05
    rho_c = 2.85e6
    h = 5.0
    Cbat = 5*3600
    # Initializing
    AL = zeros(n)
    AB = zeros(n)
    V  = zeros(n)
    for i in 1:n
            AL[i] = 2*π*i*Δr*L
            AB[i] = π*(Δr^2)*(i^2 - ((i-1)^2))
            V[i]  = AB[i]*L
    end
    # Precomputing M and B
    M = zeros(n,n)
    B = zeros(n)
    for i in 1:n
        if i ==1
            M[1,1] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[1]))/(rho_c*V[1])
            M[1,2] = (k*AL[1])/(1.5*Δr*rho_c*V[1])
            B[1]   = (2*h*AB[1])/(rho_c*V[1])
        elseif i == 2
            M[2,2] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[2])-((k*AL[2])/Δr))/(rho_c*V[2])
            M[2,1] = (k*AL[1])/(1.5*Δr*rho_c*V[2])
            M[2,3] = (k*AL[2])/(rho_c*V[2]*Δr)
            B[2]   = (2*h*AB[2])/(rho_c*V[2])
        
        elseif i == n
            M[n,n]   = (((-k*AL[n-1])/Δr)-(2*h*AB[n])-(1.5*h*AL[n]))/(rho_c*V[n])
            M[n,n-1] = (((k*AL[n-1])/Δr)+(0.5*h*AL[n]))/(rho_c*V[n])
            B[n]     = ((2*h*AB[n])+(h*AL[n]))/(rho_c*V[n])

        else 
            M[i,i]   = (((-k*AL[i-1])/Δr) - ((k*AL[i])/Δr)-(2*h*AB[i]))/(rho_c*V[i])
            M[i,i-1] = (k*AL[i-1])/(Δr*rho_c*V[i])
            M[i,i+1] = (k*AL[i])/(Δr*rho_c*V[i])
            B[i]     = (2*h*AB[i])/(rho_c*V[i])

        end
    end
    M = sparse(M)
    T∞1 = 282.553
    T∞2 = 297.67
    T∞3 = 297.664
    I1 = 5
    I2 = 5
    I3 = 10
    cache1 = DiffCache(zeros(n))
    cache12 = DiffCache(zeros(3,n))
    cache2 = DiffCache(zeros(n))
    cache22 = DiffCache(zeros(3,n))
    cache3 = DiffCache(zeros(n))
    cache32 = DiffCache(zeros(3,n))
    p1 = params(U,_para,st,n,M,B,T∞1,I1,cache1,cache12)
    p2 = params(U,_para,st,n,M,B,T∞2,I2,cache2,cache22)
    p3 = params(U,_para,st,n,M,B,T∞3,I3,cache3,cache32)
    t1 = collect(99076.0:1:102233.0)
    t2 = collect(79053.0:1:82426.0)
    t3 = collect(105517.0:1:107199.0)
    T01 = fill(T∞1,n)
    T02 = fill(T∞2,n)
    T03 = fill(T∞3,n)
    u01 = vcat(T01,1.0)
    u02 = vcat(T02,1.0)
    u03 = vcat(T03,1.0)
    prob1 = ODEProblem(UDE_model!,u01,(t1[1],t1[end]),p1)
    prob2 = ODEProblem(UDE_model!,u02,(t2[1],t2[end]),p2)
    prob3 = ODEProblem(UDE_model!,u03,(t3[1],t3[end]),p3)
    return prob1, prob2, prob3, t1, t2, t3
end

prob1, prob2, prob3, t1, t2, t3 = initialize()
@time inisol1 = solve(prob1,Rosenbrock23(),saveat = t1);
@time inisol2 = solve(prob2,Rosenbrock23(),saveat = t2);
@time inisol3 = solve(prob3,Rosenbrock23(),saveat = t3);

however you will need to implement the interface if you want to diff through that :

you can use const var if you want but avoid the closure you did that’s as bad as having global, also I’m not sure caching everything is good when you will diff through it you may need to avoid some caching I did because they may hide the derivatives.

Finally it is still quite a mess I think I’ve never been able to write the interface for ML related stuff like why this not work :

using DifferentialEquations
using StableRNGs, Lux
using ComponentArrays,SparseArrays
using PreallocationTools
using SciMLSensitivity
import SciMLStructures as SS
using LinearAlgebra
using Zygote
using Parameters

mutable struct params{MO,P,S,MT,BT,T,C,C2}
    model ::MO
    ps ::P
    st ::S
    n ::Int
    M ::MT 
    B ::BT
    T∞ ::T
    I ::Int
    cache ::C
    cache2 ::C2
end
SS.isscimlstructure(::params) = true
SS.ismutablescimlstructure(::params) = true
SS.hasportion(::SS.Tunable, ::params) = true
function SS.canonicalize(::SS.Tunable, p::params)
  buffer = copy(p.ps)
  repack = let p = p
    function repack(newbuffer)
      SS.replace(SS.Tunable(), p, newbuffer)
    end
  end
  return buffer, repack, false
end
function SS.replace(::SS.Tunable, p::params, newbuffer)
  return params(
    p.model,newbuffer,p.st,p.n,p.M,p.B,p.T∞,p.I,p.cache,p.cache2
  )
end

function SS.replace!(::SS.Tunable, p::params, newbuffer)
    p.ps = newbuffer
    return p
end

function initialize()
    #Defining the neural network
    U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
    rng = StableRNG(1111)
    _para,st = Lux.setup(rng,U)
    _para = ComponentArray(f64(_para))
    #Setting the number of states
    n = 5
    # Finding the parameters
    d   = 21e-3 # diameters in m
    r   = d/2
    L   = 70e-3 # radius in m
    Δr = r/n
    k = 1.05
    rho_c = 2.85e6
    h = 5.0
    Cbat = 5*3600
    # Initializing
    AL = zeros(n)
    AB = zeros(n)
    V  = zeros(n)
    for i in 1:n
            AL[i] = 2*π*i*Δr*L
            AB[i] = π*(Δr^2)*(i^2 - ((i-1)^2))
            V[i]  = AB[i]*L
    end
    # Precomputing M and B
    M = zeros(n,n)
    B = zeros(n)
    for i in 1:n
        if i ==1
            M[1,1] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[1]))/(rho_c*V[1])
            M[1,2] = (k*AL[1])/(1.5*Δr*rho_c*V[1])
            B[1]   = (2*h*AB[1])/(rho_c*V[1])
        elseif i == 2
            M[2,2] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[2])-((k*AL[2])/Δr))/(rho_c*V[2])
            M[2,1] = (k*AL[1])/(1.5*Δr*rho_c*V[2])
            M[2,3] = (k*AL[2])/(rho_c*V[2]*Δr)
            B[2]   = (2*h*AB[2])/(rho_c*V[2])
        
        elseif i == n
            M[n,n]   = (((-k*AL[n-1])/Δr)-(2*h*AB[n])-(1.5*h*AL[n]))/(rho_c*V[n])
            M[n,n-1] = (((k*AL[n-1])/Δr)+(0.5*h*AL[n]))/(rho_c*V[n])
            B[n]     = ((2*h*AB[n])+(h*AL[n]))/(rho_c*V[n])

        else 
            M[i,i]   = (((-k*AL[i-1])/Δr) - ((k*AL[i])/Δr)-(2*h*AB[i]))/(rho_c*V[i])
            M[i,i-1] = (k*AL[i-1])/(Δr*rho_c*V[i])
            M[i,i+1] = (k*AL[i])/(Δr*rho_c*V[i])
            B[i]     = (2*h*AB[i])/(rho_c*V[i])

        end
    end
    M = sparse(M)
    T∞1 = 282.553
    T∞2 = 297.67
    T∞3 = 297.664
    I1 = 5
    I2 = 5
    I3 = 10
    cache1 = DiffCache(zeros(n))
    cache12 = DiffCache(zeros(3,n))
    cache2 = DiffCache(zeros(n))
    cache22 = DiffCache(zeros(3,n))
    cache3 = DiffCache(zeros(n))
    cache32 = DiffCache(zeros(3,n))
    p1 = params(U,_para,st,n,M,B,T∞1,I1,cache1,cache12)
    p2 = params(U,_para,st,n,M,B,T∞2,I2,cache2,cache22)
    p3 = params(U,_para,st,n,M,B,T∞3,I3,cache3,cache32)
    t1 = collect(99076.0:1:102233.0)
    t2 = collect(79053.0:1:82426.0)
    t3 = collect(105517.0:1:107199.0)
    T01 = fill(T∞1,n)
    T02 = fill(T∞2,n)
    T03 = fill(T∞3,n)
    u01 = vcat(T01,1.0)
    u02 = vcat(T02,1.0)
    u03 = vcat(T03,1.0)
    prob1 = ODEProblem(UDE_model!,u01,(t1[1],t1[end]),p1)
    prob2 = ODEProblem(UDE_model!,u02,(t2[1],t2[end]),p2)
    prob3 = ODEProblem(UDE_model!,u03,(t3[1],t3[end]),p3)
    return prob1, prob2, prob3, t1, t2, t3, u01, u02, u03, p1, p2, p3
end


# Defining the ODE model
function UDE_model!(du,u,p,t)
    # Extracting parameters
    Parameters.@unpack model,ps,st,n,M,B,T∞,I,cache,cache2 = p
    rho_c = 2.85e6
    Cbat = 5*3600
    cache = get_tmp(cache,u)
    mul!(cache,M,@view(u[1:n]))
    cache2 = get_tmp(cache2,u)
    C = model(cache2,ps,st)[1].^2 .* (sign(I)^2)
    @views du[1:n] .= C[1,:]./rho_c .+ B .* T∞ .+ cache
    du[n+1] = -I/Cbat
    nothing
end


prob1, prob2, prob3, t1, t2, t3, u01, u02, u03, p1, p2, p3 = initialize()
@time inisol1 = solve(prob1,Rosenbrock23(),saveat = t1);
@time inisol2 = solve(prob2,Rosenbrock23(),saveat = t2);
@time inisol3 = solve(prob3,Rosenbrock23(),saveat = t3);

function run_diff(ps)
    p = p1
    p.ps = ps
    prob = remake(prob1, p = p)
    sol = solve(prob, Rosenbrock23(), saveat = t1)
    return sol.u |> last |> sum
end

run_diff(p1.ps)
Zygote.gradient(run_diff, p1.ps)

for some reason the solve call similar on the structure and if I implement it it will call length which is nonsense

Using Rosenbrock23 is almost never a good idea with training. It didn’t have the right properties. The title is also a bit confusing too, it says speed of the forward pass but then I think the discussion is on training time? Those can have very different properties, and prefer different solvers. I’d probably swap out to FBDF and maybe try Reactant compiling first before doing other things. Post your update and I should be able to try it this weekend if this isn’t already solved by what I mentioned

1 Like

Thank you so much for your reply. I will give it a try. Really appreciate your effort :slight_smile:

Thank you for your reply. I am trying to use Reactant. But the following error message pops up

Precompiling Reactant...
Info Given Reactant was explicitly requested, output will be shown live
ERROR: LoadError: UndefVarError: `RNumber` not defined in `EnzymeCore`
Stacktrace:
  [1] getproperty(x::Module, f::Symbol)
    @ Base .\Base.jl:42
  [2] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\analyses\activity.jl:76
  [3] include(mod::Module, _path::String)
    @ Base .\Base.jl:557
  [4] include(x::String)
    @ Enzyme.Compiler C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:1
  [5] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:175
  [6] include(mod::Module, _path::String)
    @ Base .\Base.jl:557
  [7] include(x::String)
    @ Enzyme C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:1
  [8] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:127
  [9] include
    @ .\Base.jl:557 [inlined]
 [10] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::String)
    @ Base .\loading.jl:2881
 [11] top-level scope
    @ stdin:6
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\analyses\activity.jl:76
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:1
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:1
in expression starting at stdin:6
ERROR: LoadError: Failed to precompile Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] to "C:\\Users\\Kalath_A\\.julia\\compiled\\v1.11\\Enzyme\\jl_2E3E.tmp".
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] compilecache(pkg::Base.PkgId, path::String, internal_stderr::IO, internal_stdout::IO, keep_loaded_modules::Bool; flags::Cmd, cacheflags::Base.CacheFlags, reasons::Dict{String, Int64}, loadable_exts::Nothing)
    @ Base .\loading.jl:3174
  [3] (::Base.var"#1110#1111"{Base.PkgId})()
    @ Base .\loading.jl:2579
  [4] mkpidlock(f::Base.var"#1110#1111"{Base.PkgId}, at::String, pid::Int32; kwopts::@Kwargs{stale_age::Int64, wait::Bool})    
    @ FileWatching.Pidfile C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:95
  [5] #mkpidlock#6
    @ C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:90 [inlined]
  [6] trymkpidlock(::Function, ::Vararg{Any}; kwargs::@Kwargs{stale_age::Int64})
    @ FileWatching.Pidfile C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:116PUArraysCoreExt
  [7] #invokelatest#2
    @ .\essentials.jl:1057 [inlined]
  [8] invokelatest
    @ .\essentials.jl:1052 [inlined]
  [9] maybe_cachefile_lock(f::Base.var"#1110#1111"{Base.PkgId}, pkg::Base.PkgId, srcpath::String; stale_age::Int64)
    @ Base .\loading.jl:3698
 [10] maybe_cachefile_lock
    @ .\loading.jl:3695 [inlined]
 [11] _require(pkg::Base.PkgId, env::String)
    @ Base .\loading.jl:2565
 [12] __require_prelocked(uuidkey::Base.PkgId, env::String)
    @ Base .\loading.jl:2388
 [13] #invoke_in_world#3
    @ .\essentials.jl:1089 [inlined]
 [14] invoke_in_world
    @ .\essentials.jl:1086 [inlined]
 [15] _require_prelocked(uuidkey::Base.PkgId, env::String)
    @ Base .\loading.jl:2375
 [16] macro expansion
    @ .\loading.jl:2314 [inlined]
 [17] macro expansion
    @ .\lock.jl:273 [inlined]
 [18] __require(into::Module, mod::Symbol)
    @ Base .\loading.jl:2271
 [19] #invoke_in_world#3
    @ .\essentials.jl:1089 [inlined]
 [20] invoke_in_world
    @ .\essentials.jl:1086 [inlined]
 [21] require(into::Module, mod::Symbol)
    @ Base .\loading.jl:2260
 [22] include
    @ .\Base.jl:557 [inlined]
 [23] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::Nothing)
    @ Base .\loading.jl:2881
 [24] top-level scope
    @ stdin:6
in expression starting at C:\Users\Kalath_A\.julia\packages\Reactant\yfoXU\src\Reactant.jl:1
in expression starting at stdin:6
  ✗ Enzyme
  ✗ Enzyme → EnzymeGPUArraysCoreExt
  ✗ Reactant
  0 dependencies successfully precompiled in 21 seconds. 61 already precompiled.

ERROR: The following 1 direct dependency failed to precompile:

Reactant

Failed to precompile Reactant [3c362404-f566-11ee-1572-e11a4b42c853] to "C:\\Users\\Kalath_A\\.julia\\compiled\\v1.11\\Reactant\\jl_2C40.tmp".
ERROR: LoadError: UndefVarError: `RNumber` not defined in `EnzymeCore`
Stacktrace:
  [1] getproperty(x::Module, f::Symbol)
    @ Base .\Base.jl:42
  [2] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\analyses\activity.jl:76
  [3] include(mod::Module, _path::String)
    @ Base .\Base.jl:557
  [4] include(x::String)
    @ Enzyme.Compiler C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:1
  [5] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:175
  [6] include(mod::Module, _path::String)
    @ Base .\Base.jl:557
  [7] include(x::String)
    @ Enzyme C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:1
  [8] top-level scope
    @ C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:127
  [9] include
    @ .\Base.jl:557 [inlined]
 [10] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::String)
    @ Base .\loading.jl:2881
 [11] top-level scope
    @ stdin:6
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\analyses\activity.jl:76
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\compiler.jl:1
in expression starting at C:\Users\Kalath_A\.julia\packages\Enzyme\owyQj\src\Enzyme.jl:1
in expression starting at stdin:6
ERROR: LoadError: Failed to precompile Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] to "C:\\Users\\Kalath_A\\.julia\\compiled\\v1.11\\Enzyme\\jl_2E3E.tmp".
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:35
  [2] compilecache(pkg::Base.PkgId, path::String, internal_stderr::IO, internal_stdout::IO, keep_loaded_modules::Bool; flags::Cmd, cacheflags::Base.CacheFlags, reasons::Dict{String, Int64}, loadable_exts::Nothing)
    @ Base .\loading.jl:3174
  [3] (::Base.var"#1110#1111"{Base.PkgId})()
    @ Base .\loading.jl:2579
  [4] mkpidlock(f::Base.var"#1110#1111"{Base.PkgId}, at::String, pid::Int32; kwopts::@Kwargs{stale_age::Int64, wait::Bool})    
    @ FileWatching.Pidfile C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:95
  [5] #mkpidlock#6
    @ C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:90 [inlined]
  [6] trymkpidlock(::Function, ::Vararg{Any}; kwargs::@Kwargs{stale_age::Int64})
    @ FileWatching.Pidfile C:\Users\Kalath_A\.julia\juliaup\julia-1.11.3+0.x64.w64.mingw32\share\julia\stdlib\v1.11\FileWatching\src\pidfile.jl:116
  [7] #invokelatest#2
    @ .\essentials.jl:1057 [inlined]
  [8] invokelatest
    @ .\essentials.jl:1052 [inlined]
  [9] maybe_cachefile_lock(f::Base.var"#1110#1111"{Base.PkgId}, pkg::Base.PkgId, srcpath::String; stale_age::Int64)
    @ Base .\loading.jl:3698
 [10] maybe_cachefile_lock
    @ .\loading.jl:3695 [inlined]
 [11] _require(pkg::Base.PkgId, env::String)
    @ Base .\loading.jl:2565
 [12] __require_prelocked(uuidkey::Base.PkgId, env::String)
    @ Base .\loading.jl:2388
 [13] #invoke_in_world#3
    @ .\essentials.jl:1089 [inlined]
 [14] invoke_in_world
    @ .\essentials.jl:1086 [inlined]
 [15] _require_prelocked(uuidkey::Base.PkgId, env::String)
    @ Base .\loading.jl:2375
 [16] macro expansion
    @ .\loading.jl:2314 [inlined]
 [17] macro expansion
    @ .\lock.jl:273 [inlined]
 [18] __require(into::Module, mod::Symbol)
    @ Base .\loading.jl:2271
 [19] #invoke_in_world#3
    @ .\essentials.jl:1089 [inlined]
 [20] invoke_in_world
    @ .\essentials.jl:1086 [inlined]
 [21] require(into::Module, mod::Symbol)
    @ Base .\loading.jl:2260
 [22] include
    @ .\Base.jl:557 [inlined]
 [23] include_package_for_output(pkg::Base.PkgId, input::String, depot_path::Vector{String}, dl_load_path::Vector{String}, load_path::Vector{String}, concrete_deps::Vector{Pair{Base.PkgId, UInt128}}, source::Nothing)
    @ Base .\loading.jl:2881
 [24] top-level scope
    @ stdin:6
in expression starting at C:\Users\Kalath_A\.julia\packages\Reactant\yfoXU\src\Reactant.jl:1
in expression starting at stdin:

I couldn’t find any information on this error. The Pkg.status() is

Status `E:\PhD Ashima\Neural ODE\Julia\env_NODE\Project.toml`
⌃ [b0b7db55] ComponentArrays v0.15.23
⌃ [0c46a032] DifferentialEquations v7.15.0
⌃ [a98d9a8b] Interpolations v0.15.1
⌃ [033835bb] JLD2 v0.5.11
⌃ [d3d80556] LineSearches v7.3.0
⌃ [b2108857] Lux v1.6.0
  [23992714] MAT v0.10.7
⌃ [7f7a1694] Optimization v4.1.0
⌃ [36348300] OptimizationOptimJL v0.4.1
  [42dfb2eb] OptimizationOptimisers v0.3.7
  [58dd65bb] Plotly v0.4.1
⌃ [91a5bcdd] Plots v1.40.9
  [d236fae5] PreallocationTools v0.4.27
⌃ [3c362404] Reactant v0.2.41
⌃ [1ed8b502] SciMLSensitivity v7.72.0
⌃ [860ef19b] StableRNGs v1.0.2
  [10745b16] Statistics v1.11.1
⌅ [e88e6eb3] Zygote v0.6.75
  [37e2e46d] LinearAlgebra v1.11.0
  [9a3f8284] Random v1.11.0
  [2f01184e] SparseArrays v1.11.0
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted 
by compatibility constraints from upgrading. To see why use `status --outdated`

and my Julia version is

Julia Version 1.11.3
Commit d63adeda50 (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 40 × Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, skylake-avx512)
Threads: 1 default, 0 interactive, 1 GC (on 40 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_VSCODE_REPL = 1

Any idea why this is happening?

Reactant can’t be used on windows it works (sometimes) with wsl sublinux but you may hit weird bugs

Have you looked at the Julia performance tips and the Differential Equations performance tips?

  1. Don’t use global variables, pass them via parameters
  2. You are allocating lots of temporary arrays. Don’t.
  3. Since your matrices are only 5x5, consider using StaticArrays.jl. SparseArrays are probably not beneficial for such small arrays; in this case the array seems to be tridiagonal so you could use the Tridiagonal type (which can still be used with StaticArrays).

Ok got it if you want :

using DifferentialEquations
using StableRNGs, Lux
using ComponentArrays,SparseArrays
using PreallocationTools
using SciMLSensitivity
import SciMLStructures as SS
using LinearAlgebra
using Zygote
using Parameters

mutable struct params{MO,P,S,MT,BT,T,C,C2}
    model ::MO
    ps ::P
    st ::S
    n ::Int
    M ::MT 
    B ::BT
    T∞ ::T
    I ::Int
    cache ::C
    cache2 ::C2
end
SS.isscimlstructure(::params) = true
SS.ismutablescimlstructure(::params) = true
SS.hasportion(::SS.Tunable, ::params) = true
function SS.canonicalize(::SS.Tunable, p::params)
  buffer = p.ps
  repack = let p = p
    function repack(newbuffer)
      SS.replace(SS.Tunable(), p, newbuffer)
    end
  end
  return buffer, repack, true
end
function SS.replace(::SS.Tunable, p::params, newbuffer)
  return params(
    p.model,newbuffer,p.st,p.n,p.M,p.B,p.T∞,p.I,p.cache,p.cache2
  )
end

function SS.replace!(::SS.Tunable, p::params, newbuffer)
    p.ps = newbuffer
    return p
end

function initialize()
    #Defining the neural network
    U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
    rng = StableRNG(1111)
    _para,st = Lux.setup(rng,U) |> f64
    _para = ComponentArray(f64(_para))
    #Setting the number of states
    n = 5
    # Finding the parameters
    d   = 21e-3 # diameters in m
    r   = d/2
    L   = 70e-3 # radius in m
    Δr = r/n
    k = 1.05
    rho_c = 2.85e6
    h = 5.0
    Cbat = 5*3600
    # Initializing
    AL = zeros(n)
    AB = zeros(n)
    V  = zeros(n)
    for i in 1:n
            AL[i] = 2*π*i*Δr*L
            AB[i] = π*(Δr^2)*(i^2 - ((i-1)^2))
            V[i]  = AB[i]*L
    end
    # Precomputing M and B
    M = zeros(n,n)
    B = zeros(n)
    for i in 1:n
        if i ==1
            M[1,1] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[1]))/(rho_c*V[1])
            M[1,2] = (k*AL[1])/(1.5*Δr*rho_c*V[1])
            B[1]   = (2*h*AB[1])/(rho_c*V[1])
        elseif i == 2
            M[2,2] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[2])-((k*AL[2])/Δr))/(rho_c*V[2])
            M[2,1] = (k*AL[1])/(1.5*Δr*rho_c*V[2])
            M[2,3] = (k*AL[2])/(rho_c*V[2]*Δr)
            B[2]   = (2*h*AB[2])/(rho_c*V[2])
        
        elseif i == n
            M[n,n]   = (((-k*AL[n-1])/Δr)-(2*h*AB[n])-(1.5*h*AL[n]))/(rho_c*V[n])
            M[n,n-1] = (((k*AL[n-1])/Δr)+(0.5*h*AL[n]))/(rho_c*V[n])
            B[n]     = ((2*h*AB[n])+(h*AL[n]))/(rho_c*V[n])

        else 
            M[i,i]   = (((-k*AL[i-1])/Δr) - ((k*AL[i])/Δr)-(2*h*AB[i]))/(rho_c*V[i])
            M[i,i-1] = (k*AL[i-1])/(Δr*rho_c*V[i])
            M[i,i+1] = (k*AL[i])/(Δr*rho_c*V[i])
            B[i]     = (2*h*AB[i])/(rho_c*V[i])

        end
    end
    M = sparse(M)
    T∞1 = 282.553
    T∞2 = 297.67
    T∞3 = 297.664
    I1 = 5
    I2 = 5
    I3 = 10
    cache1 = DiffCache(zeros(n))
    cache12 = DiffCache(zeros(3,n))
    cache2 = DiffCache(zeros(n))
    cache22 = DiffCache(zeros(3,n))
    cache3 = DiffCache(zeros(n))
    cache32 = DiffCache(zeros(3,n))
    p1 = params(U,_para,st,n,M,B,T∞1,I1,cache1,cache12)
    p2 = params(U,_para,st,n,M,B,T∞2,I2,cache2,cache22)
    p3 = params(U,_para,st,n,M,B,T∞3,I3,cache3,cache32)
    t1 = collect(99076.0:1:102233.0)
    t2 = collect(79053.0:1:82426.0)
    t3 = collect(105517.0:1:107199.0)
    T01 = fill(T∞1,n)
    T02 = fill(T∞2,n)
    T03 = fill(T∞3,n)
    u01 = vcat(T01,1.0)
    u02 = vcat(T02,1.0)
    u03 = vcat(T03,1.0)
    prob1 = ODEProblem(UDE_model!,u01,(t1[1],t1[end]),p1)
    prob2 = ODEProblem(UDE_model!,u02,(t2[1],t2[end]),p2)
    prob3 = ODEProblem(UDE_model!,u03,(t3[1],t3[end]),p3)
    return prob1, prob2, prob3, t1, t2, t3, u01, u02, u03, p1, p2, p3
end


# Defining the ODE model
function UDE_model!(du,u,p,t)
    # Extracting parameters
    Parameters.@unpack model,ps,st,n,M,B,T∞,I,cache,cache2 = p
    rho_c = 2.85e6
    Cbat = 5*3600
    cache = get_tmp(cache,u)
    mul!(cache,M,@view(u[1:n]))
    cache2 = get_tmp(cache2,u)
    for i in 1:n
        cache2[1,i] = u[n+1]
        cache2[2,i] = u[i]
        cache2[3,i] = I
    end
    C,_ = model(cache2,ps,st)
    for i in 1:n
        du[i] = C[1,i]^2/rho_c + B[i] * T∞ + cache[i]
    end
    du[n+1] = -I/Cbat
    nothing
end


prob1, prob2, prob3, t1, t2, t3, u01, u02, u03, p1, p2, p3 = initialize()
@time inisol1 = solve(prob1,Rosenbrock23(),saveat = t1);
@time inisol2 = solve(prob2,Rosenbrock23(),saveat = t2);
@time inisol3 = solve(prob3,Rosenbrock23(),saveat = t3);

function run_diff(ps,prob,p,t1)
    p = params(p.model,ps,p.st,p.n,p.M,p.B,p.T∞,p.I,p.cache,p.cache2)
    prob = remake(prob, p = p)
    sol = solve(prob)
    return sol.u |> last |> sum
end

Zygote.gradient(ps->run_diff(ps,prob1,p1,t1), p1.ps)

this may still be less good than Reactant.jl but its ok I think.
For now it is 3ms forward and 7s for the gradient calculation this is still quite bad but its better.

Just wanted to follow up here to let folks know that we just released support for reactant on native windows (though just cpu). CUDA reactant support for windows still exists, but requires WSL.

2 Likes

Hey,
Has anyone managed to get Reactant working with UDE?
I tried compiling the neural network inside my UDE using Reactant.@compile to speed up training, but I encountered an error: MethodError: no method matching Float32(::ForwardDiff.Dual…)
Any help would be much appreciated.

I just merged Reactant support 3 days ago :sweat_smile:

so if you’re talking about ReactantVJP, probably no one has ran it yet because I don’t think anyone was following it that closely haha. I don’t think I even added it to the docs yet?

But the remaining Reactant issues are documented in Track remaining ReactantVJP @test_broken items · Issue #1362 · SciML/SciMLSensitivity.jl · GitHub. @wsmoses fixed a few things upstream so I believe most of it should be working after that, I’ll need to re-enable some tests over the weekend. But

is exactly Problem 1 in that issue, which is nesting ForwardDiff over Reactant. I tried a few things to integrate PreallocationTools to workaround that Fix ReactantVJP with stiff solvers via Enzyme fallback for ForwardDiff.Dual by ChrisRackauckas-Claude · Pull Request #1363 · SciML/SciMLSensitivity.jl · GitHub but it wasn’t working well so I’ll need to talk with @avikpal and @wsmoses how to best integrate that together. I think I need to setup some manual dispatching so that I JIT multiple kernels with Reactant.@jit and then choose which one to call based on whether there’s duals or not.

I plan to finish this stuff up over the weekend. And then we’re setting up new benchmarks and optimizing once that is done.

1 Like

Hm is there anything actionable reactant side. Like I said on the sciml issue a few days ago all the reactant issues have been fixed so I’m not aware of anything on our side?

As for forwarddiff, can you just use Enzyme there? It should both be faster and more compatible

1 Like

Not always, and it would mess with a few defaults. The ODE solver cannot default to Enzyme for many many reasons, such as requiring that the software has to work on the day of every new Julia release. So the ODE Jacobians, which should always use forward mode anyways, don’t have a major benefit from Enzyme here but would have a major issue with maintainability and support.

But we could say figure out that we are doing an adjoint with ReactantVJP, and thus force the ODE solver to use Enzyme for Jacobians in the backwards pass. But in that case, it’s a different than the normal case I know @avikpal had been working on. What Avik made support for was Reactant over other AD calls, and lowering all of those to Enzyme IIRC. The case we need here is Enzyme calling a function that calls a Reactant kernel. Is that case handled and tested?

Yes as I said, not right now, and I’ll get back to this on the weekend. Unless you want to build some tooling to help with the ForwardDiff dispatching.

1 Like

Can use Reactant.within_compile() to set the default to Enzyme if only within Reactant (within_compile is a compile time constant so won’t impact other usage)?

1 Like

It’s not within a Reactant compile though.

1 Like

Thanks for the suggestion! I tried using Enzyme but it still fails. I am getting an error that says:
MethodError: no method matching nn(::Vector{Float64}, ::ConcretePJRTArray…).

Indeed confirmed your suggestion here doesn’t work. Here’s a Reactant MWE:

so we’ll need to do manual Reactant kernel dispatching.

1 Like

New adjoint benchmarks show some sped improvements. But this case really is asking for ReactantVJP

1 Like

Really thanks for taking the time to look into this issue!