Hello,
I am trying to use Flux to train a simple spiking neural network to perform input classification. The network has 100 inputs that represent incoming spike streams: half of the inputs are “active” (i.e., they receive spikes at a rate of 10 spike/s) during half of the trials, while during the other half the remaining half of the inputs are active. The network contains two output neurons, which should learn to distinguish between the input patterns: this can be done either by having the output neurons emit spikes or by having them integrate their spiking inputs in such a way that the mean membrane potential of each neuron codes for the type of input pattern received. Since I’m having troubles computing the gradients in the spiking case, I resorted to the latter option, therefore effectively removing any type of nonlinearity from the network.
Even with this simplified network configuration, I run into an error when trying to compute the gradients of the synaptic weights. The relevant piece of code where I try to do this is the following:
# neuron parameters
tau_mem = 10e-3 # [s]
tau_syn = 5e-3 # [s]
# network parameters
n_inputs = 100
n_outputs = 2
# simulation parameters
dt = 1e-3 # [s]
batch_size = 32
# simulation constants
alpha = exp(-dt/tau_syn)
beta = exp(-dt/tau_mem)
# simulation variables
I_syn_curr = zeros(batch_size, n_outputs)
Vm_curr = zeros(size(I_syn_curr))
Vm_acc = zeros(size(Vm_curr))
pars = Flux.Params([w])
loss,grads = Flux.withgradient(pars) do
@einsum I_inp[a,c,d] := inputs[a,b,d] * w[b,c]
for t in 1 : n_steps
Vm_acc = Vm_acc + Vm_curr
I_syn_next = alpha * I_syn_curr + I_inp[:, :, t]
Vm_next = beta * Vm_curr + I_syn_curr
I_syn_curr = I_syn_next
Vm_curr = Vm_next
end
y_hat = Vm_acc ./ n_steps
Flux.logitcrossentropy(y_hat, y)
end
inputs
is a binary matrix with shape (batch_size
, n_inputs
, n_steps
) with ones coding for the instants at which an input spike occurred, w
is the n_inputs
x n_outputs
matrix of the weights, and y
is a one hot matrix encoding the desired output of the two neurons.
When running the code, I get the following error:
ERROR: LoadError: MethodError: no method matching +(::IRTools.Inner.Undefined, ::IRTools.Inner.Undefined)
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...)
@ Base operators.jl:578
+(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any)
@ InitialValues ~/.julia/packages/InitialValues/OWP8V/src/InitialValues.jl:154
+(::ChainRulesCore.Tangent{P}, ::P) where P
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:146
...
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0 [inlined]
[2] _pullback(::Zygote.Context{true}, ::typeof(+), ::IRTools.Inner.Undefined, ::IRTools.Inner.Undefined)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:9
[3] _pullback
@ ~/Downloads/Classification_SNN_Julia.jl:120 [inlined]
[4] _pullback(::Zygote.Context{true}, ::var"#8#9")
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[5] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Matrix{Float64}, Vector{Matrix{Float64}}}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:384
[6] withgradient(f::Function, args::Zygote.Params{Zygote.Buffer{Matrix{Float64}, Vector{Matrix{Float64}}}})
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:132
[7] top-level scope
@ ~/Downloads/Classification_SNN_Julia.jl:116
I should also add that I can successfully train the same network (spiking and non-spiking) using pytorch.
Thanks to anyone who might shed some light on what I’m doing wrong.