Hi there,
I would like to use a custom derivative for a function - the cdf of a T distribution wrt to the data. I keep the parameter of this distribution fixed, so NoTangent for the parameter is enough for me here. Here is my current progress:
using Distributions, ChainRulesCore, ForwardDiff, ReverseDiff
import Distributions: cdf, TDist
import ChainRulesCore: rrule
function mytargetfunction(data::AbstractVector)
function obtaingradient(ν::AbstractVector{R}) where {R<:Real}
dist = Distributions.TDist(1.0)
_data = data ./ sum(ν)
data_uniform = [cdf(dist, _data[iter]) for iter in eachindex(data)] #Line of error
return sum( logpdf(dist, data_uniform[i]) for i in eachindex(data_uniform) )
end
end
data = randn(1000)
ν = [1., 2., 3., 4., 5.]
target = mytargetfunction(data)
target(ν)
ReverseDiff.gradient(target, ν) # ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Fl ....
Now I believe ReverseDiff does not know how to take the derivative wrt to cdf(TDist(1.0), x)
, so I want to write a custom rule:
function ChainRulesCore.rrule(::typeof(cdf), d::T, x
) where {T<:Distributions.TDist}
∇cdf(x) = Distributions.pdf(d, x) # Gradient w.r.t. x
val = cdf(d, x)
function _pullback(Δy)
# Only return differential w.r.t. x, keep d parameter as NoTangent
return NoTangent(), NoTangent(), Δy * ∇cdf(x)
end
return val, _pullback
end
target = mytargetfunction(data)
target(ν)
ReverseDiff.gradient(target, ν) # ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Fl ....
Unfortunately, I still seem to get the same error for some reason, and ReverseDiff does not seem to use my custom rule? I tried to check this and the rrule seems to work:
_val = cdf(TDist(10.), 1.)
_deriv = pdf(TDist(10.), 1.)
_RDval, _RDpullback = ChainRulesCore.rrule(Distributions.cdf, TDist(10.), 1.0)
_, _, _RDderiv = _RDpullback(1.0)
isequal(_val, _RDval) #true
isequal(_deriv, _RDderiv) #true
I would be grateful if someone could provide some guidance here. I have seen that one needs to opt-in for ReverseDiff to use ChainRules, but I dont know how to do that unfortunately.