How to update two NN with different optimisers and common gradient

Hi all,

I have two neural networks which I optimize at the same time. Right now it looks somethign like this

using Flux
Gx = Chain(Conv((3,3), 8=>4; stride=2, pad=1),
            BatchNorm(4), 
            x->leakyrelu.(x, 0.2f0),
            Conv((1,1), 4=>1; stride = 1, pad = 0), 
            x->sigmoid.(x))
in_x = rand(Float32,12,12,8,1)

Gk = Chain(Dense(200,1000, relu), Dense(1000, 6), x->softmax(x))
in_k = rand(Float32,200)

y = randn(6)

loss(Nx, Nk, inx, ink, Y) = mean(Y.-vec(dropdims(Nx(inx);dims = (3,4))*Nk(ink)))

px = Flux.params(Gx)
pk = Flux.params(Gk)
optsx = ADAM(0.01)
optsk = ADAM(0.0001)

gx = Flux.gradient(()->loss(Gx, Gk, in_x, in_k, y), px)
gk = Flux.gradient(()->loss(Gx, Gk, in_x, in_k, y), pk)

Flux.Optimise.update!(optsx, px, gx)
Flux.Optimise.update!(optsk, pk, gk)

My original model is more complicated and runs on GPU and I figured out that it could be faster if I used only one gradient, e.g.

p = Flux.params(Gk, Gk)
g = Flux.gradient(()->loss(Gx, Gk, in_x, in_k, y), p)

But like this, I don’t know how to update the parameters of the networks separately (with different learning rate). I found out, that I can use FastChain for the Gk network and comfortably work with its parameters. Looking for something similar for Gx, I found Flux.destructure, but it returns all the parameters, not just the trainable ones, therefore, the gradient is computed only for subset of the parameters. It looks like this

using DiffEqFlux
Gk = FastChain(FastDense(200,1000,relu), FastDense(1000, 6), (x, p)->softmax(x))
θk = initial_params(Gk)
θx, nx = Flux.destructure(Gx)

function loss2(Nx, Nk, inx, ink, parsx, parsk, Y) 
    return mean(Y.-vec(dropdims(Nx(parsx)(inx);dims = (3,4))*Nk(ink, parsk)))
end

p = Flux.params(θx, θk)
g = Flux.gradient(()->loss2(nx, Gk, in_x, in_k, θx, θk, y), p)

Flux.Optimise.update!(optsk, θk, g.grads[θk])
Flux.Optimise.update!(optsk, θx, g.grads[θx])
ERROR: LoadError: DimensionMismatch("new dimensions (313,) must be consistent with array size 305")

Is there a way to get around this? (Or am I completely wrong about the speedup with one gradient computation?) Thanks for any help.

This should be doable without destructure (which can be slow, as it allocates O(params) memory when restructuring) or FastChain.

However, it’s not clear to me that two calls to gradient will be slower than one in this case, so if you haven’t timed that already now is a good time (no need to do any optimization). If it does turn out to be slower, you can use something like the following to update both sets of parameters independently:

p = union(px, pk)
g = Flux.gradient(()->loss2(Gx, Gk, in_x, in_k, y), p)

Flux.Optimise.update!(optsx, px, g)
Flux.Optimise.update!(optsk, pk, g)

update! will only look at gradients for the parameters you pass it.

Thaks a lot, it works great. Also, it is faster, it takes only halft the time now, but I can’t tell the reason. I never really understood why gradient returns gradients with respect to basically every array that is fed into the loss function. I would assume that like this it computes less gradients and, therefore, it is faster, but it is just a guess.

IIRC it computes gradients wrt. every parameter used in the loss calculation, which may be but isn’t necessarily every array. Most of these gradients are used as intermediates in backprop anyhow. Thus it’s not entirely clear to me where the performance difference comes from, perhaps setup/teardown?

Could you be more specific abou the setup/teardown? My original code really isn’t that much different from the MWE I posted here.

This would all be Zygote-internal stuff, so not the fault of your code :slight_smile:

1 Like