I am having a problem getting a custom layer to update, and have tracked the issue to an apparent disconnect between Flux.trainable
and Flux.setup
. I read through related-sounding posts and issues, which seem to generally be cases of “model won’t learn when gradient is zero”; my issue, however, isn’t just that the layer won’t learn, but rather that Flux
explicitly warns that the layer is un-trainable.
Here is my minimal working example:
# trivial layer
mutable struct MWE
a::Float32
end
(m::MWE)(x) = m.a * x
Flux.@layer MWE
# see that the layer has exactly 1 trainable parameter, `a`
testlayer = MWE(rand())
@assert Flux.trainable(testlayer) == (; a=testlayer.a)
# get error that the layer has NO trainable parameters!
Flux.setup(Flux.Adam(), testlayer)
So, trainable
can see the parameter a
, but setup
somehow misses it. If this layer is used in a larger Chain
, the network will learn, but the MWE
layer’s parameter will not change.
I can work around this for my current actual use case, but this is concerning and perplexing for me–can anyone point out what’s missing for getting this layer to update?
Thanks!