Flux.destructure
provides a nice interface for creating hypernetworks, but I seem to be having trouble with it: it seems to calculate the wrong number of parameters for the dimensions and can’t take the gradient if the hypernetwork has no output non-linearity. Let’s take this toy example:
using Flux, Zygote
z = randn(Float32, 8, bsz) # input to hypernet
x = randn(Float32, 4, bsz) # input to primary network
y = randn(Float32, 10, bsz) # target for primary network
primary = Chain(
Dense(4, 10, relu),
Dense(10, 10, relu),
)
θ, re = Flux.destructure(primary)
Hypernet = Chain(
Dense(8, 32, elu),
Dense(32, length(θ), bias=false)
)
ps = Flux.params(Hypernet)
# Create primary Hypernet(z)
m = re(Hypernet(z))
# results in:
┌ Warning: Expected 160 params, got 10240
└ @ Flux C:\Users\aresf\.julia\packages\Flux\qAdFM\src\utils.jl:642
loss, back = pullback(ps) do
m = re(H(z))
Flux.mse(m(x), y)
end
grad = back(1.0f0)
# results in
ERROR: DimensionMismatch("A has dimensions (160,1) but B has dimensions (64,32)")
If I use a weight non-linearity in the hypernet, taking the gradient gives the symmetric warning ( Expected 10240 params, got 160)
. Furthermore the results don’t make sense (they’re what you’d expect if it didn’t generate a new model m for each batch element)
I hope my problem is clear - does anyone know how to make hypernets work using Flux.destructure
(or a similar convenient way)? I’ve been training hypernets the old-fashioned way by treating every layer as a function and that’s prone to buggy code
Thank you!!