Flux.destructure provides a nice interface for creating hypernetworks, but I seem to be having trouble with it: it seems to calculate the wrong number of parameters for the dimensions and can’t take the gradient if the hypernetwork has no output non-linearity. Let’s take this toy example:
using Flux, Zygote
z = randn(Float32, 8, bsz) # input to hypernet
x = randn(Float32, 4, bsz) # input to primary network
y = randn(Float32, 10, bsz) # target for primary network
primary = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
)
θ, re = Flux.destructure(primary)
Hypernet = Chain(
  Dense(8, 32, elu),
  Dense(32, length(θ), bias=false)
)
ps = Flux.params(Hypernet)
# Create primary Hypernet(z)
m = re(Hypernet(z))
# results in:
┌ Warning: Expected 160 params, got 10240
└ @ Flux C:\Users\aresf\.julia\packages\Flux\qAdFM\src\utils.jl:642
loss, back = pullback(ps) do
    m = re(H(z))
    Flux.mse(m(x), y)
end
grad = back(1.0f0)
# results in 
ERROR: DimensionMismatch("A has dimensions (160,1) but B has dimensions (64,32)")
If I use a weight non-linearity in the hypernet, taking the gradient gives the symmetric warning ( Expected 10240 params, got 160). Furthermore the results don’t make sense (they’re what you’d expect if it didn’t generate a new model m for each batch element)
I hope my problem is clear - does anyone know how to make hypernets work using Flux.destructure (or a similar convenient way)? I’ve been training hypernets the old-fashioned way by treating every layer as a function and that’s prone to buggy code ![]()
Thank you!!
