Hi all,
I’m trying to define a custom Layer for Flux:
struct CANCell
topology::Symbol
n::NTuple{D,Int} where {D} # number of neurons in each dimension
d::Int # number of dimensions
I::Vector{Tuple} # index (i,j...) of each neuron in the lattice
Z # gpu(Matrix)
W # gpu(Matrix)
∂ᵢW # gpu.(::Vector{Matrix})
b₀::Float32
τ::Float32
state0
β # gpu(::Vector{Float32})
end
where β
is the only trainable parameter.
It’s a recurrent model so I define a forward pass:
function (can::CANCell)(h, v::AbstractVecOrMat)
W, b₀, τ, ∂ᵢW = can.W, can.b₀, can.τ, can.∂ᵢW
σ = Flux.NNlib.fast_act(relu, v)
if size(v, 2) != size(h, 2)
# reshape h
rs = gpu(zeros(Float32, size(h, 1), size(v, 2)))
h = h .+ rs
end
dₜh = muladd(W, h, b₀)
idx = argmax.(eachcol(h))
for d in 1:can.d
∂W = can.β .* (∂ᵢW[d][:, idx] .* v[d, :]')
dₜh = dₜh + ∂W
end
h = h + (σ.(dₜh) .- h)/τ
idx = argmax.(eachcol(h))
y = can.Z[:, idx]
return h, reshape_cell_output(y, v)
end
and then use @functor
and Recur
:
@functor CANCell
Flux.trainable(c::CANCell) = (; β = c.β) # parameters to train
CAN(args...; kwargs...) = Recur(CANCell(args...))
Recur(m::CANCell) = Recur(m, m.state0)
Finally, the whole thing is a layer of a larger model:
Chain(
input = Chain(
Dense(N, nh),
Dense(nh, nh, tanh),
),
recur = Chain(
RNN(nh, d, tanh),
Dense(d, nh),
Dense(nh, nh),
Dense(nh, size(can.cell.W, 1)),
can,
),
)
This doesn’t train correctly though. If I print the gradients during training I get:
(layers = (input = nothing, recur = (layers = (nothing, nothing, nothing, nothing,
(cell = (topology = nothing, n = nothing, d = nothing, I = nothing, Z = Float32[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2.6447595f-5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], W = nothing, ∂ᵢW = nothing, b₀ = nothing, τ = nothing, state0 = nothing, β = nothing), state = nothing)),)),)
so nothing
for all layers/params except for Z
which should not even be trainable. Also the input
layer is a single nothing
instead of a tuple with the gradient of each layer. I’m very confused.
I know that the problem is with the layer definition and not the rest of my training code because if I use Flux default layers it all trains correctly. I also used other custom layers where all numeric parameters were trainable and that was also fine.
It seems like defining trainable
breaks gradient propagation but I can’t figure out why. Using @functor CANCell (β,)
has the same effect. Any idea what could be the problem?