ReverseDiff not using custom derivative from ChainRules

Hi there,

I would like to use a custom derivative for a function - the cdf of a T distribution wrt to the data. I keep the parameter of this distribution fixed, so NoTangent for the parameter is enough for me here. Here is my current progress:

using Distributions, ChainRulesCore, ForwardDiff, ReverseDiff
import Distributions: cdf, TDist
import ChainRulesCore: rrule

function mytargetfunction(data::AbstractVector)
    function obtaingradient(ν::AbstractVector{R}) where {R<:Real}
        dist = Distributions.TDist(1.0)
        _data = data ./ sum(ν)
        data_uniform = [cdf(dist, _data[iter]) for iter in eachindex(data)] #Line of error
        return sum( logpdf(dist, data_uniform[i]) for i in eachindex(data_uniform) )
    end
end
data = randn(1000)
ν = [1., 2., 3., 4., 5.]
target = mytargetfunction(data)
target(ν)
ReverseDiff.gradient(target, ν) # ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Fl ....

Now I believe ReverseDiff does not know how to take the derivative wrt to cdf(TDist(1.0), x), so I want to write a custom rule:

function ChainRulesCore.rrule(::typeof(cdf), d::T, x
) where {T<:Distributions.TDist}
    ∇cdf(x) = Distributions.pdf(d, x) # Gradient w.r.t. x
    val  = cdf(d, x)
    function _pullback(Δy)
        # Only return differential w.r.t. x, keep d parameter as NoTangent
        return NoTangent(), NoTangent(), Δy * ∇cdf(x)
    end
    return val, _pullback
end

target = mytargetfunction(data)
target(ν)
ReverseDiff.gradient(target, ν) # ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Fl ....

Unfortunately, I still seem to get the same error for some reason, and ReverseDiff does not seem to use my custom rule? I tried to check this and the rrule seems to work:

_val = cdf(TDist(10.), 1.)
_deriv = pdf(TDist(10.), 1.)
_RDval, _RDpullback = ChainRulesCore.rrule(Distributions.cdf, TDist(10.), 1.0)
_, _, _RDderiv = _RDpullback(1.0)
isequal(_val, _RDval) #true
isequal(_deriv, _RDderiv) #true

I would be grateful if someone could provide some guidance here. I have seen that one needs to opt-in for ReverseDiff to use ChainRules, but I dont know how to do that unfortunately.

1 Like

You need to use ReverseDiff.@grad_from_chainrules. So adding

ReverseDiff.@grad_from_chainrules cdf(d, x::TrackedReal)

should help.

2 Likes

Thanks a lot!

Is there a way that this macro also works with types outside ReverseDiff? I.e.:

ReverseDiff.@grad_from_chainrules cdf(d::D, x::R) where {D<:Distributions.TDist, R<:TrackedReal} #LoadError: `@grad_from_chainrules` has to be applied to a function signature
ReverseDiff.@grad_from_chainrules cdf(d::Distributions.TDist, x::TrackedReal) #UndefVarError: Distributions not defined

I guess I could make this otherwise work with closures or structs but this would be more convenient. Just using cdf(d, x::TrackedReal) would be ambigous for all the methods defined in Distributions itself as TrackedReal should be a subtype of Real.

I’m not exactly sure what causes the second error (maybe something along the lines of the macro expanding in the context of module ReverseDiff, which doesn’t have Distributions loaded). But try Main.Distributions.TDist instead and see if that works. (Of course, replace Main with whatever module you are working in.)

EDIT: If you are working in module MyModule, you could try Main.MyModule.Distributions.TDist.

EDIT: It appears this approach doesn’t work if you are creating a package (there are errors during precompilation about MyModule not being defined). I’m not sure what to do in this case.

1 Like