I’m having trouble finding documentation to get ReverseDiff to use a custom adjoint defined using ChainRulesCore.rrule. MWE:
import Zygote, ReverseDiff
import ChainRulesCore
function f(x)
x'*x
end
function ChainRulesCore.rrule(::typeof(f), x)
y = f(x)
function f_pullback(y_bar)
@show "Calling custom pullback"
return ChainRulesCore.NoTangent(), y_bar*2*x
end
return y, f_pullback
end
Running Zygote.gradient(f, ones(3))
gives the expected Calling custom pullback
output, while ReverseDiff.gradient(f, ones(3))
does not.
I must be missing something; where can I find an example?