Hi, I have a complex loss with mutating arrays unsupported by Zygote.
Is it possible to use Enzyme.jl with Flux.jl
I’m also interested in the answer. Would like to see an example of Enzyme usage with Flux / Lux.
You can write a custom rrule and use Enzyme to implement that
thank you for your answer.
Could you please provide a custom rule basic example?
using Enzyme
x = [2.0, 2.0]
bx = [0.0, 0.0]
y = [0.0,0.0]
using ComponentArrays, Lux, Random
rng = Random.default_rng()
Random.seed!(rng,100)
dudt2 = Lux.Chain(x -> x.^3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
function f(x::Array{Float64}, y::Array{Float64})
y .= dudt2(x, p, st)[1]
return nothing
end
Enzyme.autodiff(Reverse, f, Duplicated(x, bx), Duplicated(y, ones(2)))
function f2(x::Array{Float64})
dudt2(x, p, st)[1]
end
using Zygote
bx2 = Zygote.pullback(f2, x)[2](ones(2))[1]
bx
@show bx - bx2
#=
2-element Vector{Float64}:
-9.992007221626409e-16
-1.7763568394002505e-15
=#
on main, unreleased.
I would caution this coding style however and strongly recommend passing in dudt p and st explicitly as (const) parameters to the autodiff rather than type unstably capturing them.
It is significant for performance among other things.
Agreed. Just part of the demo that globals do work (dudt is just a function though)
Thank you for your help!
I understand that computation of the gradient vector is not a problem. But what I miss is knowing how to convert the gradient vector in order to use it in update!() for parameters update.