Hi, I have a complex loss with mutating arrays unsupported by Zygote.
Is it possible to use Enzyme.jl with Flux.jl
I’m also interested in the answer. Would like to see an example of Enzyme usage with Flux / Lux.
You can write a custom rrule and use Enzyme to implement that
thank you for your answer.
Could you please provide a custom rule basic example?
There is one in the docs for the latest version: Custom rules · Enzyme.jl
using Enzyme
x = [2.0, 2.0]
bx = [0.0, 0.0]
y = [0.0,0.0]
using ComponentArrays, Lux, Random
rng = Random.default_rng()
Random.seed!(rng,100)
dudt2 = Lux.Chain(x -> x.^3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
function f(x::Array{Float64}, y::Array{Float64})
y .= dudt2(x, p, st)[1]
return nothing
end
Enzyme.autodiff(Reverse, f, Duplicated(x, bx), Duplicated(y, ones(2)))
function f2(x::Array{Float64})
dudt2(x, p, st)[1]
end
using Zygote
bx2 = Zygote.pullback(f2, x)[2](ones(2))[1]
bx
@show bx - bx2
#=
2-element Vector{Float64}:
-9.992007221626409e-16
-1.7763568394002505e-15
=#
on main, unreleased.
I would caution this coding style however and strongly recommend passing in dudt p and st explicitly as (const) parameters to the autodiff rather than type unstably capturing them.
It is significant for performance among other things.
Agreed. Just part of the demo that globals do work (dudt is just a function though)
Thank you for your help!
I understand that computation of the gradient vector is not a problem. But what I miss is knowing how to convert the gradient vector in order to use it in update!() for parameters update.
this doesn’t work on GPU. Did I do sth wrong?
using Enzyme
using Lux, Random, LuxCUDA
rng = Random.default_rng()
Random.seed!(rng,100)
dudt2 = Lux.Chain(x -> x.^3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
gpu_dev = gpu_device()
p, st = Lux.setup(rng, dudt2) .|> gpu_dev
function f(x::T, y::T) where T
y .= dudt2(x, p, st)[1]
return nothing
end
x = [2.0f0, 2.0f0] |> gpu_dev
bx = [0.0f0, 0.0f0] |> gpu_dev
y = [0.0f0,0.0f0] |> gpu_dev
ones32 = ones(Float32, 2) |> gpu_dev
Enzyme.autodiff(Reverse, f, Duplicated(x, bx), Duplicated(y, ones32))
File an issue?
though also forewarning, Differentiating host-side code when accesses device memory (e.g. sum(CuArray)
) is not yet supported, but in progress. (see FAQ · Enzyme.jl)