Hypernetworks using Flux.destructure?

Flux.destructure provides a nice interface for creating hypernetworks, but I seem to be having trouble with it: it seems to calculate the wrong number of parameters for the dimensions and can’t take the gradient if the hypernetwork has no output non-linearity. Let’s take this toy example:

using Flux, Zygote
z = randn(Float32, 8, bsz) # input to hypernet
x = randn(Float32, 4, bsz) # input to primary network
y = randn(Float32, 10, bsz) # target for primary network

primary = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
)
θ, re = Flux.destructure(primary)

Hypernet = Chain(
  Dense(8, 32, elu),
  Dense(32, length(θ), bias=false)
)
ps = Flux.params(Hypernet)

# Create primary Hypernet(z)
m = re(Hypernet(z))
# results in:
┌ Warning: Expected 160 params, got 10240
└ @ Flux C:\Users\aresf\.julia\packages\Flux\qAdFM\src\utils.jl:642

loss, back = pullback(ps) do
    m = re(H(z))
    Flux.mse(m(x), y)
end

grad = back(1.0f0)
# results in 
ERROR: DimensionMismatch("A has dimensions (160,1) but B has dimensions (64,32)")

If I use a weight non-linearity in the hypernet, taking the gradient gives the symmetric warning ( Expected 10240 params, got 160). Furthermore the results don’t make sense (they’re what you’d expect if it didn’t generate a new model m for each batch element)

I hope my problem is clear - does anyone know how to make hypernets work using Flux.destructure (or a similar convenient way)? I’ve been training hypernets the old-fashioned way by treating every layer as a function and that’s prone to buggy code :smiling_face_with_tear:

Thank you!!

Is your batch size 64? I noticed that 10240 / 160 == 64. re is not designed to take in a batch of weights, but rather a single vector describing a single model.

Oh oops I forgot to show that. So I guess I’ve been trying to use it wrong? Is it possible to use destructure to generate batch size (e.g. 64) models and backpropagate through the hypernet correctly?

You could call re batch size times on each slice of the output of the hypernet to create batch size networks. I’m not sure how well that’d work, but it’s worth a try.