Use Enzyme in flux

Hi, I have a complex loss with mutating arrays unsupported by Zygote.
Is it possible to use Enzyme.jl with Flux.jl

I’m also interested in the answer. Would like to see an example of Enzyme usage with Flux / Lux.

You can write a custom rrule and use Enzyme to implement that

thank you for your answer.
Could you please provide a custom rule basic example?

There is one in the docs for the latest version: Custom rules · Enzyme.jl

using Enzyme

x  = [2.0, 2.0]
bx = [0.0, 0.0]
y  = [0.0,0.0]

using ComponentArrays, Lux, Random

rng = Random.default_rng()
Random.seed!(rng,100)
dudt2 = Lux.Chain(x -> x.^3,
                  Lux.Dense(2, 50, tanh),
                  Lux.Dense(50, 2))
p, st = Lux.setup(rng, dudt2)

function f(x::Array{Float64}, y::Array{Float64})
    y .= dudt2(x, p, st)[1]
    return nothing
end

Enzyme.autodiff(Reverse, f, Duplicated(x, bx), Duplicated(y, ones(2)))

function f2(x::Array{Float64})
    dudt2(x, p, st)[1]
end

using Zygote
bx2 = Zygote.pullback(f2, x)[2](ones(2))[1]
bx

@show bx - bx2

#=
2-element Vector{Float64}:
 -9.992007221626409e-16
 -1.7763568394002505e-15
=#

on main, unreleased.

7 Likes

I would caution this coding style however and strongly recommend passing in dudt p and st explicitly as (const) parameters to the autodiff rather than type unstably capturing them.

It is significant for performance among other things.

1 Like

Agreed. Just part of the demo that globals do work (dudt is just a function though)

Thank you for your help!

I understand that computation of the gradient vector is not a problem. But what I miss is knowing how to convert the gradient vector in order to use it in update!() for parameters update.

this doesn’t work on GPU. Did I do sth wrong?

using Enzyme
using Lux, Random, LuxCUDA

rng = Random.default_rng()
Random.seed!(rng,100)
dudt2 = Lux.Chain(x -> x.^3,
                  Lux.Dense(2, 50, tanh),
                  Lux.Dense(50, 2))
gpu_dev = gpu_device()
p, st = Lux.setup(rng, dudt2) .|> gpu_dev

function f(x::T, y::T) where T
    y .= dudt2(x, p, st)[1]
    return nothing
end

x  = [2.0f0, 2.0f0] |> gpu_dev
bx = [0.0f0, 0.0f0] |> gpu_dev
y  = [0.0f0,0.0f0] |> gpu_dev
ones32 = ones(Float32, 2) |> gpu_dev

Enzyme.autodiff(Reverse, f, Duplicated(x, bx), Duplicated(y, ones32))

File an issue?

though also forewarning, Differentiating host-side code when accesses device memory (e.g. sum(CuArray) ) is not yet supported, but in progress. (see FAQ · Enzyme.jl)

1 Like