ODE terms dependent on time of events

Hey there

I would like some advice for implementing efficient code for a spiking neuron network with coupling that depends on spike times. When a neuron in the model spikes (one of its variables, the membrane potential, crosses a threshold going up), the coupled unit receives a current that is a function of time and the spike time, such as (t-t_s) exp(1 - t-t_s), with t_s being the spike time.
So I need to first detect the spikes and then use the spike times in the ODE. I have implemented this through a VectorContinuousCallback, with the spike times being contained in the ODE’s parameters. But I’m unsure if this is optimal, and I have some questions. I haven’t found any discussion on this anywhere, even though this type of coupling is very common for chemical synapses. I’d really appreciate some insight!

A minimal working example is shown below for a system with two neurons, with a coupling from neuron 1 to neuron 2. Please ignore all the parameters haha.

using DrWatson
using DiffEqCallbacks, OrdinaryDiffEq

function conductance_alpha(t, t_s, p)
    if t_s <= 0 return 0.0 end
    δt = (t-t_s)/p.τs
    return δt * exp(1 - δt)
end

function coupling(i, u, p, t)
    if i == 2
        t_s = p.spiketimes[1]; V = u[1,2];
        return p.gmax * conductance_alpha(t, t_s, p) * (p.E - V)
    else
        return 0.0
    end
end

function fitzhugh_nagumo!(du, u, p, t)
    @unpack a, b, c, I = p
    for i = 1:size(du,2)
        V, w = @view u[:, i]
        du[1, i] = V*(a-V)*(V-1) -w + I + coupling(i, u, p, t) #Vdot
        du[2, i] = b*V - c*w
    end
end

function condition_spike(out, u, t, integrator)
    for i = 1:size(u,2)
        V = u[1, i];
        out[i] = V - integrator.p.Vth
    end
end

function affect_spike!(integrator, idx)
    integrator.p.spiketimes[idx] = integrator.t #TODO: what if multiple neurons spike at the same time-step??
end

mutable struct params{A, B}
    a :: A
    b :: A
    c :: A
    I :: A
    τs :: A
    gmax :: A
    E :: A
    spiketimes :: Vector{B}
    Vth :: A
end

N = 2;
a = -0.5; b = 0.1; c = 0.2; I = 0.0;
Vth = 0.5; gmax=0.1; E=4.0; τs = 1.0;
u0 = rand(Float64, (2, N))
p = params(a, b, c, I, τs, gmax, E, zeros(N), Vth)

T = 100.0;
cb = VectorContinuousCallback(condition_spike,affect_spike!, N; affect_neg! = nothing);
prob = ODEProblem(fitzhugh_nagumo!, u0, (0, T), p);
sol = solve(prob, Vern9(), callback=cb)

using CairoMakie
fig = Figure()
ax = Axis(fig[1,1])
for i=1:N
    lines!(ax, sol.t, sol[1, i, :])
    scatter!(ax, p.spiketimes[i], Vth)
end
fig

Specifically:

  1. Is there a more efficient implementation for this? Or is saving the event times in the parameters best already?
  2. What happens if events occurs simultaneously? How does affect_spike!(integrator, idx) deals with what should be multiple indices?
  3. Does it make sense to pass the idxs parameter to VectorContinuousCallback to specify that I only need to interpolate the first variable of each unit?

Thanks a lot!

Hi, it is a bit hard to tell which parts of your code a implementation details and which parts are essential for the model. Could write down the mathematical problem you would like to solve?

Sure! There are quite a few parameters, sorry. The system has two units, each with variables V and w. Variable V is the main one I’m interested in. Throughout time, this variable spikes like in the figure below. I consider a spike to occur when V crosses a threshold V_{th} upwards. When unit 1 spikes, at a certain time t_s, it sends a current I_{coup} to unit 2, which gets added to unit 2’s \dot{V}. The current I_{coup} has a time dependence of the type (t-t_s) \exp(1-(t-t_s)).

\begin{align} \dot{V_1} &= V_1(a-V_1)(V_1-1) - w_1 + I \\ \dot{w_1} &= bV_1 - cw_1 \\ \dot{V_2} &= V_2(a-V_2)(V_2-1) -w_2 + I + I_{coup} \\ \dot{w_2} &= bV_2 - c w_2 \\ I_{coup} &= g_{max} \frac{t-t_s}{\tau_s} \exp \left(1-\frac{t-t_s}{\tau_s}\right) (E-V_2) \\ \end{align}

My implementation works, as seen in the figure. The lines are the units potential V, and the dots demark the last spike time of each unit.

The implementation details I’m concerned with revolve around how to save the spike times and use them in the rhs function efficiently. The spike time enters as t_s in coupling and conductance_alpha.

It might be useful to include the spike condition in I_\text{coup} using the Dirac-\delta and Heaviside-\theta functions:

I_\text{coup}(t) =\ldots \int_0^t\delta(V_1(t')-V_\text{th})\theta(\dot V_1)\frac{t-t'}{\tau_s}\exp\left(1-\frac{t-t'}{\tau_s}\right)\,\mathrm{d}t'

Written in this convolutional form, we see that this is a integro-differential equation. It might be possible to rewrite it fully differential starting from this expression…

(Note that here the spikes compound. You might need some extra terms if you only want the latest spike to contribute. Although if the inter-spike time is larger than \tau_s the difference should be negligible.)

Good point, I will, indeed the spikes compound :slight_smile: Thanks!

But I don’t see how how that helps with the code though.

Rather than considering this is an ordinary differential equation with callbacks, you could think of it as a delay differential equation

https://diffeq.sciml.ai/stable/types/dde_types/