Hi all!
I’m trying to squeeze performance gains out of two sets of ODEs that I need to solve many times at relatively low tolerances (~1e-10). I’ve benefited enormously from reading other posts and the performance tips page, and I think I’ve optimized quite a bit. I’m hoping that someone with more seasoned eyes might see further room for improvement.
I think the main thing slowing me down is the interpolation of the forward solution in the reverse solution. Is this bottleneck unavoidable?
Here is my (not so) MWE:
using OrdinaryDiffEq, Parameters, LinearAlgebra, BenchmarkTools, MKL
number_of_inputs = 2
network_size = 64
number_of_targets = 2
simulation_length = 50.0
U_s = rand(network_size, number_of_inputs) .- 0.5
U_d = rand(network_size, number_of_inputs) .- 0.5
O = rand(number_of_targets, network_size) .- 0.5
R_s = 0.5*(rand(network_size, network_size) .- 0.5)
R_d = 0.5*(rand(network_size, network_size) .- 0.5)
I_s = rand(network_size)
I_d = rand(network_size)
frequencies = pi*rand(2, 4)/20
A = zeros(2, 2)
A[1, 1] = -0.7
A[1, 2] = 0.36
A[2, 1] = -2.3
A[2, 2] = -0.1
# definitions
const a_s = -1000.0
const b_s = 0.3
const c_s = 0.04
const d_s = 5.0
const e_s = 0.3
const f_s = 0.0
const g_s = 2.0
const h_s = 0.3
const i_s = 0.02
const j_s = 2500.0
const k_s = 100.0
const l_s = 0.3
const m_s = 0.0
const a_d = -200.0
const b_d = 0.3
const c_d = 0.08
const d_d = 1.0
const e_d = 0.3
const f_d = 0.0
const g_d = 2.0
const h_d = 0.3
const i_d = 0.08
const j_d = 200.0
const k_d = 130.0
const l_d = 0.25
const m_d = 0.26
const cm_s = 25.0
const cm_d = 25.0
const tau_u_s = 2.0
const tau_w_s = 5.0
const tau_u_d = 0.6
const tau_w_d = 30.0
const sigma_s_d = 7.35
const sigma_d_s = 8.505
const mu_s = 0.2707
const mu_d = 0.2679
const q1 = 100.0
const z_max = 5.0
const tau_s = 10.0
const r1 = 0.01
const r2 = 0.01
const r3 = 0.0
@with_kw mutable struct ForwardParameters
number_of_inputs::Int
network_size::Int
number_of_targets::Int
U_s::Array{Float64,2}
U_d::Array{Float64,2}
O::Array{Float64,2}
R_s::Array{Float64,2}
R_d::Array{Float64,2}
I_s::Vector{Float64}
I_d::Vector{Float64}
frequencies::Array{Float64,2} = zeros(2, 4)
A::Array{Float64,2} = zeros(2, 2)
input_cache::Vector{Float64} = zeros(number_of_inputs)
inputs_s::Vector{Float64} = zeros(network_size)
recurrents_s::Vector{Float64} = zeros(network_size)
inputs_d::Vector{Float64} = zeros(network_size)
recurrents_d::Vector{Float64} = zeros(network_size)
target_cache::Vector{Float64} = zeros(number_of_targets)
end
@with_kw mutable struct ReverseParameters
number_of_inputs::Int
network_size::Int
number_of_targets::Int
U_s::Array{Float64,2}
U_d::Array{Float64,2}
O::Array{Float64,2}
R_s::Array{Float64,2}
R_d::Array{Float64,2}
I_s::Vector{Float64}
I_d::Vector{Float64}
forward_solution::OrdinaryDiffEq.ODECompositeSolution
u::Vector{Float64} = zeros(7*network_size+number_of_targets)
frequencies::Array{Float64,2} = zeros(2, 4)
A::Array{Float64,2} = zeros(2, 2)
input_cache::Vector{Float64} = zeros(number_of_inputs)
inputs_s::Vector{Float64} = zeros(network_size)
recurrents_s::Vector{Float64} = zeros(network_size)
total_s::Vector{Float64} = zeros(network_size)
inputs_d::Vector{Float64} = zeros(network_size)
recurrents_d::Vector{Float64} = zeros(network_size)
total_d::Vector{Float64} = zeros(network_size)
outputs::Vector{Float64} = zeros(number_of_targets)
output_cache_s::Vector{Float64} = zeros(network_size)
output_cache_d::Vector{Float64} = zeros(network_size)
loss_cache::Vector{Float64} = zeros(number_of_targets)
dl_ds1::Vector{Float64} = zeros(network_size)
O_T::Array{Float64,2} = transpose(O)
R_s_T::Array{Float64,2} = transpose(R_s)
R_d_T::Array{Float64,2} = transpose(R_d)
reverse_cache_s::Vector{Float64} = zeros(network_size)
reverse_cache_d::Vector{Float64} = zeros(network_size)
end
# functions
function forward_RHS!(du, u, p, t)
@inbounds @simd for i in 1:p.number_of_inputs
p.input_cache[i] = sin(p.frequencies[i, 1]*t) - sin(p.frequencies[i, 2]*t) + sin(p.frequencies[i, 3]*t) - sin(p.frequencies[i, 4]*t)
end
BLAS.gemv!('N', 1.0, p.U_s, p.input_cache, 0.0, p.inputs_s)
@views BLAS.gemv!('N', 1.0, p.R_s, u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_s)
BLAS.gemv!('N', 1.0, p.U_d, p.input_cache, 0.0, p.inputs_d)
@views BLAS.gemv!('N', 1.0, p.R_d, u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_d)
@views BLAS.gemv!('N', 1.0, p.A, u[7*p.network_size+1:7*p.network_size+p.number_of_targets], 0.0, p.target_cache)
@inbounds @simd for i in 1:p.network_size
du[i] = ((a_s*(u[i] - b_s)^2 + c_s)*(u[i] - 1.0)
- j_s*u[3*p.network_size+i]*(u[i] - l_s)
- k_s*u[5*p.network_size+i]*(u[i] - m_s)
+ p.inputs_s[i] + p.recurrents_s[i] + p.I_s[i] + sigma_d_s*(u[p.network_size+i] - mu_d))/cm_s
du[p.network_size+i] = ((a_d*(u[p.network_size+i] - b_d)^2 + c_d)*(u[p.network_size+i] - 1.0)
- j_d*u[4*p.network_size+i]*(u[p.network_size+i] - l_d)
- k_d*u[6*p.network_size+i]*(u[p.network_size+i] - m_d)
+ p.inputs_d[i] + p.recurrents_d[i] + p.I_d[i] + sigma_s_d*(u[i] - mu_s))/cm_d
du[2*p.network_size+i] = (z_max*(tanh(q1*(u[3*p.network_size+i] - 0.05)) + 1.0) - u[2*p.network_size+i])/tau_s
du[3*p.network_size+i] = (d_s*(u[i] - e_s)^3 + f_s - u[3*p.network_size+i])/tau_u_s
du[4*p.network_size+i] = (d_d*(u[p.network_size+i] - e_d)^3 + f_d - u[4*p.network_size+i])/tau_u_d
du[5*p.network_size+i] = (g_s*(u[i] - h_s)^2 + i_s - u[5*p.network_size+i])/tau_w_s
du[6*p.network_size+i] = (g_d*(u[p.network_size+i] - h_d)^2 + i_d - u[6*p.network_size+i])/tau_w_d
end
@inbounds @simd for i in 1:p.number_of_targets
du[7*p.network_size+i] = (p.target_cache[i] + p.input_cache[i] - u[7*p.network_size+i])/tau_s
end
nothing
end
function reverse_RHS!(du, u, p, t)
p.forward_solution(p.u, t)
@inbounds @simd for i in 1:p.number_of_inputs
p.input_cache[i] = sin(p.frequencies[i, 1]*t) - sin(p.frequencies[i, 2]*t) + sin(p.frequencies[i, 3]*t) - sin(p.frequencies[i, 4]*t)
end
BLAS.gemv!('N', 1.0, p.U_s, p.input_cache, 0.0, p.inputs_s)
@views BLAS.gemv!('N', 1.0, p.R_s, p.u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_s)
@. p.total_s = p.inputs_s + p.recurrents_s + p.I_s
BLAS.gemv!('N', 1.0, p.R_s_T, p.total_s, 0.0, p.output_cache_s)
BLAS.gemv!('N', 1.0, p.U_d, p.input_cache, 0.0, p.inputs_d)
@views BLAS.gemv!('N', 1.0, p.R_d, p.u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_d)
@. p.total_d = p.inputs_d + p.recurrents_d + p.I_d
BLAS.gemv!('N', 1.0, p.R_d_T, p.total_d, 0.0, p.output_cache_d)
@views BLAS.gemv!('N', 1.0, p.O, p.u[2*p.network_size+1:3*p.network_size], 0.0, p.outputs)
@views @. p.loss_cache = p.outputs - p.u[7*p.network_size+1:7*p.network_size+p.number_of_targets] + r3*p.outputs
BLAS.gemv!('N', 1.0, p.O_T, p.loss_cache, 0.0, p.dl_ds1)
@views BLAS.gemv!('N', 1.0, p.R_s_T, u[1:p.network_size], 0.0, p.reverse_cache_s)
@views BLAS.gemv!('N', 1.0, p.R_d_T, u[p.network_size+1:2*p.network_size], 0.0, p.reverse_cache_d)
@inbounds @simd for i in 1:p.network_size
p.dl_ds1[i] += r1*(p.output_cache_s[i] + p.output_cache_d[i]) + r2*p.u[2*p.network_size+i]
end
@inbounds @simd for i in 1:p.network_size
if p.u[i] < 0.45
du[i] = (-1*u[i]*(2*a_s*(p.u[i] - b_s)*(p.u[i] - 1) + a_s*(p.u[i] - b_s)^2 + c_s - j_s*p.u[3*p.network_size+i] - k_s*p.u[5*p.network_size+i])/cm_s
- u[p.network_size+i]*sigma_d_s/cm_s
- u[3*p.network_size+i]*3*(d_s*(p.u[i] - e_s)^2)/tau_u_s
- u[5*p.network_size+i]*2*(g_s*(p.u[i] - h_s))/tau_w_s)
else
du[i] = (u[i]*(j_s*p.u[3*p.network_size+i] + k_s*p.u[5*p.network_size+i])/cm_s
- u[p.network_size+i]*sigma_d_s/cm_s
- u[3*p.network_size+i]*3*(d_s*(p.u[i] - e_s)^2)/tau_u_s
- u[5*p.network_size+i]*2*(g_s*(p.u[i] - h_s))/tau_w_s)
end
du[p.network_size+i] = (-1*u[i]*sigma_s_d/cm_d
- u[p.network_size+i]*(2*a_d*(p.u[p.network_size+i] - b_d)*(p.u[p.network_size+i] - 1) + a_d*(p.u[p.network_size+i] - b_d)^2 + c_d - j_d*p.u[4*p.network_size+i] - k_d*p.u[6*p.network_size+i])/cm_d
- u[4*p.network_size+i]*3*(d_d*(p.u[p.network_size+i] - e_d)^2)/tau_u_d
- u[6*p.network_size+i]*2*(g_d*(p.u[p.network_size+i] - h_d))/tau_w_d)
du[2*p.network_size+i] = u[2*p.network_size+i]/tau_s - p.reverse_cache_s[i]/cm_s - p.reverse_cache_d[i]/cm_d - p.dl_ds1[i]
du[3*p.network_size+i] = u[3*p.network_size+i]/tau_u_s + u[i]*(j_s*(p.u[i] - l_s))/cm_s - u[2*p.network_size+i]*z_max*q1*sech(q1*(p.u[3*p.network_size+i] - 0.05))^2/tau_s
du[4*p.network_size+i] = u[4*p.network_size+i]/tau_u_d + u[p.network_size+i]*(j_d*(p.u[p.network_size+i] - l_d))/cm_d
du[5*p.network_size+i] = u[5*p.network_size+i]/tau_w_s + u[i]*(k_s*(p.u[i] - m_s))/cm_s
du[6*p.network_size+i] = u[6*p.network_size+i]/tau_w_d + u[p.network_size+i]*(k_d*(p.u[p.network_size+i] - m_d))/cm_d
end
nothing
end
function forward_jacobian!(J, u, p, t)
@inbounds for i in 1:p.network_size
J[i, i] = (2*a_s*(u[i] - b_s)*(u[i] - 1) +
a_s*(u[i] - b_s)^2 + c_s - j_s*u[3*p.network_size+i] - k_s*u[5*p.network_size+i])/cm_s
J[i, p.network_size+i] = sigma_d_s/cm_s
J[i, 3*p.network_size+i] = j_s*(l_s - u[i])/cm_s
J[i, 5*p.network_size+i] = k_s*(m_s - u[i])/cm_s
J[p.network_size+i, i] = sigma_s_d/cm_s
J[p.network_size+i, p.network_size+i] = (2*a_d*(u[p.network_size+i] - b_d)*(u[p.network_size+i] - 1) +
a_d*(u[p.network_size+i] - b_d)^2 + c_d - j_d*u[4*p.network_size+i] - k_d*u[6*p.network_size+i])/cm_d
J[p.network_size+i, 4*p.network_size+i] = j_d*(l_d - u[p.network_size+i])/cm_d
J[p.network_size+i, 6*p.network_size+i] = k_d*(m_d - u[p.network_size+i])/cm_d
J[2*p.network_size+i, 2*p.network_size+i] = -1/tau_s
J[2*p.network_size+i, 3*p.network_size+i] = z_max*q1*sech(q1*(u[3*p.network_size+i] - 0.05))^2/tau_s
J[3*p.network_size+i, i] = 3*(d_s*(u[i] - e_s)^2)/tau_u_s
J[3*p.network_size+i, 3*p.network_size+i] = -1/tau_u_s
J[4*p.network_size+i, p.network_size+i] = 3*(d_d*(u[p.network_size+i] - e_d)^2)/tau_u_d
J[4*p.network_size+i, 4*p.network_size+i] = -1/tau_u_d
J[5*p.network_size+i, i] = 2*(g_s*(u[i] - h_s))/tau_w_s
J[5*p.network_size+i, 5*p.network_size+i] = -1/tau_w_s
J[6*p.network_size+i, p.network_size+i] = 2*(g_d*(u[p.network_size+i] - h_d))/tau_w_d
J[6*p.network_size+i, 6*p.network_size+i] = -1/tau_w_d
@inbounds @simd for j in 1:p.network_size
J[i, 2*p.network_size+j] = p.R_s[i, j]/cm_s
J[p.network_size+i, 2*p.network_size+j] = p.R_d[i, j]/cm_d
end
end
@inbounds @simd for i in 1:p.number_of_targets
J[7*p.network_size+i, 7*p.network_size+i] = -1/tau_s
end
nothing
end
function reverse_jacobian!(J, u, p, t)
p.forward_solution(p.u, t)
@inbounds for i in 1:p.network_size
J[i, i] = -1*(2*a_s*(p.u[i] - b_s)*(p.u[i] - 1) +
a_s*(p.u[i] - b_s)^2 + c_s - j_s*p.u[3*p.network_size+i] - k_s*p.u[5*p.network_size+i])/cm_s
J[i, p.network_size+i] = -1*sigma_s_d/cm_d
J[i, 3*p.network_size+i] = -3*(d_s*(p.u[i] - e_s)^2)/tau_u_s
J[i, 5*p.network_size+i] = -2*(g_s*(p.u[i] - h_s))/tau_w_s
J[p.network_size+i, i] = -1*sigma_d_s/cm_s
J[p.network_size+i, p.network_size+i] = -1*(2*a_d*(p.u[p.network_size+i] - b_d)*(p.u[p.network_size+i] - 1) +
a_d*(p.u[p.network_size+i] - b_d)^2 + c_d - j_d*p.u[4*p.network_size+i] - k_d*p.u[6*p.network_size+i])/cm_d
J[p.network_size+i, 4*p.network_size+i] = -3*(d_d*(p.u[p.network_size+i] - e_d)^2)/tau_u_d
J[p.network_size+i, 6*p.network_size+i] = -2*(g_d*(p.u[p.network_size+i] - h_d))/tau_w_d
J[2*p.network_size+i, 2*p.network_size+i] = 1/tau_s
J[3*p.network_size+i, i] = j_s*(p.u[i] - l_s)/cm_s
J[3*p.network_size+i, 2*p.network_size+i] = -1*z_max*q1*sech(q1*(p.u[3*p.network_size+i] - 0.05))^2/tau_s
J[3*p.network_size+i, 3*p.network_size+i] = 1/tau_u_s
J[4*p.network_size+i, p.network_size+i] = j_d*(p.u[p.network_size+i] - l_d)/cm_d
J[4*p.network_size+i, 4*p.network_size+i] = 1/tau_u_d
J[5*p.network_size+i, i] = k_s*(p.u[i] - m_s)/cm_s
J[5*p.network_size+i, 5*p.network_size+i] = 1/tau_w_s
J[6*p.network_size+i, p.network_size+i] = k_d*(p.u[p.network_size+i] - m_d)/cm_d
J[6*p.network_size+i, 6*p.network_size+i] = 1/tau_w_d
@inbounds for j in 1:p.network_size
J[2*p.network_size+i, j] = -1*p.R_s_T[i, j]/cm_s
J[2*p.network_size+i, p.network_size+j] = -1*p.R_d_T[i, j]/cm_d
end
end
nothing
end
forward_parameters = ForwardParameters(number_of_inputs = number_of_inputs, network_size = network_size, number_of_targets = number_of_targets,
U_s = U_s, U_d = U_d, O = O,
R_s = R_s, R_d = R_d,
I_s = I_s, I_d = I_d,
frequencies = frequencies, A = A)
forward_ODEFunction = ODEFunction(forward_RHS!, jac = forward_jacobian!)
forward_ODE = ODEProblem(forward_ODEFunction, zeros(7*network_size + number_of_targets), (0.0, simulation_length), forward_parameters)
forward_solution = solve(forward_ODE, AutoVern7(Rodas5()), dense = true, reltol = 1e-10, abstol = 1e-10)
reverse_parameters = ReverseParameters(number_of_inputs = number_of_inputs, network_size = network_size, number_of_targets = number_of_targets,
U_s = U_s, U_d = U_d, O = O,
R_s = R_s, R_d = R_d,
I_s = I_s, I_d = I_d,
forward_solution = forward_solution,
frequencies = frequencies, A = A)
reverse_ODEFunction = ODEFunction(reverse_RHS!, jac = reverse_jacobian!)
reverse_ODE = ODEProblem(reverse_ODEFunction, zeros(7*network_size), (simulation_length, 0.0), reverse_parameters)
reverse_solution = solve(reverse_ODE, Vern8(), dense = true, reltol = 1e-10, abstol = 1e-10)
@benchmark solve(forward_ODE, AutoVern7(Rodas5()), dense = true, reltol = 1e-10, abstol = 1e-10)
@benchmark solve(reverse_ODE, Vern8(), dense = true, reltol = 1e-10, abstol = 1e-10)