Hey there
I would like some advice for implementing efficient code for a spiking neuron network with coupling that depends on spike times. When a neuron in the model spikes (one of its variables, the membrane potential, crosses a threshold going up), the coupled unit receives a current that is a function of time and the spike time, such as (t-t_s) exp(1 - t-t_s)
, with t_s
being the spike time.
So I need to first detect the spikes and then use the spike times in the ODE. I have implemented this through a VectorContinuousCallback, with the spike times being contained in the ODE’s parameters. But I’m unsure if this is optimal, and I have some questions. I haven’t found any discussion on this anywhere, even though this type of coupling is very common for chemical synapses. I’d really appreciate some insight!
A minimal working example is shown below for a system with two neurons, with a coupling from neuron 1 to neuron 2. Please ignore all the parameters haha.
using DrWatson
using DiffEqCallbacks, OrdinaryDiffEq
function conductance_alpha(t, t_s, p)
if t_s <= 0 return 0.0 end
δt = (t-t_s)/p.τs
return δt * exp(1 - δt)
end
function coupling(i, u, p, t)
if i == 2
t_s = p.spiketimes[1]; V = u[1,2];
return p.gmax * conductance_alpha(t, t_s, p) * (p.E - V)
else
return 0.0
end
end
function fitzhugh_nagumo!(du, u, p, t)
@unpack a, b, c, I = p
for i = 1:size(du,2)
V, w = @view u[:, i]
du[1, i] = V*(a-V)*(V-1) -w + I + coupling(i, u, p, t) #Vdot
du[2, i] = b*V - c*w
end
end
function condition_spike(out, u, t, integrator)
for i = 1:size(u,2)
V = u[1, i];
out[i] = V - integrator.p.Vth
end
end
function affect_spike!(integrator, idx)
integrator.p.spiketimes[idx] = integrator.t #TODO: what if multiple neurons spike at the same time-step??
end
mutable struct params{A, B}
a :: A
b :: A
c :: A
I :: A
τs :: A
gmax :: A
E :: A
spiketimes :: Vector{B}
Vth :: A
end
N = 2;
a = -0.5; b = 0.1; c = 0.2; I = 0.0;
Vth = 0.5; gmax=0.1; E=4.0; τs = 1.0;
u0 = rand(Float64, (2, N))
p = params(a, b, c, I, τs, gmax, E, zeros(N), Vth)
T = 100.0;
cb = VectorContinuousCallback(condition_spike,affect_spike!, N; affect_neg! = nothing);
prob = ODEProblem(fitzhugh_nagumo!, u0, (0, T), p);
sol = solve(prob, Vern9(), callback=cb)
using CairoMakie
fig = Figure()
ax = Axis(fig[1,1])
for i=1:N
lines!(ax, sol.t, sol[1, i, :])
scatter!(ax, p.spiketimes[i], Vth)
end
fig
Specifically:
- Is there a more efficient implementation for this? Or is saving the event times in the parameters best already?
- What happens if events occurs simultaneously? How does
affect_spike!(integrator, idx)
deals with what should be multiple indices? - Does it make sense to pass the
idxs
parameter toVectorContinuousCallback
to specify that I only need to interpolate the first variable of each unit?
Thanks a lot!