Second order gradient with Lux, Zygote, CUDA, Enzyme

using Lux, CUDA, cuDNN, Random, OneHotArrays, Zygote
using Functors, Optimisers, Printf

model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MeanPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MeanPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 2)
    )
)

dev = gpu_device(; force=true)

ps, st = Lux.setup(Random.default_rng(), model) |> dev;

x = randn(Float32, 28, 28, 1, 32) |> dev;
δ = randn(Float32, 28, 28, 1, 32) |> dev;
y = onehotbatch(rand((1, 2), 32), 1:2) |> dev;

const celoss = CrossEntropyLoss(; logits=true)
const regloss = MSELoss()

function loss_function(model, ps, st, x, y)
    pred, _ = model(x, ps, st)
    return celoss(pred, y)
end

function ∂xloss_function(model, ps, st, x, δ, y)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    ∂x = only(Zygote.gradient(Base.Fix2(celoss, y) ∘ smodel, x))
    regloss(∂x, δ) + loss_function(model, ps, st, x, y)
end

function ∂∂xloss_function(model, ps, st, x, δ, y)
    only(Zygote.gradient(ps -> ∂xloss_function(model, ps, st, x, δ, y), ps))
end

∂∂xloss_function(model, ps, st, x, δ, y)

I have patched the support for (log)softmax and MeanPool (MaxPool is a bit finicky to write the jvp for, so I try will do that later) in feat: more nested AD rules by avik-pal · Pull Request #1151 · LuxDL/Lux.jl · GitHub. I will merge and tag it later tonight once tests pass

also note that in the original example cuDNN (or LuxCUDA) wasn’t loaded so it wasn’t able to use the correct versions of the algorithms.

2 Likes