I want to exclude some functions in my model code from gradient calculations, using @ignore_derivatives
from ChainRulesCore
. It works with Zygote
, but not with ReverseDiff
. Here’s a MWE.
using Zygote, ReverseDiff
import ChainRulesCore: @ignore_derivatives
# function to ignore in gradient calculation but not in forward pass
g = x -> x^2
# main function
function f(x)
x = x'*x
@ignore_derivatives x = g(x)
return x
end
inp = [2.];
@show Zygote.gradient(f, inp)[1];
@show ReverseDiff.gradient(f, inp);
This yields
(Zygote.gradient(f, inp))[1] = [4.0]
ReverseDiff.gradient(f, inp) = [32.0]
which is as intended with Zygote
but not with ReverseDiff
. How could I make this work in ReverseDiff
?