using Lux, CUDA, cuDNN, Random, OneHotArrays, Zygote
using Functors, Optimisers, Printf
model = Chain(
Conv((5, 5), 1 => 6, relu),
MeanPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MeanPool((2, 2)),
FlattenLayer(3),
Chain(
Dense(256 => 128, relu),
Dense(128 => 84, relu),
Dense(84 => 2)
)
)
dev = gpu_device(; force=true)
ps, st = Lux.setup(Random.default_rng(), model) |> dev;
x = randn(Float32, 28, 28, 1, 32) |> dev;
δ = randn(Float32, 28, 28, 1, 32) |> dev;
y = onehotbatch(rand((1, 2), 32), 1:2) |> dev;
const celoss = CrossEntropyLoss(; logits=true)
const regloss = MSELoss()
function loss_function(model, ps, st, x, y)
pred, _ = model(x, ps, st)
return celoss(pred, y)
end
function ∂xloss_function(model, ps, st, x, δ, y)
smodel = StatefulLuxLayer{true}(model, ps, st)
∂x = only(Zygote.gradient(Base.Fix2(celoss, y) ∘ smodel, x))
regloss(∂x, δ) + loss_function(model, ps, st, x, y)
end
function ∂∂xloss_function(model, ps, st, x, δ, y)
only(Zygote.gradient(ps -> ∂xloss_function(model, ps, st, x, δ, y), ps))
end
∂∂xloss_function(model, ps, st, x, δ, y)
I have patched the support for (log)softmax and MeanPool (MaxPool is a bit finicky to write the jvp for, so I try will do that later) in feat: more nested AD rules by avik-pal · Pull Request #1151 · LuxDL/Lux.jl · GitHub. I will merge and tag it later tonight once tests pass
also note that in the original example cuDNN (or LuxCUDA) wasn’t loaded so it wasn’t able to use the correct versions of the algorithms.