Optimisers.jl cannot be used with Zygote.jl's implicit gradients

I’m having the following issue with using Zygote with Optimisers.jl. Here is an MRE:

using Flux, Zygote

m = Chain(Dense(10, 5, relu), Dense(5, 2))

gs = gradient(() -> sum(m(x)), params(m))

opt = ADAM();

Flux.Optimise.update!(opt, params(m), gs)

which returns the error

ERROR: Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] base(dx::Zygote.Grads)
   @ Flux ~/.julia/packages/Flux/n3cOc/src/Flux.jl:20
 [3] (::Optimisers.var"#13#15"{Params{Zygote.Buffer{Any, Vector{Any}}}})(x̄::Zygote.Grads)
   @ Optimisers ~/.julia/packages/Optimisers/1x8gl/src/interface.jl:112
 [4] map
   @ ./tuple.jl:273 [inlined]
 [5] _grads!(dict::IdDict{Optimisers.Leaf, Any}, tree::Optimisers.Adam{Float32}, x::Params{Zygote.Buffer{Any, Vector{Any}}}, x̄s::Zygote.Grads)
   @ Optimisers ~/.julia/packages/Optimisers/1x8gl/src/interface.jl:112
 [6] update!(::Optimisers.Adam{Float32}, ::Params{Zygote.Buffer{Any, Vector{Any}}}, ::Zygote.Grads)
   @ Optimisers ~/.julia/packages/Optimisers/1x8gl/src/interface.jl:70
 [7] top-level scope
   @ REPL[106]:1

Optimisers seems to no longer support implicit gradients. The docs show how to use Flux with explicit gradients instead. For your example, the following should work:

gs, _ = gradient((model, inp) -> sum(model(inp)), m, x)

opt = Optimisers.ADAM()
opt_state = Optimisers.setup(opt, m)
# Do an update step
opt_state, m = Optimisers.update(opt_state, m, gs)

I see, then why are there so many examples of implicit gradients in the Flux documentation if optimisation with such gradients is no longer supported?


Don’t know when this will be updated. On the other hand, the optimisers from Flux itself, i.e., not loading Optimisers, still work for me with implicit gradients:

gs = Zygote.gradient(() -> sum(m(x)), Flux.params(m))
opt = Flux.Optimise.ADAM();
Flux.Optimise.update!(opt, Flux.params(m), gs)