Implemented the changes recommended here, which have made a big difference in the runtime and made it realistic for me to start considering larger N. I’ve been benchmarking against Numba, and once we get to N>200 or so, Numba outpaces Julia. Profiling indicates this is probably due to the matrix multiplication steps. Is mul! the fastest way for me to do it in Julia? I tried rewriting into a for loop (Julia 1 code below), which is very fast for smaller N but ramps up quickly.
Numba:
import numba
import numpy
import time as timer
from scipy import integrate
from matplotlib import pyplot
@numba.njit
def forward_RHS(u, t, *p):
N = p[2]
v_s = u[0:N]
u_s = u[N:2*N]
w_s = u[2*N:3*N]
s = u[3*N:4*N]
du = p[24]
du[0:N] = ((p[3]*(v_s - p[4])**2.0 + p[5])*(v_s - 1.0)
- p[12]*u_s*(v_s - p[14]) - p[13]*w_s*(v_s - p[15])
+ p[20]@(numpy.sin(p[23][:, 0]*t) + numpy.sin(p[23][:, 1]*t) - numpy.sin(p[23][:, 2]*t) - numpy.sin(p[23][:, 3]*t))
+ p[21]@s
+ p[22])/p[16]
du[N:2*N] = (numpy.exp(p[6]*(v_s - p[7])) + p[8] - u_s)/p[17]
du[2*N:3*N] = (p[9]*(v_s - p[10])**2.0 + p[11] - w_s)/p[18]
du[3*N:4*N] = (-s + (v_s > 0.25)*(v_s < 0.3)*(du[0:N] > 0.0)*du[0:N]/0.05)/p[19]
return du
@numba.njit
def forward_jacobian(u, t, *p):
N = p[2]
v_s = u[0:N]
u_s = u[N:2*N]
w_s = u[2*N:3*N]
s = u[3*N:4*N]
J = numpy.zeros((4*N, 4*N))
mask = numpy.identity(N)
J[0:N, 0:N] = (p[5] + 2.0*p[3]*(v_s - 1.0)*(v_s - p[4]) + p[3]*(v_s - p[4])**2.0 - p[12]*u_s - p[13]*w_s)/p[16]*mask
J[0:N, N:2*N] = -(p[12]*(v_s - p[14]))/p[16]*mask
J[0:N, 2*N:3*N] = -(p[13]*(v_s - p[15]))/p[16]*mask
J[0:N, 3*N:4*N] = self.recurrent_weights_s.sum(axis = 1)*mask
J[N:2*N, 0:N] = p[6]*numpy.exp(p[6]*(v_s - p[7]))/p[17]*mask
J[N:2*N, N:2*N] = -mask/p[17]
J[2*N:3*N, 0:N] = 2.0*(p[9]*(v_s - p[10]))/p[18]*mask
J[2*N:3*N, 2*N:3*N] = -mask/p[18]
J[3*N:4*N, 3*N:4*N] = -mask/p[19]
return J
def runprob(N):
parameters = (0, 2, N,
-500, 0.2, 0.0, 100.0, 0.7, 0.0, 15.0, 0.24, 0.02, 1700.0, 8.0, 0.19, -0.5,
25.0, 6.0, 12.0, 10.0,
2.8*(numpy.random.rand(N, 2) - 0.5), 0.5*(numpy.random.rand(N, N) - 0.8), 0.14*(numpy.random.rand(N, ) - 0.5),
numpy.random.rand(2, 4),
numpy.zeros(4*N),
numpy.zeros((4*N, 4*N)),
numpy.identity(N))
y = integrate.odeint(forward_RHS, numpy.zeros(4*N), numpy.linspace(0.0, 500.0, 5000), args = parameters, Dfun = forward_jacobian, rtol = 1e-8, atol = 1e-8, mxstep = 50000)
return y
%timeit runprob(50)
Julia 1 (for loop):
using OrdinaryDiffEq, BenchmarkTools, LSODA, Plots, Profile
using LinearAlgebra
using Parameters
@with_kw struct ODEParameters
number_of_inputs::Int
N::Int
a_s::Float64
b_s::Float64
c_s::Float64
d_s::Float64
e_s::Float64
f_s::Float64
g_s::Float64
h_s::Float64
i_s::Float64
j_s::Float64
k_s::Float64
l_s::Float64
m_s::Float64
cm_s::Float64
tau_u_s::Float64
tau_w_s::Float64
tau_s::Float64
input_weights_s::Array{Float64,2}
recurrent_weights_s::Array{Float64,2}
tonic_inputs_s::Vector{Float64}
frequencies::Array{Float64,2}
network_inputs_s::Vector{Float64}
recurrent_inputs_s::Vector{Float64}
input_cache::Vector{Float64}
end
function forward_RHS_1(du, u, p, t)
v_s = view(u, 1:p.N)
u_s = view(u, p.N+1:2*p.N)
w_s = view(u, 2*p.N+1:3*p.N)
s = view(u, 3*p.N+1:4*p.N)
p.input_cache .= sin.(t.*view(p.frequencies, :, 1)) .+ sin.(t.*view(p.frequencies, :, 2)) .- sin.(t.*view(p.frequencies, :, 3)) .- sin.(t.*view(p.frequencies, :, 4))
@inbounds for i = 1:p.N
@inbounds for j = 1:p.number_of_inputs
p.network_inputs_s[i] = p.input_weights_s[i, j]*p.input_cache[j]
end
@inbounds for j = 1:p.N
p.recurrent_inputs_s[i] = p.recurrent_weights_s[i, j]*s[j]
end
end
du[1:p.N] .= ((p.a_s.*(v_s .- p.b_s).^2 .+ p.c_s).*(v_s .- 1.0)
.- p.j_s.*u_s.*(v_s .- p.l_s) .- p.k_s.*w_s.*(v_s .- p.m_s)
.+ p.network_inputs_s
.+ p.recurrent_inputs_s
.+ p.tonic_inputs_s)./p.cm_s
du[p.N+1:2*p.N] .= (exp.(p.d_s.*(v_s .- p.e_s)) .+ p.f_s .- u_s)./p.tau_u_s
du[2*p.N+1:3*p.N] .= (p.g_s.*(v_s .- p.h_s).^2 .+ p.i_s .- w_s)./p.tau_w_s
du[3*p.N+1:4*p.N] .= (-1 .* s .+ (v_s .> 0.25).*(v_s .< 0.3).*(view(du, 1:p.N) .> 0.0).*view(du, 1:p.N)./0.05)./p.tau_s
nothing
end
function runprob1(N)
p = ODEParameters(2, N,
-500, 0.2, 0.0, 100.0, 0.7, 0.0, 15.0, 0.24, 0.02, 1700.0, 8.0, 0.19, -0.5,
25.0, 6.0, 12.0, 10.0,
2.8.*(rand(Float64, (N, 2)) .- 0.5), 0.5.*(rand(Float64, (N, N)) .- 0.8), 0.14.*(rand(Float64, (N, )) .- 0.5),
rand(Float64, (2, 4)),
zeros(N), zeros(N),
zeros(2))
prob = ODEProblem(forward_RHS_1, zeros(4*N), (0.0, 500.0), p)
solution = solve(prob, lsoda(), saveat = 0.1, reltol = 1e-8, abstol = 1e-8)
return solution
end
N = 50
#@time runprob1(N);
#@time runprob1(N);
@btime runprob1($N);
Julia 2:
using OrdinaryDiffEq, BenchmarkTools, LSODA, Plots, Profile
using LinearAlgebra
using Parameters
@with_kw struct ODEParameters
number_of_inputs::Int
N::Int
a_s::Float64
b_s::Float64
c_s::Float64
d_s::Float64
e_s::Float64
f_s::Float64
g_s::Float64
h_s::Float64
i_s::Float64
j_s::Float64
k_s::Float64
l_s::Float64
m_s::Float64
cm_s::Float64
tau_u_s::Float64
tau_w_s::Float64
tau_s::Float64
input_weights_s::Array{Float64,2}
recurrent_weights_s::Array{Float64,2}
tonic_inputs_s::Vector{Float64}
frequencies::Array{Float64,2}
network_inputs_s::Vector{Float64}
recurrent_inputs_s::Vector{Float64}
input_cache::Vector{Float64}
end
function forward_RHS_2(du, u, p, t)
v_s = view(u, 1:p.N)
u_s = view(u, p.N+1:2*p.N)
w_s = view(u, 2*p.N+1:3*p.N)
s = view(u, 3*p.N+1:4*p.N)
p.input_cache .= sin.(t.*view(p.frequencies, :, 1)) .+ sin.(t.*view(p.frequencies, :, 2)) .- sin.(t.*view(p.frequencies, :, 3)) .- sin.(t.*view(p.frequencies, :, 4))
mul!(p.network_inputs_s, p.input_weights_s, p.input_cache)
mul!(p.recurrent_inputs_s, p.recurrent_weights_s, s)
du[1:p.N] .= ((p.a_s.*(v_s .- p.b_s).^2 .+ p.c_s).*(v_s .- 1.0)
.- p.j_s.*u_s.*(v_s .- p.l_s) .- p.k_s.*w_s.*(v_s .- p.m_s)
.+ p.network_inputs_s
.+ p.recurrent_inputs_s
.+ p.tonic_inputs_s)./p.cm_s
du[p.N+1:2*p.N] .= (exp.(p.d_s.*(v_s .- p.e_s)) .+ p.f_s .- u_s)./p.tau_u_s
du[2*p.N+1:3*p.N] .= (p.g_s.*(v_s .- p.h_s).^2 .+ p.i_s .- w_s)./p.tau_w_s
du[3*p.N+1:4*p.N] .= (-1 .* s .+ (v_s .> 0.25).*(v_s .< 0.3).*(view(du, 1:p.N) .> 0.0).*view(du, 1:p.N)./0.05)./p.tau_s
nothing
end
function runprob2(N)
p = ODEParameters(2, N,
-500, 0.2, 0.0, 100.0, 0.7, 0.0, 15.0, 0.24, 0.02, 1700.0, 8.0, 0.19, -0.5,
25.0, 6.0, 12.0, 10.0,
2.8.*(rand(Float64, (N, 2)) .- 0.5), 0.5.*(rand(Float64, (N, N)) .- 0.8), 0.14.*(rand(Float64, (N, )) .- 0.5),
rand(Float64, (2, 4)),
zeros(N), zeros(N),
zeros(2))
prob = ODEProblem(forward_RHS_2, zeros(4*N), (0.0, 500.0), p)
solution = solve(prob, lsoda(), saveat = 0.1, reltol = 1e-8, abstol = 1e-8)
return solution
end
N = 50
#@time runprob2(N);
#@time runprob2(N);
@btime runprob2($N);
Numba+SciPy
N = 50: 1.96 s ± 179 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
N = 100: 4.6 s ± 220 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
N = 300: 18.8 s ± 769 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
N = 500: 35.5 s ± 2.42 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
Julia 1 (for loop)
N = 50: 792.417 ms (426951 allocations: 41.67 MiB)
N = 100: 3.925 s (842437 allocations: 80.75 MiB)
N = 300: 58.075 s (1994026 allocations: 200.40 MiB)
N = 500: 262.799 s (2912560 allocations: 303.50 MiB)
Julia 2
N = 50: 834.853 ms (540793 allocations: 50.35 MiB)
N = 100: 15.337 s (962251 allocations: 89.89 MiB)
N = 300: 31.940 s (1420379 allocations: 156.64 MiB)
N = 500: 55.922 s (1731853 allocations: 213.42 MiB)
If this is the most I can expect that’s probably okay, I’ll just set up a switch between Numba and Julia in my final code when N is large enough.