Custom Flux.jl Layer Not Updating (problem with Flux.trainable?)

I am having a problem getting a custom layer to update, and have tracked the issue to an apparent disconnect between Flux.trainable and Flux.setup. I read through related-sounding posts and issues, which seem to generally be cases of “model won’t learn when gradient is zero”; my issue, however, isn’t just that the layer won’t learn, but rather that Flux explicitly warns that the layer is un-trainable.

Here is my minimal working example:

# trivial layer
mutable struct MWE
	a::Float32
end
(m::MWE)(x) = m.a * x
Flux.@layer MWE

# see that the layer has exactly 1 trainable parameter, `a`
testlayer = MWE(rand())
@assert Flux.trainable(testlayer) == (; a=testlayer.a)

# get error that the layer has NO trainable parameters!
Flux.setup(Flux.Adam(), testlayer)

So, trainable can see the parameter a, but setup somehow misses it. If this layer is used in a larger Chain, the network will learn, but the MWE layer’s parameter will not change.

I can work around this for my current actual use case, but this is concerning and perplexing for me–can anyone point out what’s missing for getting this layer to update?

Thanks!

Classic. The MOMENT I posted the question, I realized the issue.

The parameter a needs to be an array, even if it’s just a single element. That’s literally it.

I’ll leave this post in the hopes that others can learn from my mild embarrassment.

3 Likes

I wasted a lot of time on this problem and was about to open a new post about this very same issue. Thanks for posting the solution.
Updating the documentation and stating this using your example might be helpful. If I have time, I will make a pull request to update the Flux documentation.

1 Like