# Problem with training a spiking neural network with Flux

Hello,

I am trying to use Flux to train a simple spiking neural network to perform input classification. The network has 100 inputs that represent incoming spike streams: half of the inputs are “active” (i.e., they receive spikes at a rate of 10 spike/s) during half of the trials, while during the other half the remaining half of the inputs are active. The network contains two output neurons, which should learn to distinguish between the input patterns: this can be done either by having the output neurons emit spikes or by having them integrate their spiking inputs in such a way that the mean membrane potential of each neuron codes for the type of input pattern received. Since I’m having troubles computing the gradients in the spiking case, I resorted to the latter option, therefore effectively removing any type of nonlinearity from the network.

Even with this simplified network configuration, I run into an error when trying to compute the gradients of the synaptic weights. The relevant piece of code where I try to do this is the following:

``````# neuron parameters
tau_mem     = 10e-3 # [s]
tau_syn     = 5e-3  # [s]
# network parameters
n_inputs    = 100
n_outputs   = 2
# simulation parameters
dt          = 1e-3  # [s]
batch_size  = 32
# simulation constants
alpha       = exp(-dt/tau_syn)
beta        = exp(-dt/tau_mem)
# simulation variables
I_syn_curr  = zeros(batch_size, n_outputs)
Vm_curr     = zeros(size(I_syn_curr))
Vm_acc      = zeros(size(Vm_curr))

pars = Flux.Params([w])

@einsum I_inp[a,c,d] := inputs[a,b,d] * w[b,c]
for t in 1 : n_steps
Vm_acc = Vm_acc + Vm_curr
I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
Vm_next = beta * Vm_curr + I_syn_curr
I_syn_curr = I_syn_next
Vm_curr = Vm_next
end
y_hat = Vm_acc ./ n_steps
Flux.logitcrossentropy(y_hat, y)

end
``````

`inputs` is a binary matrix with shape (`batch_size`, `n_inputs`, `n_steps`) with ones coding for the instants at which an input spike occurred, `w` is the `n_inputs` x `n_outputs` matrix of the weights, and `y` is a one hot matrix encoding the desired output of the two neurons.

When running the code, I get the following error:

``````ERROR: LoadError: MethodError: no method matching +(::IRTools.Inner.Undefined, ::IRTools.Inner.Undefined)

Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...)
@ Base operators.jl:578
+(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any)
@ InitialValues ~/.julia/packages/InitialValues/OWP8V/src/InitialValues.jl:154
+(::ChainRulesCore.Tangent{P}, ::P) where P
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:146
...

Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0 [inlined]
[2] _pullback(::Zygote.Context{true}, ::typeof(+), ::IRTools.Inner.Undefined, ::IRTools.Inner.Undefined)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:9
[3] _pullback
[4] _pullback(::Zygote.Context{true}, ::var"#8#9")
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[5] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Matrix{Float64}, Vector{Matrix{Float64}}}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:384
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:132
[7] top-level scope
``````

I should also add that I can successfully train the same network (spiking and non-spiking) using pytorch.

Thanks to anyone who might shed some light on what I’m doing wrong.

Not very experienced with Flux here, but what that Stacktrace is telling is that there is no sum operation defined between two objects of the type `IRTools.Inner.Undefined`.

You seem to be summing arrays here:

``````I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
Vm_next = beta * Vm_curr + I_syn_curr
``````

Have you tried running a timestep of this outside the Flux block to replicate going through those lines yourself?

Yes, if I comment out the lines

``````# pars = Flux.Params([w])
...
# end
``````

then everything works fine.

I’m guessing it has to do with the backward part of the gradient computation (i.e., the call to `_pullback` in Zygote), but I can’t figure out whether I’m doing something wrong or if this is a limitation of Zygote and/or ChainRules.

Can you edit your example to define all variables?

I don’t think `@einsum` will be Zygote-differentiable, you could try `@tullio` instead.

But there are probably other problems. There are many global variables, and it’s possible these are upsetting Zygote.

gradient computation (i.e., the call to `_pullback` in Zygote),

Note that `Zygote._pullback` might be the forward pass, within Zygote. You can call `y, back = pullback(f, x...)` to run only the forward pass, and `back(1.0)` for the gradient.

2 Likes

Is all this code running at the top level? Can you put it inside a function and see if that makes a difference (please update the MWE if so)?

Ok, I think I might have found a solution. It looks like the problem had indeed to do with `@einsum`: I’ve replaced that call with the equivalent `@tullio` and things seem to be working fine now.

The way I’m doing it now, I first define a function that simulates the network of neurons and returns the mean membrane potential for each trial:

``````function run_leaky_NN(tau_mem::Number, tau_syn::Number, dt::Number, spikes_in::Array, w::Matrix)
alpha = exp(-dt / tau_syn)
beta  = exp(-dt / tau_mem)
n_batches, n_spikes_in, n_steps = size(spikes_in)
n_outputs = size(w, 2)
I_syn_curr  = zeros(batch_size, n_outputs);
Vm_curr     = zeros(size(I_syn_curr));
Vm_acc      = zeros(size(Vm_curr));
@tullio I_inp[a,c,d] := spikes_in[a,b,d] * w[b,c]
for t in 1 : n_steps
I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
Vm_next = beta * Vm_curr + I_syn_curr
Vm_acc = Vm_acc + Vm_curr
I_syn_curr = I_syn_next
Vm_curr = Vm_next
end
Vm_acc ./ n_steps
end;
``````

I can then compute the gradient by running

``````pars = Flux.Params([w])
y_hat = run_leaky_NN(tau_mem, tau_syn, dt, inputs, w)
Flux.logitcrossentropy(y_hat, y)
end
``````

And tune the weights by running

``````pars = Flux.Params([w])
for epoch in 1:100
y_hat = fun(tau_mem, tau_syn, dt, inputs, w)
Flux.logitcrossentropy(y_hat, y)
end
end
``````

which indeed leads to a network that behaves as expected.

I’m now trying to include a spiking nonlinearity in the neurons and therefore have my network classify patterns based on the number of spikes, but that does not seem to be working. I’ll post a MWE of this second scenario here in a bit.

1 Like

I have modified my code to run a spiking neural network, i.e., the cells now have a threshold and reset, represented by the Heaviside function in the code below:

``````using Random
using Tullio
using Flux

heaviside(x::Number; thresh::Number=0.) = 0.5 * (1 + (x-thresh) / sqrt((x-thresh)^2));

function run_spiking_NN(tau_mem::Number, tau_syn::Number, dt::Number, spikes_in::Array, w::Matrix, thresh::Number=1.)
alpha = exp(-dt / tau_syn)
beta  = exp(-dt / tau_mem)
n_batches, n_spikes_in, n_steps = size(spikes_in)
n_outputs = size(w, 2)
I_syn_curr  = zeros(batch_size, n_outputs);
Vm_curr     = zeros(size(I_syn_curr));
spikes_out  = zeros(size(Vm_curr));
@tullio I_inp[a,c,d] := spikes_in[a,b,d] * w[b,c]
for t in 1 : n_steps
reset = heaviside.(Vm_curr, thresh=thresh)
I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
Vm_next = beta * Vm_curr + I_syn_curr - reset
spikes_out += reset
I_syn_curr = I_syn_next
Vm_curr = Vm_next
end
spikes_out
end;

# neuron parameters
tau_mem    = 10e-3
tau_syn    = 5e-3
# network parameters
input_rate = 10 # [Hz]
n_inputs   = 20
n_outputs  = 2
# simulation parameters
tend       = 0.2
dt         = 1e-3
n_steps    = Int(tend / dt)

batch_size = 32;
rng = MersenneTwister(100);

inputs = zeros(Int, batch_size, n_inputs, n_steps)
prob = rand(rng, Float64, size(inputs))
inputs[prob .<= input_rate*dt] .= 1
println("Total number of input spikes: ", sum(inputs))

truth = ones(Int, batch_size)
truth[rand(rng, Float64, size(truth)) .> 0.5] .= 2
classes = unique(truth)
y = Flux.onehotbatch(truth, classes)';

weight_scale = 7 * (1 - exp(-dt/tau_mem));
w = weight_scale / sqrt(n_inputs) * randn(rng, Float64, (n_inputs, n_outputs));

pars = Flux.Params([w])
y_hat = run_spiking_NN(tau_mem, tau_syn, dt, inputs, w)
Flux.logitcrossentropy(y_hat, y)
end

# the following line prints a value around 1e-13:
``````

The problem now is that the computed gradients w.r.t. the synaptic weights are all zero, which of course prevents any optimization to take place. However, in a spiking neural network with no hidden layers, this should work fine: indeed, the same example implemented in PyTorch behaves as expected.

I’ve tried to follow @mcabbott’s suggestion and run `y,back = pullback(pars) do ... end`, but `back(1.0)` gives the same results as `Flux.withgradient`.

Any help would be much appreciated.

I think I’ve solved the problems I had with training a spiking neural network using Flux. The key is defining an appropriate adjoint for the spiking threshold function.

What I’ve come up with is the following (I’m posting it here in case someone were to run in the same kind of problem):

``````using ChainRulesCore

spike_fun(x::Number) = x > 0. ? 1. : 0.
spike_fun(x::AbstractArray{<:Number}) = spike_fun.(x)
function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:HasReverseMode},
::typeof(spike_fun), x::Union{Number,AbstractArray{<:Number}})
retval = spike_fun(x)
pullback_spike_fun(y) = NoTangent(), y ./ (1.0 .+ 100.0 * abs.(x)).^2
return retval, pullback_spike_fun
end
``````

and the `for` loop in the `run_spiking_NN` function above should be replaced with

``````for t in 1 : n_steps
reset = spike_fun(Vm_curr .- θ)
I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
Vm_next = beta * Vm_curr + I_syn_curr - reset
spikes_out += reset
I_syn_curr = I_syn_next
Vm_curr = Vm_next
end
``````

where the crucial point is not to use broadcasting (i.e., `heaviside.` as I was doing before), otherwise the `rrule` is not called and the gradient cannot be computed.

Hope this helps somebody else.

4 Likes