I am trying to train a model in Flux in which the loss has a nested gradient. I know I should avoid dynamic function invocation, but I am unsure how. The following code works on CPU but not on CUDA/GPU:
using Flux
using Zygote
device = cpu # or gpu
λ = Float32(0.1)
X = Float32.(rand(30, 20)) |> device
y = Float32.(rand(10, 20)) |> device
dx = Float32.(rand(30, 20)) |> device
dy = Float32.(ones(10, 20)) |> device
model = Chain(
Dense(30, 10, relu),
Dense(10, 10, relu),
Dense(10, 10)
) |> device
# 1) If dy = ∂sum(ŷ)/∂ŷ
loss, grad = Zygote.withgradient(model) do model
ret = Zygote.withgradient(X) do X
ŷ = model(X)
return sum(ŷ)
end
return Flux.mse(model(X), y) + λ * Flux.mse(dx, ret.grad[1])
end
# 2) General dy
loss_general, grad_general = Zygote.withgradient(model) do model
ŷ, pb = Zygote.pullback(X) do X
model(X)
end
return Flux.mse(ŷ, y) + λ * Flux.mse(dx, pb(dy)[1])
end
It would be great to get 1 and 2 to work on CUDA, but I am happy with just one.