# ReverseDiff not using custom derivative from ChainRules

Hi there,

I would like to use a custom derivative for a function - the cdf of a T distribution wrt to the data. I keep the parameter of this distribution fixed, so NoTangent for the parameter is enough for me here. Here is my current progress:

``````using Distributions, ChainRulesCore, ForwardDiff, ReverseDiff
import Distributions: cdf, TDist
import ChainRulesCore: rrule

function mytargetfunction(data::AbstractVector)
dist = Distributions.TDist(1.0)
_data = data ./ sum(ν)
data_uniform = [cdf(dist, _data[iter]) for iter in eachindex(data)] #Line of error
return sum( logpdf(dist, data_uniform[i]) for i in eachindex(data_uniform) )
end
end
data = randn(1000)
ν = [1., 2., 3., 4., 5.]
target = mytargetfunction(data)
target(ν)
ReverseDiff.gradient(target, ν) # ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Fl ....
``````

Now I believe ReverseDiff does not know how to take the derivative wrt to `cdf(TDist(1.0), x)`, so I want to write a custom rule:

``````function ChainRulesCore.rrule(::typeof(cdf), d::T, x
) where {T<:Distributions.TDist}
∇cdf(x) = Distributions.pdf(d, x) # Gradient w.r.t. x
val  = cdf(d, x)
function _pullback(Δy)
# Only return differential w.r.t. x, keep d parameter as NoTangent
return NoTangent(), NoTangent(), Δy * ∇cdf(x)
end
return val, _pullback
end

target = mytargetfunction(data)
target(ν)
ReverseDiff.gradient(target, ν) # ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Fl ....
``````

Unfortunately, I still seem to get the same error for some reason, and ReverseDiff does not seem to use my custom rule? I tried to check this and the rrule seems to work:

``````_val = cdf(TDist(10.), 1.)
_deriv = pdf(TDist(10.), 1.)
_RDval, _RDpullback = ChainRulesCore.rrule(Distributions.cdf, TDist(10.), 1.0)
_, _, _RDderiv = _RDpullback(1.0)
isequal(_val, _RDval) #true
isequal(_deriv, _RDderiv) #true
``````

I would be grateful if someone could provide some guidance here. I have seen that one needs to opt-in for ReverseDiff to use ChainRules, but I dont know how to do that unfortunately.

1 Like

You need to use `ReverseDiff.@grad_from_chainrules`. So adding

``````ReverseDiff.@grad_from_chainrules cdf(d, x::TrackedReal)
``````

should help.

2 Likes

Thanks a lot!

Is there a way that this macro also works with types outside ReverseDiff? I.e.:

``````ReverseDiff.@grad_from_chainrules cdf(d::D, x::R) where {D<:Distributions.TDist, R<:TrackedReal} #LoadError: `@grad_from_chainrules` has to be applied to a function signature
ReverseDiff.@grad_from_chainrules cdf(d::Distributions.TDist, x::TrackedReal) #UndefVarError: Distributions not defined
``````

I guess I could make this otherwise work with closures or structs but this would be more convenient. Just using `cdf(d, x::TrackedReal)` would be ambigous for all the methods defined in Distributions itself as TrackedReal should be a subtype of Real.

I’m not exactly sure what causes the second error (maybe something along the lines of the macro expanding in the context of module `ReverseDiff`, which doesn’t have `Distributions` loaded). But try `Main.Distributions.TDist` instead and see if that works. (Of course, replace `Main` with whatever module you are working in.)

EDIT: If you are working in module `MyModule`, you could try `Main.MyModule.Distributions.TDist`.

EDIT: It appears this approach doesn’t work if you are creating a package (there are errors during precompilation about `MyModule` not being defined). I’m not sure what to do in this case.

1 Like