Hi all,
I have two neural networks which I optimize at the same time. Right now it looks somethign like this
using Flux
Gx = Chain(Conv((3,3), 8=>4; stride=2, pad=1),
BatchNorm(4),
x->leakyrelu.(x, 0.2f0),
Conv((1,1), 4=>1; stride = 1, pad = 0),
x->sigmoid.(x))
in_x = rand(Float32,12,12,8,1)
Gk = Chain(Dense(200,1000, relu), Dense(1000, 6), x->softmax(x))
in_k = rand(Float32,200)
y = randn(6)
loss(Nx, Nk, inx, ink, Y) = mean(Y.-vec(dropdims(Nx(inx);dims = (3,4))*Nk(ink)))
px = Flux.params(Gx)
pk = Flux.params(Gk)
optsx = ADAM(0.01)
optsk = ADAM(0.0001)
gx = Flux.gradient(()->loss(Gx, Gk, in_x, in_k, y), px)
gk = Flux.gradient(()->loss(Gx, Gk, in_x, in_k, y), pk)
Flux.Optimise.update!(optsx, px, gx)
Flux.Optimise.update!(optsk, pk, gk)
My original model is more complicated and runs on GPU and I figured out that it could be faster if I used only one gradient, e.g.
p = Flux.params(Gk, Gk)
g = Flux.gradient(()->loss(Gx, Gk, in_x, in_k, y), p)
But like this, I don’t know how to update the parameters of the networks separately (with different learning rate). I found out, that I can use FastChain for the Gk network and comfortably work with its parameters. Looking for something similar for Gx, I found Flux.destructure, but it returns all the parameters, not just the trainable ones, therefore, the gradient is computed only for subset of the parameters. It looks like this
using DiffEqFlux
Gk = FastChain(FastDense(200,1000,relu), FastDense(1000, 6), (x, p)->softmax(x))
θk = initial_params(Gk)
θx, nx = Flux.destructure(Gx)
function loss2(Nx, Nk, inx, ink, parsx, parsk, Y)
return mean(Y.-vec(dropdims(Nx(parsx)(inx);dims = (3,4))*Nk(ink, parsk)))
end
p = Flux.params(θx, θk)
g = Flux.gradient(()->loss2(nx, Gk, in_x, in_k, θx, θk, y), p)
Flux.Optimise.update!(optsk, θk, g.grads[θk])
Flux.Optimise.update!(optsk, θx, g.grads[θx])
ERROR: LoadError: DimensionMismatch("new dimensions (313,) must be consistent with array size 305")
Is there a way to get around this? (Or am I completely wrong about the speedup with one gradient computation?) Thanks for any help.