# Squeezing performance out of ODE solves with interpolations

Hi all!

I’m trying to squeeze performance gains out of two sets of ODEs that I need to solve many times at relatively low tolerances (~1e-10). I’ve benefited enormously from reading other posts and the performance tips page, and I think I’ve optimized quite a bit. I’m hoping that someone with more seasoned eyes might see further room for improvement.

I think the main thing slowing me down is the interpolation of the forward solution in the reverse solution. Is this bottleneck unavoidable?

Here is my (not so) MWE:

using OrdinaryDiffEq, Parameters, LinearAlgebra, BenchmarkTools, MKL

number_of_inputs = 2
network_size = 64
number_of_targets = 2

simulation_length = 50.0

U_s = rand(network_size, number_of_inputs) .- 0.5
U_d = rand(network_size, number_of_inputs) .- 0.5
O = rand(number_of_targets, network_size) .- 0.5

R_s = 0.5*(rand(network_size, network_size) .- 0.5)
R_d = 0.5*(rand(network_size, network_size) .- 0.5)

I_s = rand(network_size)
I_d = rand(network_size)

frequencies = pi*rand(2, 4)/20
A = zeros(2, 2)
A[1, 1] = -0.7
A[1, 2] = 0.36
A[2, 1] = -2.3
A[2, 2] = -0.1

# definitions

const a_s = -1000.0
const b_s = 0.3
const c_s = 0.04
const d_s = 5.0
const e_s = 0.3
const f_s = 0.0
const g_s = 2.0
const h_s = 0.3
const i_s = 0.02
const j_s = 2500.0
const k_s = 100.0
const l_s = 0.3
const m_s = 0.0

const a_d = -200.0
const b_d = 0.3
const c_d = 0.08
const d_d = 1.0
const e_d = 0.3
const f_d = 0.0
const g_d = 2.0
const h_d = 0.3
const i_d = 0.08
const j_d = 200.0
const k_d = 130.0
const l_d = 0.25
const m_d = 0.26

const cm_s = 25.0
const cm_d = 25.0

const tau_u_s = 2.0
const tau_w_s = 5.0

const tau_u_d = 0.6
const tau_w_d = 30.0

const sigma_s_d = 7.35
const sigma_d_s = 8.505

const mu_s = 0.2707
const mu_d = 0.2679

const q1 = 100.0

const z_max = 5.0

const tau_s = 10.0

const r1 = 0.01
const r2 = 0.01
const r3 = 0.0

@with_kw mutable struct ForwardParameters
number_of_inputs::Int
network_size::Int
number_of_targets::Int

U_s::Array{Float64,2}
U_d::Array{Float64,2}
O::Array{Float64,2}

R_s::Array{Float64,2}
R_d::Array{Float64,2}

I_s::Vector{Float64}
I_d::Vector{Float64}

frequencies::Array{Float64,2} = zeros(2, 4)
A::Array{Float64,2} = zeros(2, 2)
input_cache::Vector{Float64} = zeros(number_of_inputs)

inputs_s::Vector{Float64} = zeros(network_size)
recurrents_s::Vector{Float64} = zeros(network_size)

inputs_d::Vector{Float64} = zeros(network_size)
recurrents_d::Vector{Float64} = zeros(network_size)

target_cache::Vector{Float64} = zeros(number_of_targets)
end

@with_kw mutable struct ReverseParameters
number_of_inputs::Int
network_size::Int
number_of_targets::Int

U_s::Array{Float64,2}
U_d::Array{Float64,2}
O::Array{Float64,2}

R_s::Array{Float64,2}
R_d::Array{Float64,2}

I_s::Vector{Float64}
I_d::Vector{Float64}

forward_solution::OrdinaryDiffEq.ODECompositeSolution

u::Vector{Float64} = zeros(7*network_size+number_of_targets)

frequencies::Array{Float64,2} = zeros(2, 4)
A::Array{Float64,2} = zeros(2, 2)
input_cache::Vector{Float64} = zeros(number_of_inputs)

inputs_s::Vector{Float64} = zeros(network_size)
recurrents_s::Vector{Float64} = zeros(network_size)
total_s::Vector{Float64} = zeros(network_size)

inputs_d::Vector{Float64} = zeros(network_size)
recurrents_d::Vector{Float64} = zeros(network_size)
total_d::Vector{Float64} = zeros(network_size)

outputs::Vector{Float64} = zeros(number_of_targets)

output_cache_s::Vector{Float64} = zeros(network_size)
output_cache_d::Vector{Float64} = zeros(network_size)

loss_cache::Vector{Float64} = zeros(number_of_targets)
dl_ds1::Vector{Float64} = zeros(network_size)

O_T::Array{Float64,2} = transpose(O)
R_s_T::Array{Float64,2} = transpose(R_s)
R_d_T::Array{Float64,2} = transpose(R_d)

reverse_cache_s::Vector{Float64} = zeros(network_size)
reverse_cache_d::Vector{Float64} = zeros(network_size)
end

# functions

function forward_RHS!(du, u, p, t)
@inbounds @simd for i in 1:p.number_of_inputs
p.input_cache[i] = sin(p.frequencies[i, 1]*t) - sin(p.frequencies[i, 2]*t) + sin(p.frequencies[i, 3]*t) - sin(p.frequencies[i, 4]*t)
end

BLAS.gemv!('N', 1.0, p.U_s, p.input_cache, 0.0, p.inputs_s)
@views BLAS.gemv!('N', 1.0, p.R_s, u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_s)

BLAS.gemv!('N', 1.0, p.U_d, p.input_cache, 0.0, p.inputs_d)
@views BLAS.gemv!('N', 1.0, p.R_d, u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_d)

@views BLAS.gemv!('N', 1.0, p.A, u[7*p.network_size+1:7*p.network_size+p.number_of_targets], 0.0, p.target_cache)

@inbounds @simd for i in 1:p.network_size
du[i] = ((a_s*(u[i] - b_s)^2 + c_s)*(u[i] - 1.0)
- j_s*u[3*p.network_size+i]*(u[i] - l_s)
- k_s*u[5*p.network_size+i]*(u[i] - m_s)
+ p.inputs_s[i] + p.recurrents_s[i] + p.I_s[i] + sigma_d_s*(u[p.network_size+i] - mu_d))/cm_s

du[p.network_size+i] = ((a_d*(u[p.network_size+i] - b_d)^2 + c_d)*(u[p.network_size+i] - 1.0)
- j_d*u[4*p.network_size+i]*(u[p.network_size+i] - l_d)
- k_d*u[6*p.network_size+i]*(u[p.network_size+i] - m_d)
+ p.inputs_d[i] + p.recurrents_d[i] + p.I_d[i] + sigma_s_d*(u[i] - mu_s))/cm_d

du[2*p.network_size+i] = (z_max*(tanh(q1*(u[3*p.network_size+i] - 0.05)) + 1.0) - u[2*p.network_size+i])/tau_s

du[3*p.network_size+i] = (d_s*(u[i] - e_s)^3 + f_s - u[3*p.network_size+i])/tau_u_s

du[4*p.network_size+i] = (d_d*(u[p.network_size+i] - e_d)^3 + f_d - u[4*p.network_size+i])/tau_u_d

du[5*p.network_size+i] = (g_s*(u[i] - h_s)^2 + i_s - u[5*p.network_size+i])/tau_w_s

du[6*p.network_size+i] = (g_d*(u[p.network_size+i] - h_d)^2 + i_d - u[6*p.network_size+i])/tau_w_d
end

@inbounds @simd for i in 1:p.number_of_targets
du[7*p.network_size+i] = (p.target_cache[i] + p.input_cache[i] - u[7*p.network_size+i])/tau_s
end

nothing
end

function reverse_RHS!(du, u, p, t)
p.forward_solution(p.u, t)

@inbounds @simd for i in 1:p.number_of_inputs
p.input_cache[i] = sin(p.frequencies[i, 1]*t) - sin(p.frequencies[i, 2]*t) + sin(p.frequencies[i, 3]*t) - sin(p.frequencies[i, 4]*t)
end

BLAS.gemv!('N', 1.0, p.U_s, p.input_cache, 0.0, p.inputs_s)
@views BLAS.gemv!('N', 1.0, p.R_s, p.u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_s)
@. p.total_s = p.inputs_s + p.recurrents_s + p.I_s
BLAS.gemv!('N', 1.0, p.R_s_T, p.total_s, 0.0, p.output_cache_s)

BLAS.gemv!('N', 1.0, p.U_d, p.input_cache, 0.0, p.inputs_d)
@views BLAS.gemv!('N', 1.0, p.R_d, p.u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_d)
@. p.total_d = p.inputs_d + p.recurrents_d + p.I_d
BLAS.gemv!('N', 1.0, p.R_d_T, p.total_d, 0.0, p.output_cache_d)

@views BLAS.gemv!('N', 1.0, p.O, p.u[2*p.network_size+1:3*p.network_size], 0.0, p.outputs)

@views @. p.loss_cache = p.outputs - p.u[7*p.network_size+1:7*p.network_size+p.number_of_targets] + r3*p.outputs
BLAS.gemv!('N', 1.0, p.O_T, p.loss_cache, 0.0, p.dl_ds1)

@views BLAS.gemv!('N', 1.0, p.R_s_T, u[1:p.network_size], 0.0, p.reverse_cache_s)
@views BLAS.gemv!('N', 1.0, p.R_d_T, u[p.network_size+1:2*p.network_size], 0.0, p.reverse_cache_d)

@inbounds @simd for i in 1:p.network_size
p.dl_ds1[i] += r1*(p.output_cache_s[i] + p.output_cache_d[i]) + r2*p.u[2*p.network_size+i]
end

@inbounds @simd for i in 1:p.network_size
if p.u[i] < 0.45
du[i] = (-1*u[i]*(2*a_s*(p.u[i] - b_s)*(p.u[i] - 1) + a_s*(p.u[i] - b_s)^2 + c_s - j_s*p.u[3*p.network_size+i] - k_s*p.u[5*p.network_size+i])/cm_s
- u[p.network_size+i]*sigma_d_s/cm_s
- u[3*p.network_size+i]*3*(d_s*(p.u[i] - e_s)^2)/tau_u_s
- u[5*p.network_size+i]*2*(g_s*(p.u[i] - h_s))/tau_w_s)
else
du[i] = (u[i]*(j_s*p.u[3*p.network_size+i] + k_s*p.u[5*p.network_size+i])/cm_s
- u[p.network_size+i]*sigma_d_s/cm_s
- u[3*p.network_size+i]*3*(d_s*(p.u[i] - e_s)^2)/tau_u_s
- u[5*p.network_size+i]*2*(g_s*(p.u[i] - h_s))/tau_w_s)
end

du[p.network_size+i] = (-1*u[i]*sigma_s_d/cm_d
- u[p.network_size+i]*(2*a_d*(p.u[p.network_size+i] - b_d)*(p.u[p.network_size+i] - 1) + a_d*(p.u[p.network_size+i] - b_d)^2 + c_d - j_d*p.u[4*p.network_size+i] - k_d*p.u[6*p.network_size+i])/cm_d
- u[4*p.network_size+i]*3*(d_d*(p.u[p.network_size+i] - e_d)^2)/tau_u_d
- u[6*p.network_size+i]*2*(g_d*(p.u[p.network_size+i] - h_d))/tau_w_d)

du[2*p.network_size+i] = u[2*p.network_size+i]/tau_s - p.reverse_cache_s[i]/cm_s - p.reverse_cache_d[i]/cm_d - p.dl_ds1[i]

du[3*p.network_size+i] = u[3*p.network_size+i]/tau_u_s + u[i]*(j_s*(p.u[i] - l_s))/cm_s - u[2*p.network_size+i]*z_max*q1*sech(q1*(p.u[3*p.network_size+i] - 0.05))^2/tau_s

du[4*p.network_size+i] = u[4*p.network_size+i]/tau_u_d + u[p.network_size+i]*(j_d*(p.u[p.network_size+i] - l_d))/cm_d

du[5*p.network_size+i] = u[5*p.network_size+i]/tau_w_s + u[i]*(k_s*(p.u[i] - m_s))/cm_s

du[6*p.network_size+i] = u[6*p.network_size+i]/tau_w_d + u[p.network_size+i]*(k_d*(p.u[p.network_size+i] - m_d))/cm_d
end

nothing
end

function forward_jacobian!(J, u, p, t)
@inbounds for i in 1:p.network_size
J[i, i] = (2*a_s*(u[i] - b_s)*(u[i] - 1) +
a_s*(u[i] - b_s)^2 + c_s - j_s*u[3*p.network_size+i] - k_s*u[5*p.network_size+i])/cm_s
J[i, p.network_size+i] = sigma_d_s/cm_s
J[i, 3*p.network_size+i] = j_s*(l_s - u[i])/cm_s
J[i, 5*p.network_size+i] = k_s*(m_s - u[i])/cm_s

J[p.network_size+i, i] = sigma_s_d/cm_s
J[p.network_size+i, p.network_size+i] = (2*a_d*(u[p.network_size+i] - b_d)*(u[p.network_size+i] - 1) +
a_d*(u[p.network_size+i] - b_d)^2 + c_d - j_d*u[4*p.network_size+i] - k_d*u[6*p.network_size+i])/cm_d
J[p.network_size+i, 4*p.network_size+i] = j_d*(l_d - u[p.network_size+i])/cm_d
J[p.network_size+i, 6*p.network_size+i] = k_d*(m_d - u[p.network_size+i])/cm_d

J[2*p.network_size+i, 2*p.network_size+i] = -1/tau_s
J[2*p.network_size+i, 3*p.network_size+i] = z_max*q1*sech(q1*(u[3*p.network_size+i] - 0.05))^2/tau_s

J[3*p.network_size+i, i] = 3*(d_s*(u[i] - e_s)^2)/tau_u_s
J[3*p.network_size+i, 3*p.network_size+i] = -1/tau_u_s

J[4*p.network_size+i, p.network_size+i] = 3*(d_d*(u[p.network_size+i] - e_d)^2)/tau_u_d
J[4*p.network_size+i, 4*p.network_size+i] = -1/tau_u_d

J[5*p.network_size+i, i] = 2*(g_s*(u[i] - h_s))/tau_w_s
J[5*p.network_size+i, 5*p.network_size+i] = -1/tau_w_s

J[6*p.network_size+i, p.network_size+i] = 2*(g_d*(u[p.network_size+i] - h_d))/tau_w_d
J[6*p.network_size+i, 6*p.network_size+i] = -1/tau_w_d

@inbounds @simd for j in 1:p.network_size
J[i, 2*p.network_size+j] = p.R_s[i, j]/cm_s
J[p.network_size+i, 2*p.network_size+j] = p.R_d[i, j]/cm_d
end
end

@inbounds @simd for i in 1:p.number_of_targets
J[7*p.network_size+i, 7*p.network_size+i] = -1/tau_s
end

nothing
end

function reverse_jacobian!(J, u, p, t)
p.forward_solution(p.u, t)

@inbounds for i in 1:p.network_size
J[i, i] = -1*(2*a_s*(p.u[i] - b_s)*(p.u[i] - 1) +
a_s*(p.u[i] - b_s)^2 + c_s - j_s*p.u[3*p.network_size+i] - k_s*p.u[5*p.network_size+i])/cm_s
J[i, p.network_size+i] = -1*sigma_s_d/cm_d
J[i, 3*p.network_size+i] = -3*(d_s*(p.u[i] - e_s)^2)/tau_u_s
J[i, 5*p.network_size+i] = -2*(g_s*(p.u[i] - h_s))/tau_w_s

J[p.network_size+i, i] = -1*sigma_d_s/cm_s
J[p.network_size+i, p.network_size+i] = -1*(2*a_d*(p.u[p.network_size+i] - b_d)*(p.u[p.network_size+i] - 1) +
a_d*(p.u[p.network_size+i] - b_d)^2 + c_d - j_d*p.u[4*p.network_size+i] - k_d*p.u[6*p.network_size+i])/cm_d
J[p.network_size+i, 4*p.network_size+i] = -3*(d_d*(p.u[p.network_size+i] - e_d)^2)/tau_u_d
J[p.network_size+i, 6*p.network_size+i] = -2*(g_d*(p.u[p.network_size+i] - h_d))/tau_w_d

J[2*p.network_size+i, 2*p.network_size+i] = 1/tau_s

J[3*p.network_size+i, i] = j_s*(p.u[i] - l_s)/cm_s
J[3*p.network_size+i, 2*p.network_size+i] = -1*z_max*q1*sech(q1*(p.u[3*p.network_size+i] - 0.05))^2/tau_s
J[3*p.network_size+i, 3*p.network_size+i] = 1/tau_u_s

J[4*p.network_size+i, p.network_size+i] = j_d*(p.u[p.network_size+i] - l_d)/cm_d
J[4*p.network_size+i, 4*p.network_size+i] = 1/tau_u_d

J[5*p.network_size+i, i] = k_s*(p.u[i] - m_s)/cm_s
J[5*p.network_size+i, 5*p.network_size+i] = 1/tau_w_s

J[6*p.network_size+i, p.network_size+i] = k_d*(p.u[p.network_size+i] - m_d)/cm_d
J[6*p.network_size+i, 6*p.network_size+i] = 1/tau_w_d

@inbounds for j in 1:p.network_size
J[2*p.network_size+i, j] = -1*p.R_s_T[i, j]/cm_s
J[2*p.network_size+i, p.network_size+j] = -1*p.R_d_T[i, j]/cm_d
end
end

nothing
end

forward_parameters = ForwardParameters(number_of_inputs = number_of_inputs, network_size = network_size, number_of_targets = number_of_targets,
U_s = U_s, U_d = U_d, O = O,
R_s = R_s, R_d = R_d,
I_s = I_s, I_d = I_d,
frequencies = frequencies, A = A)

forward_ODEFunction = ODEFunction(forward_RHS!, jac = forward_jacobian!)
forward_ODE = ODEProblem(forward_ODEFunction, zeros(7*network_size + number_of_targets), (0.0, simulation_length), forward_parameters)
forward_solution = solve(forward_ODE, AutoVern7(Rodas5()), dense = true, reltol = 1e-10, abstol = 1e-10)

reverse_parameters = ReverseParameters(number_of_inputs = number_of_inputs, network_size = network_size, number_of_targets = number_of_targets,
U_s = U_s, U_d = U_d, O = O,
R_s = R_s, R_d = R_d,
I_s = I_s, I_d = I_d,
forward_solution = forward_solution,
frequencies = frequencies, A = A)

reverse_ODEFunction = ODEFunction(reverse_RHS!, jac = reverse_jacobian!)
reverse_ODE = ODEProblem(reverse_ODEFunction, zeros(7*network_size), (simulation_length, 0.0), reverse_parameters)
reverse_solution = solve(reverse_ODE, Vern8(), dense = true, reltol = 1e-10, abstol = 1e-10)

@benchmark solve(forward_ODE, AutoVern7(Rodas5()), dense = true, reltol = 1e-10, abstol = 1e-10)
@benchmark solve(reverse_ODE, Vern8(), dense = true, reltol = 1e-10, abstol = 1e-10)

One thing that I think you should look into is replacing a_sm_s with const s = (-1000.0, 0.3, 0.04, 5.0, 0.3, 0.0,2.0...) (also a_dm_d. This won’t help performance on it’s own, but having more organized code makes it much easier to see places to take advantage of structure in your problem.

There’s a lot that can be improved here. A big part of it is searchsortedlast performance, so someone can roll up their sleaves like in Searchsortedlast performance and really knock that down. We’ve done a little bit already,

but the answer in that thread is that you don’t always want to do bisection, so that’s one major improvement. Some other things like caching the last index could also be done. It would be really good for someone to dig into this to see if they can find something to speed it up because indeed larger interpolations do take a bit.

Is the reverse here an adjoint equation?

2 Likes

It is indeed! My group is interested in the adjoint dynamics of this system and they’re simple enough that I’m doing it manually. But they’re also numerically unstable (which is why I’ve done a super messy hack with the if statement).

To reduce the number of f calls, have you specified the Jacobian sparsities?

Ah, I didn’t realize that was necessary when declaring the entire Jacobian as a function. It looks like doing so maybe shaved 5% off my time.

Re: your earlier post, just so I understand you, are you recommending that I write a custom interpolation function that takes advantage of methods specific for the size of my interpolant?

Just for completeness, this is an updated MWE with some typos fixed and adjoint equation a little more stabilized:

using OrdinaryDiffEq, Parameters, LinearAlgebra, Sundials, LSODA, BenchmarkTools, MKL, SparseArrays

number_of_inputs = 2
network_size = 64
number_of_targets = 2

simulation_length = 500.0

U_s = rand(network_size, number_of_inputs) .- 0.5
U_d = rand(network_size, number_of_inputs) .- 0.5
O = rand(number_of_targets, network_size) .- 0.5

R_s = 0.5*(rand(network_size, network_size) .- 0.5)
R_d = 0.5*(rand(network_size, network_size) .- 0.5)

I_s = rand(network_size)
I_d = rand(network_size)

frequencies = pi*rand(2, 4)/20
A = zeros(2, 2)
A[1, 1] = -0.7
A[1, 2] = 0.36
A[2, 1] = -2.3
A[2, 2] = -0.1

# definitions

const a_s = -1000.0
const b_s = 0.3
const c_s = 0.04
const d_s = 5.0
const e_s = 0.3
const f_s = 0.0
const g_s = 2.0
const h_s = 0.3
const i_s = 0.02
const j_s = 2500.0
const k_s = 100.0
const l_s = 0.3
const m_s = 0.0

const a_d = -200.0
const b_d = 0.3
const c_d = 0.08
const d_d = 1.0
const e_d = 0.3
const f_d = 0.0
const g_d = 2.0
const h_d = 0.3
const i_d = 0.08
const j_d = 200.0
const k_d = 130.0
const l_d = 0.25
const m_d = 0.26

const cm_s = 25.0
const cm_d = 25.0

const tau_u_s = 2.0
const tau_w_s = 5.0

const tau_u_d = 0.6
const tau_w_d = 30.0

const sigma_s_d = 7.35
const sigma_d_s = 8.505

const mu_s = 0.2707
const mu_d = 0.2679

const q1 = 50.0

const z_max = 50.0

const tau_s = 10.0

const r1 = 0.01
const r2 = 0.01
const r3 = 0.0

@with_kw mutable struct ForwardParameters
number_of_inputs::Int
network_size::Int
number_of_targets::Int

U_s::Array{Float64,2}
U_d::Array{Float64,2}
O::Array{Float64,2}

R_s::Array{Float64,2}
R_d::Array{Float64,2}

I_s::Vector{Float64}
I_d::Vector{Float64}

frequencies::Array{Float64,2} = zeros(2, 4)
A::Array{Float64,2} = zeros(2, 2)
input_cache::Vector{Float64} = zeros(number_of_inputs)

inputs_s::Vector{Float64} = zeros(network_size)
recurrents_s::Vector{Float64} = zeros(network_size)

inputs_d::Vector{Float64} = zeros(network_size)
recurrents_d::Vector{Float64} = zeros(network_size)

target_cache::Vector{Float64} = zeros(number_of_targets)
end

@with_kw mutable struct ReverseParameters
number_of_inputs::Int
network_size::Int
number_of_targets::Int

U_s::Array{Float64,2}
U_d::Array{Float64,2}
O::Array{Float64,2}

R_s::Array{Float64,2}
R_d::Array{Float64,2}

I_s::Vector{Float64}
I_d::Vector{Float64}

forward_solution::OrdinaryDiffEq.ODECompositeSolution

u::Vector{Float64} = zeros(7*network_size+number_of_targets)

frequencies::Array{Float64,2} = zeros(2, 4)
A::Array{Float64,2} = zeros(2, 2)
input_cache::Vector{Float64} = zeros(number_of_inputs)

inputs_s::Vector{Float64} = zeros(network_size)
recurrents_s::Vector{Float64} = zeros(network_size)
total_s::Vector{Float64} = zeros(network_size)

inputs_d::Vector{Float64} = zeros(network_size)
recurrents_d::Vector{Float64} = zeros(network_size)
total_d::Vector{Float64} = zeros(network_size)

outputs::Vector{Float64} = zeros(number_of_targets)

output_cache_s::Vector{Float64} = zeros(network_size)
output_cache_d::Vector{Float64} = zeros(network_size)

loss_cache::Vector{Float64} = zeros(number_of_targets)
dl_ds1::Vector{Float64} = zeros(network_size)

O_T::Array{Float64,2} = transpose(O)
R_s_T::Array{Float64,2} = transpose(R_s)
R_d_T::Array{Float64,2} = transpose(R_d)

reverse_cache_s::Vector{Float64} = zeros(network_size)
reverse_cache_d::Vector{Float64} = zeros(network_size)
end

# functions

function forward_RHS!(du, u, p, t)
@inbounds for i in 1:p.number_of_inputs
p.input_cache[i] = sin(p.frequencies[i, 1]*t) - sin(p.frequencies[i, 2]*t) + sin(p.frequencies[i, 3]*t) - sin(p.frequencies[i, 4]*t)
end

BLAS.gemv!('N', 1.0, p.U_s, p.input_cache, 0.0, p.inputs_s)
@views BLAS.gemv!('N', 1.0, p.R_s, u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_s)

BLAS.gemv!('N', 1.0, p.U_d, p.input_cache, 0.0, p.inputs_d)
@views BLAS.gemv!('N', 1.0, p.R_d, u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_d)

@views BLAS.gemv!('N', 1.0, p.A, u[7*p.network_size+1:7*p.network_size+p.number_of_targets], 0.0, p.target_cache)

@inbounds for i in 1:p.network_size
du[i] = ((a_s*(u[i] - b_s)^2 + c_s)*(u[i] - 1.0)
- j_s*u[3*p.network_size+i]*(u[i] - l_s)
- k_s*u[5*p.network_size+i]*(u[i] - m_s)
+ p.inputs_s[i] + p.recurrents_s[i] + p.I_s[i] + sigma_d_s*(u[p.network_size+i] - mu_d))/cm_s

du[p.network_size+i] = ((a_d*(u[p.network_size+i] - b_d)^2 + c_d)*(u[p.network_size+i] - 1.0)
- j_d*u[4*p.network_size+i]*(u[p.network_size+i] - l_d)
- k_d*u[6*p.network_size+i]*(u[p.network_size+i] - m_d)
+ p.inputs_d[i] + p.recurrents_d[i] + p.I_d[i] + sigma_s_d*(u[i] - mu_s))/cm_d

du[2*p.network_size+i] = (z_max*(tanh(q1*(u[3*p.network_size+i] - 0.1)) + 1.0) - u[2*p.network_size+i])/tau_s

du[3*p.network_size+i] = (d_s*(u[i] - e_s)^3 + f_s - u[3*p.network_size+i])/tau_u_s

du[4*p.network_size+i] = (d_d*(u[p.network_size+i] - e_d)^3 + f_d - u[4*p.network_size+i])/tau_u_d

du[5*p.network_size+i] = (g_s*(u[i] - h_s)^2 + i_s - u[5*p.network_size+i])/tau_w_s

du[6*p.network_size+i] = (g_d*(u[p.network_size+i] - h_d)^2 + i_d - u[6*p.network_size+i])/tau_w_d
end

@inbounds for i in 1:p.number_of_targets
du[7*p.network_size+i] = (p.target_cache[i] + p.input_cache[i] - u[7*p.network_size+i])/tau_s
end

nothing
end

function reverse_RHS!(du, u, p, t)
p.forward_solution(p.u, t)

@inbounds for i in 1:p.number_of_inputs
p.input_cache[i] = sin(p.frequencies[i, 1]*t) - sin(p.frequencies[i, 2]*t) + sin(p.frequencies[i, 3]*t) - sin(p.frequencies[i, 4]*t)
end

BLAS.gemv!('N', 1.0, p.U_s, p.input_cache, 0.0, p.inputs_s)
@views BLAS.gemv!('N', 1.0, p.R_s, p.u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_s)
@. p.total_s = p.inputs_s + p.recurrents_s + p.I_s
BLAS.gemv!('N', 1.0, p.R_s_T, p.total_s, 0.0, p.output_cache_s)

BLAS.gemv!('N', 1.0, p.U_d, p.input_cache, 0.0, p.inputs_d)
@views BLAS.gemv!('N', 1.0, p.R_d, p.u[2*p.network_size+1:3*p.network_size], 0.0, p.recurrents_d)
@. p.total_d = p.inputs_d + p.recurrents_d + p.I_d
BLAS.gemv!('N', 1.0, p.R_d_T, p.total_d, 0.0, p.output_cache_d)

@views BLAS.gemv!('N', 1.0, p.O, p.u[2*p.network_size+1:3*p.network_size], 0.0, p.outputs)

@views @. p.loss_cache = p.outputs - p.u[7*p.network_size+1:7*p.network_size+p.number_of_targets] + r3*p.outputs
BLAS.gemv!('N', 1.0, p.O_T, p.loss_cache, 0.0, p.dl_ds1)

@views BLAS.gemv!('N', 1.0, p.R_s_T, u[1:p.network_size], 0.0, p.reverse_cache_s)
@views BLAS.gemv!('N', 1.0, p.R_d_T, u[p.network_size+1:2*p.network_size], 0.0, p.reverse_cache_d)

@inbounds for i in 1:p.network_size
p.dl_ds1[i] += r1*(p.output_cache_s[i] + p.output_cache_d[i]) + r2*p.u[2*p.network_size+i]
end

@inbounds for i in 1:p.network_size
du[i] = 1/cm_s*(-1*u[i]*(2*a_s*(p.u[i] - b_s)*(p.u[i] - 1) + a_s*(p.u[i] - b_s)^2 + c_s - j_s*p.u[3*p.network_size+i] - k_s*p.u[5*p.network_size+i])
- u[p.network_size+i]*sigma_s_d
- u[3*p.network_size+i]*3*(d_s*(p.u[i] - e_s)^2)
- u[5*p.network_size+i]*2*(g_s*(p.u[i] - h_s)))

du[p.network_size+i] = 1/cm_d*(-1*u[i]*sigma_d_s
- u[p.network_size+i]*(2*a_d*(p.u[p.network_size+i] - b_d)*(p.u[p.network_size+i] - 1) + a_d*(p.u[p.network_size+i] - b_d)^2 + c_d - j_d*p.u[4*p.network_size+i] - k_d*p.u[6*p.network_size+i])
- u[4*p.network_size+i]*3*(d_d*(p.u[p.network_size+i] - e_d)^2)
- u[6*p.network_size+i]*2*(g_d*(p.u[p.network_size+i] - h_d)))

du[2*p.network_size+i] = 1/tau_s*(u[2*p.network_size+i] - p.reverse_cache_s[i] - p.reverse_cache_d[i] - p.dl_ds1[i])

du[3*p.network_size+i] = 1/tau_u_s*(u[3*p.network_size+i] + u[i]*(j_s*(p.u[i] - l_s)) - u[2*p.network_size+i]*z_max*q1*sech(q1*(p.u[3*p.network_size+i] - 0.1))^2)

du[4*p.network_size+i] = 1/tau_u_d*(u[4*p.network_size+i] + u[p.network_size+i]*(j_d*(p.u[p.network_size+i] - l_d)))

du[5*p.network_size+i] = 1/tau_w_s*(u[5*p.network_size+i] + u[i]*(k_s*(p.u[i] - m_s)))

du[6*p.network_size+i] = 1/tau_w_d*(u[6*p.network_size+i] + u[p.network_size+i]*(k_d*(p.u[p.network_size+i] - m_d)))
end

nothing
end

function forward_jacobian!(J, u, p, t)
@inbounds for i in 1:p.network_size
J[i, i] = (2*a_s*(u[i] - b_s)*(u[i] - 1)
+ a_s*(u[i] - b_s)^2 + c_s - j_s*u[3*p.network_size+i] - k_s*u[5*p.network_size+i])/cm_s
J[i, p.network_size+i] = sigma_d_s/cm_s
J[i, 3*p.network_size+i] = j_s*(l_s - u[i])/cm_s
J[i, 5*p.network_size+i] = k_s*(m_s - u[i])/cm_s

J[p.network_size+i, i] = sigma_s_d/cm_s
J[p.network_size+i, p.network_size+i] = (2*a_d*(u[p.network_size+i] - b_d)*(u[p.network_size+i] - 1)
+ a_d*(u[p.network_size+i] - b_d)^2 + c_d - j_d*u[4*p.network_size+i] - k_d*u[6*p.network_size+i])/cm_d
J[p.network_size+i, 4*p.network_size+i] = j_d*(l_d - u[p.network_size+i])/cm_d
J[p.network_size+i, 6*p.network_size+i] = k_d*(m_d - u[p.network_size+i])/cm_d

J[2*p.network_size+i, 2*p.network_size+i] = -1/tau_s
J[2*p.network_size+i, 3*p.network_size+i] = z_max*q1*sech(q1*(u[3*p.network_size+i] - 0.1))^2/tau_s

J[3*p.network_size+i, i] = 3*(d_s*(u[i] - e_s)^2)/tau_u_s
J[3*p.network_size+i, 3*p.network_size+i] = -1/tau_u_s

J[4*p.network_size+i, p.network_size+i] = 3*(d_d*(u[p.network_size+i] - e_d)^2)/tau_u_d
J[4*p.network_size+i, 4*p.network_size+i] = -1/tau_u_d

J[5*p.network_size+i, i] = 2*(g_s*(u[i] - h_s))/tau_w_s
J[5*p.network_size+i, 5*p.network_size+i] = -1/tau_w_s

J[6*p.network_size+i, p.network_size+i] = 2*(g_d*(u[p.network_size+i] - h_d))/tau_w_d
J[6*p.network_size+i, 6*p.network_size+i] = -1/tau_w_d

@inbounds for j in 1:p.network_size
J[i, 2*p.network_size+j] = p.R_s[i, j]/cm_s
J[p.network_size+i, 2*p.network_size+j] = p.R_d[i, j]/cm_d
end
end

@inbounds for i in 1:p.number_of_targets
J[7*p.network_size+i, 7*p.network_size+i] = -1/tau_s
end

nothing
end

function reverse_jacobian!(J, u, p, t)
p.forward_solution(p.u, t)

@inbounds for i in 1:p.network_size
J[i, i] = -1*(2*a_s*(p.u[i] - b_s)*(p.u[i] - 1)
+ a_s*(p.u[i] - b_s)^2 + c_s - j_s*p.u[3*p.network_size+i] - k_s*p.u[5*p.network_size+i])/cm_s
J[i, p.network_size+i] = -1*sigma_s_d/cm_s
J[i, 3*p.network_size+i] = -3*(d_s*(p.u[i] - e_s)^2)/cm_s
J[i, 5*p.network_size+i] = -2*(g_s*(p.u[i] - h_s))/cm_s

J[p.network_size+i, i] = -1*sigma_d_s/cm_d
J[p.network_size+i, p.network_size+i] = -1*(2*a_d*(p.u[p.network_size+i] - b_d)*(p.u[p.network_size+i] - 1)
+ a_d*(p.u[p.network_size+i] - b_d)^2 + c_d - j_d*p.u[4*p.network_size+i] - k_d*p.u[6*p.network_size+i])/cm_d
J[p.network_size+i, 4*p.network_size+i] = -3*(d_d*(p.u[p.network_size+i] - e_d)^2)/cm_d
J[p.network_size+i, 6*p.network_size+i] = -2*(g_d*(p.u[p.network_size+i] - h_d))/cm_d

J[2*p.network_size+i, 2*p.network_size+i] = 1/tau_s

J[3*p.network_size+i, i] = j_s*(p.u[i] - l_s)/tau_u_s
J[3*p.network_size+i, 2*p.network_size+i] = -1*z_max*q1*sech(q1*(p.u[3*p.network_size+i] - 0.1))^2/tau_u_s
J[3*p.network_size+i, 3*p.network_size+i] = 1/tau_u_s

J[4*p.network_size+i, p.network_size+i] = j_d*(p.u[p.network_size+i] - l_d)/tau_u_d
J[4*p.network_size+i, 4*p.network_size+i] = 1/tau_u_d

J[5*p.network_size+i, i] = k_s*(p.u[i] - m_s)/tau_w_s
J[5*p.network_size+i, 5*p.network_size+i] = 1/tau_w_s

J[6*p.network_size+i, p.network_size+i] = k_d*(p.u[p.network_size+i] - m_d)/tau_w_d
J[6*p.network_size+i, 6*p.network_size+i] = 1/tau_w_d

@inbounds for j in 1:p.network_size
J[2*p.network_size+i, j] = -1*p.R_s_T[i, j]/tau_s
J[2*p.network_size+i, p.network_size+j] = -1*p.R_d_T[i, j]/tau_s
end
end

nothing
end

forward_parameters = ForwardParameters(number_of_inputs = number_of_inputs, network_size = network_size, number_of_targets = number_of_targets,
U_s = U_s, U_d = U_d, O = O,
R_s = R_s, R_d = R_d,
I_s = I_s, I_d = I_d,
frequencies = frequencies, A = A)

forward_J = zeros(7*network_size + number_of_targets, 7*network_size + number_of_targets)
forward_jacobian!(forward_J, zeros(7*network_size + number_of_targets), forward_parameters, 0.0)

forward_ODEFunction = ODEFunction(forward_RHS!, jac = forward_jacobian!, jac_prototype = sparse(forward_J))
forward_ODE = ODEProblem(forward_ODEFunction, zeros(7*network_size + number_of_targets), (0.0, simulation_length), forward_parameters)
forward_solution = solve(forward_ODE, AutoVern7(Rodas5()), reltol = 1e-10, abstol = 1e-10)

reverse_parameters = ReverseParameters(number_of_inputs = number_of_inputs, network_size = network_size, number_of_targets = number_of_targets,
U_s = U_s, U_d = U_d, O = O,
R_s = R_s, R_d = R_d,
I_s = I_s, I_d = I_d,
forward_solution = forward_solution,
frequencies = frequencies, A = A)

reverse_J = zeros(7*network_size, 7*network_size)
reverse_jacobian!(reverse_J, zeros(7*network_size), reverse_parameters, simulation_length)

reverse_ODEFunction = ODEFunction(reverse_RHS!, jac = reverse_jacobian!, jac_prototype = sparse(reverse_J))
reverse_ODE = ODEProblem(reverse_ODEFunction, zeros(7*network_size), (simulation_length, 0.0), reverse_parameters)
reverse_solution = solve(reverse_ODE, Vern8(), dense = true, reltol = 1e-10, abstol = 1e-10)