Problem with training a spiking neural network with Flux

Hello,

I am trying to use Flux to train a simple spiking neural network to perform input classification. The network has 100 inputs that represent incoming spike streams: half of the inputs are “active” (i.e., they receive spikes at a rate of 10 spike/s) during half of the trials, while during the other half the remaining half of the inputs are active. The network contains two output neurons, which should learn to distinguish between the input patterns: this can be done either by having the output neurons emit spikes or by having them integrate their spiking inputs in such a way that the mean membrane potential of each neuron codes for the type of input pattern received. Since I’m having troubles computing the gradients in the spiking case, I resorted to the latter option, therefore effectively removing any type of nonlinearity from the network.

Even with this simplified network configuration, I run into an error when trying to compute the gradients of the synaptic weights. The relevant piece of code where I try to do this is the following:

# neuron parameters
tau_mem     = 10e-3 # [s]
tau_syn     = 5e-3  # [s]
# network parameters
n_inputs    = 100
n_outputs   = 2
# simulation parameters
dt          = 1e-3  # [s]
batch_size  = 32
# simulation constants
alpha       = exp(-dt/tau_syn)
beta        = exp(-dt/tau_mem)
# simulation variables
I_syn_curr  = zeros(batch_size, n_outputs)
Vm_curr     = zeros(size(I_syn_curr))
Vm_acc      = zeros(size(Vm_curr))

pars = Flux.Params([w])
loss,grads = Flux.withgradient(pars) do

    @einsum I_inp[a,c,d] := inputs[a,b,d] * w[b,c]
    for t in 1 : n_steps
        Vm_acc = Vm_acc + Vm_curr
        I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
        Vm_next = beta * Vm_curr + I_syn_curr
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    end
    y_hat = Vm_acc ./ n_steps
    Flux.logitcrossentropy(y_hat, y)

end

inputs is a binary matrix with shape (batch_size, n_inputs, n_steps) with ones coding for the instants at which an input spike occurred, w is the n_inputs x n_outputs matrix of the weights, and y is a one hot matrix encoding the desired output of the two neurons.

When running the code, I get the following error:

ERROR: LoadError: MethodError: no method matching +(::IRTools.Inner.Undefined, ::IRTools.Inner.Undefined)

Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...)
   @ Base operators.jl:578
  +(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any)
   @ InitialValues ~/.julia/packages/InitialValues/OWP8V/src/InitialValues.jl:154
  +(::ChainRulesCore.Tangent{P}, ::P) where P
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:146
  ...

Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(::Zygote.Context{true}, ::typeof(+), ::IRTools.Inner.Undefined, ::IRTools.Inner.Undefined)
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:9
 [3] _pullback
   @ ~/Downloads/Classification_SNN_Julia.jl:120 [inlined]
 [4] _pullback(::Zygote.Context{true}, ::var"#8#9")
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
 [5] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Matrix{Float64}, Vector{Matrix{Float64}}}})
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:384
 [6] withgradient(f::Function, args::Zygote.Params{Zygote.Buffer{Matrix{Float64}, Vector{Matrix{Float64}}}})
   @ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:132
 [7] top-level scope
   @ ~/Downloads/Classification_SNN_Julia.jl:116

I should also add that I can successfully train the same network (spiking and non-spiking) using pytorch.

Thanks to anyone who might shed some light on what I’m doing wrong.

Not very experienced with Flux here, but what that Stacktrace is telling is that there is no sum operation defined between two objects of the type IRTools.Inner.Undefined.

You seem to be summing arrays here:

I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
Vm_next = beta * Vm_curr + I_syn_curr

Have you tried running a timestep of this outside the Flux block to replicate going through those lines yourself?

Yes, if I comment out the lines

# pars = Flux.Params([w])
# loss,grads = Flux.withgradient(pars) do
...
# end

then everything works fine.

I’m guessing it has to do with the backward part of the gradient computation (i.e., the call to _pullback in Zygote), but I can’t figure out whether I’m doing something wrong or if this is a limitation of Zygote and/or ChainRules.

Can you edit your example to define all variables?

I don’t think @einsum will be Zygote-differentiable, you could try @tullio instead.

But there are probably other problems. There are many global variables, and it’s possible these are upsetting Zygote.

gradient computation (i.e., the call to _pullback in Zygote),

Note that Zygote._pullback might be the forward pass, within Zygote. You can call y, back = pullback(f, x...) to run only the forward pass, and back(1.0) for the gradient.

2 Likes

Is all this code running at the top level? Can you put it inside a function and see if that makes a difference (please update the MWE if so)?

Ok, I think I might have found a solution. It looks like the problem had indeed to do with @einsum: I’ve replaced that call with the equivalent @tullio and things seem to be working fine now.

The way I’m doing it now, I first define a function that simulates the network of neurons and returns the mean membrane potential for each trial:

function run_leaky_NN(tau_mem::Number, tau_syn::Number, dt::Number, spikes_in::Array, w::Matrix)
    alpha = exp(-dt / tau_syn)
    beta  = exp(-dt / tau_mem)
    n_batches, n_spikes_in, n_steps = size(spikes_in)
    n_outputs = size(w, 2)
    I_syn_curr  = zeros(batch_size, n_outputs);
    Vm_curr     = zeros(size(I_syn_curr));
    Vm_acc      = zeros(size(Vm_curr));
    @tullio I_inp[a,c,d] := spikes_in[a,b,d] * w[b,c]
    for t in 1 : n_steps
        I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
        Vm_next = beta * Vm_curr + I_syn_curr
        Vm_acc = Vm_acc + Vm_curr
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    end
    Vm_acc ./ n_steps
end;

I can then compute the gradient by running

pars = Flux.Params([w])
l,grads = Flux.withgradient(pars) do
    y_hat = run_leaky_NN(tau_mem, tau_syn, dt, inputs, w)
    Flux.logitcrossentropy(y_hat, y)
end

And tune the weights by running

pars = Flux.Params([w])
optimizer = Adam(2e-3, (0.9, 0.999)) # simple gradient descent
for epoch in 1:100
    loss,grads = Flux.withgradient(pars) do
        y_hat = fun(tau_mem, tau_syn, dt, inputs, w)
        Flux.logitcrossentropy(y_hat, y)
    end
    Flux.Optimise.update!(optimizer, pars, grads)
end

which indeed leads to a network that behaves as expected.

I’m now trying to include a spiking nonlinearity in the neurons and therefore have my network classify patterns based on the number of spikes, but that does not seem to be working. I’ll post a MWE of this second scenario here in a bit.

Thanks everyone for your help!

1 Like

I have modified my code to run a spiking neural network, i.e., the cells now have a threshold and reset, represented by the Heaviside function in the code below:

using Random
using Tullio
using Flux

heaviside(x::Number; thresh::Number=0.) = 0.5 * (1 + (x-thresh) / sqrt((x-thresh)^2));

function run_spiking_NN(tau_mem::Number, tau_syn::Number, dt::Number, spikes_in::Array, w::Matrix, thresh::Number=1.)
    alpha = exp(-dt / tau_syn)
    beta  = exp(-dt / tau_mem)
    n_batches, n_spikes_in, n_steps = size(spikes_in)
    n_outputs = size(w, 2)
    I_syn_curr  = zeros(batch_size, n_outputs);
    Vm_curr     = zeros(size(I_syn_curr));
    spikes_out  = zeros(size(Vm_curr));
    @tullio I_inp[a,c,d] := spikes_in[a,b,d] * w[b,c]
    for t in 1 : n_steps
        reset = heaviside.(Vm_curr, thresh=thresh)
        I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
        Vm_next = beta * Vm_curr + I_syn_curr - reset
        spikes_out += reset
        I_syn_curr = I_syn_next
        Vm_curr = Vm_next
    end
    spikes_out
end;

# neuron parameters
tau_mem    = 10e-3
tau_syn    = 5e-3
# network parameters
input_rate = 10 # [Hz]
n_inputs   = 20
n_outputs  = 2
# simulation parameters
tend       = 0.2
dt         = 1e-3
n_steps    = Int(tend / dt)

batch_size = 32;
rng = MersenneTwister(100);

inputs = zeros(Int, batch_size, n_inputs, n_steps)
prob = rand(rng, Float64, size(inputs))
inputs[prob .<= input_rate*dt] .= 1
println("Total number of input spikes: ", sum(inputs))

truth = ones(Int, batch_size)
truth[rand(rng, Float64, size(truth)) .> 0.5] .= 2
classes = unique(truth)
y = Flux.onehotbatch(truth, classes)';

weight_scale = 7 * (1 - exp(-dt/tau_mem));
w = weight_scale / sqrt(n_inputs) * randn(rng, Float64, (n_inputs, n_outputs));

pars = Flux.Params([w])
l,grads = Flux.withgradient(pars) do
    y_hat = run_spiking_NN(tau_mem, tau_syn, dt, inputs, w)
    Flux.logitcrossentropy(y_hat, y)
end

# the following line prints a value around 1e-13:
println("Max abs gradient: ", maximum(abs.(grads[w])))

The problem now is that the computed gradients w.r.t. the synaptic weights are all zero, which of course prevents any optimization to take place. However, in a spiking neural network with no hidden layers, this should work fine: indeed, the same example implemented in PyTorch behaves as expected.

I’ve tried to follow @mcabbott’s suggestion and run y,back = pullback(pars) do ... end, but back(1.0) gives the same results as Flux.withgradient.

Any help would be much appreciated.

I think I’ve solved the problems I had with training a spiking neural network using Flux. The key is defining an appropriate adjoint for the spiking threshold function.

What I’ve come up with is the following (I’m posting it here in case someone were to run in the same kind of problem):

using ChainRulesCore

spike_fun(x::Number) = x > 0. ? 1. : 0.
spike_fun(x::AbstractArray{<:Number}) = spike_fun.(x)
function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:HasReverseMode},
        ::typeof(spike_fun), x::Union{Number,AbstractArray{<:Number}})
    retval = spike_fun(x)
    pullback_spike_fun(y) = NoTangent(), y ./ (1.0 .+ 100.0 * abs.(x)).^2
    return retval, pullback_spike_fun
end

and the for loop in the run_spiking_NN function above should be replaced with

for t in 1 : n_steps
    reset = spike_fun(Vm_curr .- θ)
    I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
    Vm_next = beta * Vm_curr + I_syn_curr - reset
    spikes_out += reset
    I_syn_curr = I_syn_next
    Vm_curr = Vm_next
end

where the crucial point is not to use broadcasting (i.e., heaviside. as I was doing before), otherwise the rrule is not called and the gradient cannot be computed.

Hope this helps somebody else.

4 Likes