Flux: `nothing` gradients with custom layer

Hi all,

I’m trying to define a custom Layer for Flux:

struct CANCell
    n::NTuple{D,Int} where {D}         # number of neurons in each dimension
    d::Int                             # number of dimensions
    I::Vector{Tuple}                   # index (i,j...) of each neuron in the lattice
    Z               # gpu(Matrix)
    W               # gpu(Matrix)
    ∂ᵢW      # gpu.(::Vector{Matrix})
    β  # gpu(::Vector{Float32})

where β is the only trainable parameter.

It’s a recurrent model so I define a forward pass:

function (can::CANCell)(h, v::AbstractVecOrMat)
    W, b₀, τ, ∂ᵢW = can.W, can.b₀, can.τ, can.∂ᵢW
    σ = Flux.NNlib.fast_act(relu, v)

    if size(v, 2) != size(h, 2)
        # reshape h
        rs = gpu(zeros(Float32, size(h, 1), size(v, 2)))
        h = h .+ rs
    dₜh = muladd(W, h, b₀)

    idx = argmax.(eachcol(h))
    for d in 1:can.d
        ∂W = can.β .* (∂ᵢW[d][:, idx] .* v[d, :]')
        dₜh = dₜh + ∂W

    h = h + (σ.(dₜh) .- h)/τ

    idx = argmax.(eachcol(h))
    y =  can.Z[:, idx]
    return h, reshape_cell_output(y, v)

and then use @functor and Recur:

@functor CANCell
Flux.trainable(c::CANCell) = (; β = c.β)  # parameters to train

CAN(args...; kwargs...) = Recur(CANCell(args...))

Recur(m::CANCell) =  Recur(m, m.state0)

Finally, the whole thing is a layer of a larger model:

        input = Chain(
            Dense(N, nh),
            Dense(nh, nh, tanh),  
        recur = Chain(
            RNN(nh, d, tanh),
            Dense(d, nh),
            Dense(nh, nh),
            Dense(nh, size(can.cell.W, 1)),

This doesn’t train correctly though. If I print the gradients during training I get:

(layers = (input = nothing, recur = (layers = (nothing, nothing, nothing, nothing, 
(cell = (topology = nothing, n = nothing, d = nothing, I = nothing, Z = Float32[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2.6447595f-5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], W = nothing, ∂ᵢW = nothing, b₀ = nothing, τ = nothing, state0 = nothing, β = nothing), state = nothing)),)),)

so nothing for all layers/params except for Z which should not even be trainable. Also the input layer is a single nothing instead of a tuple with the gradient of each layer. I’m very confused.

I know that the problem is with the layer definition and not the rest of my training code because if I use Flux default layers it all trains correctly. I also used other custom layers where all numeric parameters were trainable and that was also fine.

It seems like defining trainable breaks gradient propagation but I can’t figure out why. Using @functor CANCell (β,) has the same effect. Any idea what could be the problem?

The definition of trainable is redundant if @functor is already used, but without a MWE with the loss and gradient taking it’s difficult to say what if anything could be at issue.

Thanks for getting back to me.

Removing Flux.trainable and using only @functor CANCell (β,) doesn’t change.
The loss function is just Flux.mse(m(x), y) and the training loop is standard too. The same code and the same data works perfectly if I don’t use my custom layer (e.g. if I just replace it with an RNN). So the problem should be with the layer definition.

Neither of those should work out of the box for RNNs because the input handling is different. That’s why a runnable MWE is important, because you’ve surely done other changes to make them work :slight_smile:

No I get it but the whole thing is part of a bigger code base so it’s hard to get a MWE out of it.

But this model works:

        input = Dense(N, nh),
        recur = Chain(
            Dense(nh, nh, tanh),
            Dense(nh, 64,),
            RNN(64, 1)

with all the gradients etc:

(layers = (input = (weight = Float32[-0.010323494 0.009741096 -0.004862191 -0.0014745682 -0.0041073635 0.00091178465 -0.006024482 0.0027308932 -0.002870893 -0.0006800644 -0.00555486 0.021792755 -0.0030689763 -0.006078959 -0.002313804 0.0011246824 0.0019230729 -0.00074779923 -0.017342443 -0.008566279 0.0034860787 0.00018054058....

But this:

        input = Dense(N, nh),
        recur = Chain(
            Dense(nh, nh, tanh),
            Dense(nh, 2,),

does not give the right gradient:

(layers = (input = nothing, recur = (layers = (nothing, nothing, (cell = (topology = nothing, n = nothing, d = nothing, I = nothing, Z = Float32[-0.009424061 0.0 0.

What I find particularly confusing is that the input component’s gradient is nothing.

Notice that the only difference is the last layer in recur. They are both essentially RNNs just that my custom layer has some non-trainable parameters.

The data is passed as batch where each x is a vector of matrices as described in the docs.

    function dobatch(b)
        l, grads = Flux.withgradient(model) do m
            length(b) == 1 ? loss(m, b) : loss(m, b...)
        any(isnan.(l)) && error("NaN loss during training")
        grads = grads isa Tuple ? grads[1] : grads


        update!(opt_state, model, grads)  # model params update

The total loss for a batch is the sum of Flux.mse over each sample.

Placeholder data is always an option. Usually code for a single batch’s worth of training is enough and you’ve posted maybe 70% of that, so filling in the gaps (e.g. what is can and how is it defined?) would be enough.

Note that @functor and trainable don’t actually affect what gradients you get back, so those can be ruled out immediately.

Some more things to try:

  • Reduce the example from a Chain with a bunch of Dense layers to a single RNN vs a single Recur{CanCell}.
  • Drop the Recur part and see if you get gradients when manually differentiating through a single timestep with a CanCell as your model. If that works, reintroduce the Recur.


After a bit more investigation, it looks like the problem was with usingargmax which is not differentiable.
Changing the forward pass of the layer to use softmax instead gives correct gradients.
Thanks for helping out!

1 Like