How to implement a sparse decoder layer in Flux.jl?

I’m new to Julia and Flux.jl and I’m trying to implement a Variational Autoencoder (VAE). My decoder has a single layer mapping to the output, but I need to apply different weights to each example.

Instead of a dense layer, I want to interact the weights with a matrix that indicates which weights should be non-zero for each example. Essentially, I’m looking for a way to implement a sparse decoder layer.

Is this possible in Flux? Could someone provide an example or point me to relevant documentation?

Define encoder

encoder(x) = Chain(
Dense(40, 20, relu),
Dense(20, 20)
)(x)

Decoder 1 layer

decoder(z) = Dense(20, 7, sigmoid)

Maybe you want this, although I’m not sure I’d recommend it:

julia> using SparseArrays

julia> Dense(sprand(5,3,0.2))
Dense(3 => 5)       # 20 parameters  <-- this is a lie

julia> ans.weight
5×3 SparseMatrixCSC{Float64, Int64} with 3 stored entries:
  ⋅         ⋅    ⋅ 
  ⋅         ⋅   0.357995
  ⋅         ⋅   0.353468
 0.552263   ⋅    ⋅ 
  ⋅         ⋅    ⋅ 

But if not, can you explain what you mean here? For instance, write out code which uses some weights and some “examples”, even if it’s not fast / Zygote-friendly.