the fastest I could get is :
using DifferentialEquations
using StableRNGs, Lux
using ComponentArrays,SparseArrays
using PreallocationTools
struct params{MO,P,S,MT,BT,T,C,C2}
model ::MO
ps ::P
st ::S
n ::Int
M ::MT
B ::BT
T∞ ::T
I ::Int
cache ::C
cache2 ::C2
end
# Defining the ODE model
function UDE_model!(du,u,p,t)
# Extracting parameters
n = p.n
M = p.M
B = p.B
T∞ = p.T∞
I = p.I
rho_c = 2.85e6
Cbat = 5*3600
cache = get_tmp(p.cache,u)
mul!(cache,M,@view(u[1:n]))
cache2 = get_tmp(p.cache2,u)
@views cache2[1,:] .= u[n+1]
@views cache2[2,:] .= u[1:n]
@views cache2[3,:] .= I
C,_ = p.model(cache2,p.ps,p.st)
C .= C.^2
C .= C .* sign(p.I)^2
for i in 1:n
du[i] = C[1,i]/rho_c + B[i] * T∞ + cache[i]
end
du[n+1] = -I/Cbat
nothing
end
function initialize()
#Defining the neural network
U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
rng = StableRNG(1111)
_para,st = Lux.setup(rng,U)
_para = f64(_para)
#Setting the number of states
n = 5
# Finding the parameters
d = 21e-3 # diameters in m
r = d/2
L = 70e-3 # radius in m
Δr = r/n
k = 1.05
rho_c = 2.85e6
h = 5.0
Cbat = 5*3600
# Initializing
AL = zeros(n)
AB = zeros(n)
V = zeros(n)
for i in 1:n
AL[i] = 2*π*i*Δr*L
AB[i] = π*(Δr^2)*(i^2 - ((i-1)^2))
V[i] = AB[i]*L
end
# Precomputing M and B
M = zeros(n,n)
B = zeros(n)
for i in 1:n
if i ==1
M[1,1] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[1]))/(rho_c*V[1])
M[1,2] = (k*AL[1])/(1.5*Δr*rho_c*V[1])
B[1] = (2*h*AB[1])/(rho_c*V[1])
elseif i == 2
M[2,2] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[2])-((k*AL[2])/Δr))/(rho_c*V[2])
M[2,1] = (k*AL[1])/(1.5*Δr*rho_c*V[2])
M[2,3] = (k*AL[2])/(rho_c*V[2]*Δr)
B[2] = (2*h*AB[2])/(rho_c*V[2])
elseif i == n
M[n,n] = (((-k*AL[n-1])/Δr)-(2*h*AB[n])-(1.5*h*AL[n]))/(rho_c*V[n])
M[n,n-1] = (((k*AL[n-1])/Δr)+(0.5*h*AL[n]))/(rho_c*V[n])
B[n] = ((2*h*AB[n])+(h*AL[n]))/(rho_c*V[n])
else
M[i,i] = (((-k*AL[i-1])/Δr) - ((k*AL[i])/Δr)-(2*h*AB[i]))/(rho_c*V[i])
M[i,i-1] = (k*AL[i-1])/(Δr*rho_c*V[i])
M[i,i+1] = (k*AL[i])/(Δr*rho_c*V[i])
B[i] = (2*h*AB[i])/(rho_c*V[i])
end
end
M = sparse(M)
T∞1 = 282.553
T∞2 = 297.67
T∞3 = 297.664
I1 = 5
I2 = 5
I3 = 10
cache1 = DiffCache(zeros(n))
cache12 = DiffCache(zeros(3,n))
cache2 = DiffCache(zeros(n))
cache22 = DiffCache(zeros(3,n))
cache3 = DiffCache(zeros(n))
cache32 = DiffCache(zeros(3,n))
p1 = params(U,_para,st,n,M,B,T∞1,I1,cache1,cache12)
p2 = params(U,_para,st,n,M,B,T∞2,I2,cache2,cache22)
p3 = params(U,_para,st,n,M,B,T∞3,I3,cache3,cache32)
t1 = collect(99076.0:1:102233.0)
t2 = collect(79053.0:1:82426.0)
t3 = collect(105517.0:1:107199.0)
T01 = fill(T∞1,n)
T02 = fill(T∞2,n)
T03 = fill(T∞3,n)
u01 = vcat(T01,1.0)
u02 = vcat(T02,1.0)
u03 = vcat(T03,1.0)
prob1 = ODEProblem(UDE_model!,u01,(t1[1],t1[end]),p1)
prob2 = ODEProblem(UDE_model!,u02,(t2[1],t2[end]),p2)
prob3 = ODEProblem(UDE_model!,u03,(t3[1],t3[end]),p3)
return prob1, prob2, prob3, t1, t2, t3
end
prob1, prob2, prob3, t1, t2, t3 = initialize()
@time inisol1 = solve(prob1,Rosenbrock23(),saveat = t1);
@time inisol2 = solve(prob2,Rosenbrock23(),saveat = t2);
@time inisol3 = solve(prob3,Rosenbrock23(),saveat = t3);
however you will need to implement the interface if you want to diff through that :
you can use const var if you want but avoid the closure you did that’s as bad as having global, also I’m not sure caching everything is good when you will diff through it you may need to avoid some caching I did because they may hide the derivatives.
Finally it is still quite a mess I think I’ve never been able to write the interface for ML related stuff like why this not work :
using DifferentialEquations
using StableRNGs, Lux
using ComponentArrays,SparseArrays
using PreallocationTools
using SciMLSensitivity
import SciMLStructures as SS
using LinearAlgebra
using Zygote
using Parameters
mutable struct params{MO,P,S,MT,BT,T,C,C2}
model ::MO
ps ::P
st ::S
n ::Int
M ::MT
B ::BT
T∞ ::T
I ::Int
cache ::C
cache2 ::C2
end
SS.isscimlstructure(::params) = true
SS.ismutablescimlstructure(::params) = true
SS.hasportion(::SS.Tunable, ::params) = true
function SS.canonicalize(::SS.Tunable, p::params)
buffer = copy(p.ps)
repack = let p = p
function repack(newbuffer)
SS.replace(SS.Tunable(), p, newbuffer)
end
end
return buffer, repack, false
end
function SS.replace(::SS.Tunable, p::params, newbuffer)
return params(
p.model,newbuffer,p.st,p.n,p.M,p.B,p.T∞,p.I,p.cache,p.cache2
)
end
function SS.replace!(::SS.Tunable, p::params, newbuffer)
p.ps = newbuffer
return p
end
function initialize()
#Defining the neural network
U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
rng = StableRNG(1111)
_para,st = Lux.setup(rng,U)
_para = ComponentArray(f64(_para))
#Setting the number of states
n = 5
# Finding the parameters
d = 21e-3 # diameters in m
r = d/2
L = 70e-3 # radius in m
Δr = r/n
k = 1.05
rho_c = 2.85e6
h = 5.0
Cbat = 5*3600
# Initializing
AL = zeros(n)
AB = zeros(n)
V = zeros(n)
for i in 1:n
AL[i] = 2*π*i*Δr*L
AB[i] = π*(Δr^2)*(i^2 - ((i-1)^2))
V[i] = AB[i]*L
end
# Precomputing M and B
M = zeros(n,n)
B = zeros(n)
for i in 1:n
if i ==1
M[1,1] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[1]))/(rho_c*V[1])
M[1,2] = (k*AL[1])/(1.5*Δr*rho_c*V[1])
B[1] = (2*h*AB[1])/(rho_c*V[1])
elseif i == 2
M[2,2] = (((-k*AL[1])/(1.5*Δr))-(2*h*AB[2])-((k*AL[2])/Δr))/(rho_c*V[2])
M[2,1] = (k*AL[1])/(1.5*Δr*rho_c*V[2])
M[2,3] = (k*AL[2])/(rho_c*V[2]*Δr)
B[2] = (2*h*AB[2])/(rho_c*V[2])
elseif i == n
M[n,n] = (((-k*AL[n-1])/Δr)-(2*h*AB[n])-(1.5*h*AL[n]))/(rho_c*V[n])
M[n,n-1] = (((k*AL[n-1])/Δr)+(0.5*h*AL[n]))/(rho_c*V[n])
B[n] = ((2*h*AB[n])+(h*AL[n]))/(rho_c*V[n])
else
M[i,i] = (((-k*AL[i-1])/Δr) - ((k*AL[i])/Δr)-(2*h*AB[i]))/(rho_c*V[i])
M[i,i-1] = (k*AL[i-1])/(Δr*rho_c*V[i])
M[i,i+1] = (k*AL[i])/(Δr*rho_c*V[i])
B[i] = (2*h*AB[i])/(rho_c*V[i])
end
end
M = sparse(M)
T∞1 = 282.553
T∞2 = 297.67
T∞3 = 297.664
I1 = 5
I2 = 5
I3 = 10
cache1 = DiffCache(zeros(n))
cache12 = DiffCache(zeros(3,n))
cache2 = DiffCache(zeros(n))
cache22 = DiffCache(zeros(3,n))
cache3 = DiffCache(zeros(n))
cache32 = DiffCache(zeros(3,n))
p1 = params(U,_para,st,n,M,B,T∞1,I1,cache1,cache12)
p2 = params(U,_para,st,n,M,B,T∞2,I2,cache2,cache22)
p3 = params(U,_para,st,n,M,B,T∞3,I3,cache3,cache32)
t1 = collect(99076.0:1:102233.0)
t2 = collect(79053.0:1:82426.0)
t3 = collect(105517.0:1:107199.0)
T01 = fill(T∞1,n)
T02 = fill(T∞2,n)
T03 = fill(T∞3,n)
u01 = vcat(T01,1.0)
u02 = vcat(T02,1.0)
u03 = vcat(T03,1.0)
prob1 = ODEProblem(UDE_model!,u01,(t1[1],t1[end]),p1)
prob2 = ODEProblem(UDE_model!,u02,(t2[1],t2[end]),p2)
prob3 = ODEProblem(UDE_model!,u03,(t3[1],t3[end]),p3)
return prob1, prob2, prob3, t1, t2, t3, u01, u02, u03, p1, p2, p3
end
# Defining the ODE model
function UDE_model!(du,u,p,t)
# Extracting parameters
Parameters.@unpack model,ps,st,n,M,B,T∞,I,cache,cache2 = p
rho_c = 2.85e6
Cbat = 5*3600
cache = get_tmp(cache,u)
mul!(cache,M,@view(u[1:n]))
cache2 = get_tmp(cache2,u)
C = model(cache2,ps,st)[1].^2 .* (sign(I)^2)
@views du[1:n] .= C[1,:]./rho_c .+ B .* T∞ .+ cache
du[n+1] = -I/Cbat
nothing
end
prob1, prob2, prob3, t1, t2, t3, u01, u02, u03, p1, p2, p3 = initialize()
@time inisol1 = solve(prob1,Rosenbrock23(),saveat = t1);
@time inisol2 = solve(prob2,Rosenbrock23(),saveat = t2);
@time inisol3 = solve(prob3,Rosenbrock23(),saveat = t3);
function run_diff(ps)
p = p1
p.ps = ps
prob = remake(prob1, p = p)
sol = solve(prob, Rosenbrock23(), saveat = t1)
return sol.u |> last |> sum
end
run_diff(p1.ps)
Zygote.gradient(run_diff, p1.ps)
for some reason the solve call similar on the structure and if I implement it it will call length which is nonsense