Custom ChainRulesCore rrule with ReverseDiff

I’m having trouble finding documentation to get ReverseDiff to use a custom adjoint defined using ChainRulesCore.rrule. MWE:

import Zygote, ReverseDiff
import ChainRulesCore
function f(x)
  x'*x
end
function ChainRulesCore.rrule(::typeof(f), x)
  y = f(x)
  function f_pullback(y_bar)
    @show "Calling custom pullback"
    return ChainRulesCore.NoTangent(), y_bar*2*x
  end
  return y, f_pullback
end

Running Zygote.gradient(f, ones(3)) gives the expected Calling custom pullback output, while ReverseDiff.gradient(f, ones(3)) does not.

I must be missing something; where can I find an example?

By default, ReverseDiff ignores all rrules, and one must opt into each rule one wants. Unfortunately this isn’t documented, but there’s an open PR to do so.

Here’s how to opt in:

julia> ReverseDiff.gradient(f, ones(3))
3-element Vector{Float64}:
 2.0
 2.0
 2.0

julia> ReverseDiff.@grad_from_chainrules f(x::TrackedArray)

julia> ReverseDiff.gradient(f, ones(3))
"Calling custom pullback" = "Calling custom pullback"
3-element Vector{Float64}:
 2.0
 2.0
 2.0
4 Likes

Brilliant. Many thanks!