Hi all,
I have two neural networks which I optimize at the same time. Right now it looks somethign like this
using Flux
Gx = Chain(Conv((3,3), 8=>4; stride=2, pad=1),
BatchNorm(4),
x->leakyrelu.(x, 0.2f0),
Conv((1,1), 4=>1; stride = 1, pad = 0),
x->sigmoid.(x))
in_x = rand(Float32,12,12,8,1)
Gk = Chain(Dense(200,1000, relu), Dense(1000, 6), x->softmax(x))
in_k = rand(Float32,200)
y = randn(6)
loss(Nx, Nk, inx, ink, Y) = mean(Y.-vec(dropdims(Nx(inx);dims = (3,4))*Nk(ink)))
px = Flux.params(Gx)
pk = Flux.params(Gk)
optsx = ADAM(0.01)
optsk = ADAM(0.0001)
gx = Flux.gradient(()->loss(Gx, Gk, in_x, in_k, y), px)
gk = Flux.gradient(()->loss(Gx, Gk, in_x, in_k, y), pk)
Flux.Optimise.update!(optsx, px, gx)
Flux.Optimise.update!(optsk, pk, gk)
My original model is more complicated and runs on GPU and I figured out that it could be faster if I used only one gradient, e.g.
p = Flux.params(Gk, Gk)
g = Flux.gradient(()->loss(Gx, Gk, in_x, in_k, y), p)
But like this, I don’t know how to update the parameters of the networks separately (with different learning rate). I found out, that I can use FastChain
for the Gk
network and comfortably work with its parameters. Looking for something similar for Gx
, I found Flux.destructure
, but it returns all the parameters, not just the trainable ones, therefore, the gradient is computed only for subset of the parameters. It looks like this
using DiffEqFlux
Gk = FastChain(FastDense(200,1000,relu), FastDense(1000, 6), (x, p)->softmax(x))
θk = initial_params(Gk)
θx, nx = Flux.destructure(Gx)
function loss2(Nx, Nk, inx, ink, parsx, parsk, Y)
return mean(Y.-vec(dropdims(Nx(parsx)(inx);dims = (3,4))*Nk(ink, parsk)))
end
p = Flux.params(θx, θk)
g = Flux.gradient(()->loss2(nx, Gk, in_x, in_k, θx, θk, y), p)
Flux.Optimise.update!(optsk, θk, g.grads[θk])
Flux.Optimise.update!(optsk, θx, g.grads[θx])
ERROR: LoadError: DimensionMismatch("new dimensions (313,) must be consistent with array size 305")
Is there a way to get around this? (Or am I completely wrong about the speedup with one gradient computation?) Thanks for any help.