Flux assumes only trainable parameters should be moved to GPU

And that’s not always a good assumption! I am working on a masked dense layer, which looks something like this

struct MaskedDense{F, M <: AbstractMatrix, B <: AbstractVector}
  W::M
  b::B
  M::AbstractMatrix{Bool}
  σ::F
end
(d::MaskedDense)(x::AbstractArray) = d.σ.(d.b .+ (d.W .* d.M)*x
@Flux.functor MaskedDense (W, b)

If we try to move it to the GPU, the trainable parameters specified in the functor declaration are moved to the GPU

m = MaskedDense(randn(2, 2), randn(2), [true false; false true], Flux.sigmoid)
typeof(gpu(m)) = Main.MaskedDense{typeof(NNlib.σ), CUDA.CuArray{Float32, 2}, CUDA.CuArray{Float32, 1}} 

but this does not work with the mask! And if we include the mask in the functor definition @Flux.functor MaskedDense (W, b, M), it is indeed moved, but we get masks that are being updated in training which is not desirable. It should of course be moved to the GPU but not trainable.

I tried to hijack the gpu function like so

Flux.gpu(d::MaskedDense) = MaskedDense(cu(d.W), cu(d.b), cu(d.M), d.σ)

but that causes my Julia session to explode in errors like [4] _setindex! at ./abstractarray.jl:1290 when training, so I assume that’s not the way to go.

Has anyone here implemented something similar and know how to tackle this? Thanks!

I think you can specify which parameters are trainable separately from which parameters are the layers “children”. Eg for BatchNorms: https://github.com/FluxML/Flux.jl/blob/29a96b961badea84a3bb323257c1374ffebca2e7/src/layers/normalise.jl#L265.

2 Likes

Thanks! That does indeed solve the problem!

For completeness if someone stumbles upon this in the future:

@Flux.functor MaskedDense
Flux.trainable(d::MaskedDense) =  (d.W, d.b)
2 Likes

Yup, this is 100% the intended use of trainable. The only reason we haven’t documented it thoroughly is because we’re in the process of moving away from implicit params and are still thinking about how best to represent all parameters vs trainable parameters (and any other subsets of params you can think of).