Hypernetworks using Flux.destructure?

Flux.destructure provides a nice interface for creating hypernetworks, but I seem to be having trouble with it: it seems to calculate the wrong number of parameters for the dimensions and can’t take the gradient if the hypernetwork has no output non-linearity. Let’s take this toy example:

using Flux, Zygote
z = randn(Float32, 8, bsz) # input to hypernet
x = randn(Float32, 4, bsz) # input to primary network
y = randn(Float32, 10, bsz) # target for primary network

primary = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
)
θ, re = Flux.destructure(primary)

Hypernet = Chain(
  Dense(8, 32, elu),
  Dense(32, length(θ), bias=false)
)
ps = Flux.params(Hypernet)

# Create primary Hypernet(z)
m = re(Hypernet(z))
# results in:
┌ Warning: Expected 160 params, got 10240
└ @ Flux C:\Users\aresf\.julia\packages\Flux\qAdFM\src\utils.jl:642

loss, back = pullback(ps) do
    m = re(H(z))
    Flux.mse(m(x), y)
end

grad = back(1.0f0)
# results in 
ERROR: DimensionMismatch("A has dimensions (160,1) but B has dimensions (64,32)")

If I use a weight non-linearity in the hypernet, taking the gradient gives the symmetric warning ( Expected 10240 params, got 160). Furthermore the results don’t make sense (they’re what you’d expect if it didn’t generate a new model m for each batch element)

I hope my problem is clear - does anyone know how to make hypernets work using Flux.destructure (or a similar convenient way)? I’ve been training hypernets the old-fashioned way by treating every layer as a function and that’s prone to buggy code :smiling_face_with_tear:

Thank you!!

Is your batch size 64? I noticed that 10240 / 160 == 64. re is not designed to take in a batch of weights, but rather a single vector describing a single model.

Oh oops I forgot to show that. So I guess I’ve been trying to use it wrong? Is it possible to use destructure to generate batch size (e.g. 64) models and backpropagate through the hypernet correctly?

You could call re batch size times on each slice of the output of the hypernet to create batch size networks. I’m not sure how well that’d work, but it’s worth a try.

I might be doing something wrong, but doing that seems to be pretty slow compared to slicing the parameter vector up and using the slices in a custom struct. My problem is that I don’t know how to nicely parse an arbitrary Flux.Chain() and create functions that do the correct (batch-wise) computations for it. Here’s an example (i’m not sure how to shorten it):

using Flux, Zygote
x = randn(784, 64) |> gpu # input 
y = Flux.onehotbatch(rand(0:9, 64), 0:9) |> gpu # target

# primary network
p = Dense(784, 10)

θ, re = Flux.destructure(p)

#Hypernetworks 
H1 =  Dense(32, length(θ), bias=false) |> gpu
H2 =  Dense(32, length(θ), bias=false) |> gpu

ps1 = Flux.params(H1)
ps2 = Flux.params(H2)

## === 
"apply re to each x (here (784, 1))"
eval_net(re, θs, x::AbstractVector) = hcat(map((f, x) -> f(x), re.(eachcol(θs)), x)...)

# hypernet-Dense module
struct HyDense
    weight
    bias
    HyDense(weight, bias) = new(weight, bias)
end

bmul(a, b) = dropdims(batched_mul(unsqueeze(a, 1), b), dims=1)

(m::HyDense)(x) = bmul(x, m.weight) + m.bias

## =====
"split data batch x into vector of batchsize x (784,1) arrays"
function split_(x)
    last_dim = length(size(x))
    map(x -> unsqueeze(x, last_dim), eachslice(x, dims=last_dim))
end

x_ = split_(x)

"loss applying re() batch-wise"
function loss1(x_, y)
    θs = H1(z)
    ŷ = eval_net(re, θs, x_)
    Flux.logitcrossentropy(ŷ, y)
end

"slice weight vector"
function split_weights(θ)
    w = θ[1:end-10, :]
    w = reshape(w, 784, 10, 64)
    b = θ[end-9:end, :]
    w, b
end

"loss for weight slicing"
function loss2(x, y)
    θ2 = H2(z)
    w, b = split_weights(θ2)
    m = HyDense(w, b)
    ŷ = m(x)
    Flux.logitcrossentropy(ŷ, y)
end

"update with batch-wise re()"
function update1()
    loss_1, grad1 = withgradient(ps1) do
        loss1(x, x_, y)
    end
    Flux.update!(opt1, ps1, grad1)
    loss_1
end

"update with param slicing"
function update2()
    loss_2, grad2 = withgradient(ps2) do
        loss2(x, y)
    end
    Flux.update!(opt2, ps2, grad2)
    loss_2
end

## === test
z = randn(32, 64) |> gpu # hypernet input vector
opt1, opt2 = ADAM(1e-4), ADAM(1e-4)

@time update1()
  0.195762 seconds (113.19 k allocations: 5.460 MiB, 92.60% compilation time)

@time update2()
  0.143003 seconds (223.28 k allocations: 10.895 MiB, 97.33% compilation time)

# 2nd pass
@time update1()
  0.016299 seconds (25.98 k allocations: 1.106 MiB)

@time update2()
  0.008000 seconds (1.47 k allocations: 94.875 KiB)

I’m not surprised given that re is doing strictly more work than split_weights. It needs to do more work because it has to handle arbitrarily complex models like nested Chains. Also, batched_mul should be more efficient than mapping normal * over array slices. For the former it may be possible to hide some of the less AD-friendly computation inside of a rrule, but for the latter your best bet could be defining custom Hy[Layer] types to de/restructure from.

Is it possible to write a function or macro that looks inside a pre-defined Flux module (e.g. Dense), and creates a new one (e.g. HyDense) that replaces * with batched_mul etc.?

You can replace the whole layer (Dense -> HyDense), but not any more granular than that with Vanilla Flux. Theoretically it would be possible to use a tracing approach like GitHub - dfdx/Ghost.jl: The Code Tracer and re-write operations on the tape afterwards, but I’d only recommend going down that complex path if you really, really need the flexibility it affords.

Thanks for your help! This has been confusing to think about, for now I’ll just create custom structs that mirror the Flux models and use Optimisers.destructure() to reconstruct them from hypernet outputs. To do that correctly, the parameter batch θs = H(z) generated by the hypernetwork needs to be correctly reshaped so that each module in the primary gets the slice θs[module_inds, :]. To do this over an arbitrary Flux.Chain this would need to be done recursively, maybe using fmap? I’ve tried to rework something from Optimisers.jl/destructure.jl at master · FluxML/Optimisers.jl · GitHub but think I’m out of my depth here.

Let’s say we have a primary network

p = Chain(
    Parallel(
        vcat,
        Chain(
          Dense(32, 64), # mapped to HyDense later
          LayerNorm(64, elu),
        ),
        Chain(
          Dense(32, 64),
          LayerNorm(64, elu),
        )
     ),
    Dense(128, 64),
    LayerNorm(64, elu),
    Dense(64, 10, bias=false),
)

Using destructure:

θ, re = Flux.destructure(p)
julia> offs = re.offsets
(layers = ((connection = (), layers = ((layers = ((weight = 0, bias = 2048, σ = ()), (λ = (), diag = (scale = 2112, bias = 2176, σ = ()), ϵ = (), size = ((),), affine = ())),), (layers = ((weight = 2240, bias = 4288, σ = ()), (λ = (), diag = (scale = 4352, bias = 4416, σ = ()), ϵ = (), size = ((),), affine = ())),))), (weight = 4480, bias = 12672, σ = ()), (λ = (), diag = (scale = 12736, bias = 12800, σ = ()), ϵ = (), size = ((),), affine = ()), (weight = 12864, bias = (), σ = ())),)

(I think) all that needs to be done now is to slice θs = H(z) according to the indices in offs, i.e. the HyDense that replaces the first Dense in the Parallel module would have weights w = θs[1:2048,:], b=θs[2049:2113,:] etc. How can you gather all the indices in an arbitrarily nested tuple?

Also, I’m not sure how to incorporate activity normalization layers like LayerNorm and BatchNorm into a restructured Chain, since I don’t think(?) the parameters should be produced by the hypernet H. Adding them into the Chain post-hoc as e.g. 'LayerNorm(64, elu) |> gpu` makes Zygote unhappy.

^ Looks like this issue is still up for debate

Ok I think I found a really hacky way to do it, not sure if it works correctly though. To iterate through the primary model:

function get_module_sizes(m::Flux.Chain; args=args)
    sizes, modules = [], []
    function get_module_sizes_(m, sizes, modules)
        for l in m.layers
            if hasfield(typeof(l), :layers)
                get_module_sizes_(l, sizes, modules)
            elseif hasfield(typeof(l), :weight)
                wsz = size(l.weight)
                b_ = l.bias
                b_sz = b_ == false ? 0 : size(b_)
                push!(sizes, (wsz, b_sz))
                push!(modules, typeof(l))
            elseif isempty(Flux.params(l))
                nothing
            elseif Flux.hasaffine(l) # check for activity normalization
                psz = size.(Flux.params(l))
                push!(sizes, (1,))
                push!(modules, typeof(l))
            end
        end
        return modules, sizes
    end
    modules, sizes = get_module_sizes_(m, sizes, modules)
end

function get_params_length(p)
    ms, szs = get_module_sizes(p)
    sizes_ = map(x -> prod.(x), szs)
    map(sum, sizes_)
end


function split_weights(θs, offsets)
    msz, bsz = size(θs)
    offsets_full = [0; offsets]
    ws = [θs[offsets_full[i]+1:offsets_full[i+1], :][:] for i in 1:length(offsets)]
    ws_flat = vcat(ws...)
    @assert length(ws_flat) == msz * bsz
    return ws_flat
end

So let’s say we want to create a little image functional for MNIST. We can hackily do:

p = Chain(
    Dense(2, 32, ),
    LayerNorm(32, elu), Dense(32, 32, ),
    LayerNorm(32, elu),
    Dense(32, 32, ),
    LayerNorm(32, elu),
    Dense(32, 1, relu, bias=false),
    flatten,
)

"hyper" version of p
p2 = let
    p_list = []
    for m in p
        if isa(m, Dense)
            wsz = size(m.weight)
            push!(p_list, HyDense(reverse(wsz)..., args[:bsz], m.σ, bias=isa(m.bias, AbstractArray) ? true : false))
        else
            push!(p_list, m)
        end
    end
    Flux.Chain(p_list...)
end

θ, re = Flux.destructure(p)
θ2, re2 = Flux.destructure(p2) # restructure for actual net we'll be using
param_sizes = get_params_length(p)
offsets, len_θ = cumsum(param_sizes), sum(param_sizes)

Encoder = Chain(
    x -> unsqueeze(x, 3),
    Conv((5, 5), 1 => 32, stride=(2, 2)),
    BatchNorm(32, relu),
    Conv((5, 5), 32 => 8, stride=(2, 2)),
    BatchNorm(8, relu),
    flatten,
    Dense(128, 64),
    BatchNorm(64, relu),
    Dense(64, 32),
    BatchNorm(32, relu),
) |> gpu

H = Chain(
    LayerNorm(32,),
    Dense(32, 64),
    LayerNorm(64, elu),
    Dense(64, 64),
    LayerNorm(64, elu),
    Dense(64, len_θ, bias=false),) |> gpu

function model_loss(x)
    z = Encoder(x)
    θs = H(z)
    ws = split_weights(θs, offsets)
    m = re2(ws)
    # xy is a grid of 2 x 784 x batch-size spanning from (-1,1)
    Flux.mse(m(xy), flatten(x))
end

## ====

A few epochs into training, you start getting MNIST-y things

1 Like

My recommendation would be to replace layers with their hypernetwork equivalents before destructuring. That way the offsets will be correct from the get-go, and you don’t need a bunch of extra code to manipulate them. It is also relatively safe to do since the network is just used as a template (i.e. you immediately throw away θ).