Issues with AD of Hypergeometric function in narrow parameter band

Hello—

Using AD with the Gauss 2F1 function has come up in some work I am doing. I am incredibly impressed by the performance and functionality of HypergeometricFunctions.jl, but I’m having a strange issue where only for only a somewhat narrow band of values for the second argument I’m getting a stack overflow error. I’m having a lot of trouble narrowing down where in the source I should be looking and I’m hoping somebody can provide advice.

Here is a MWE minus the unicode, which I can’t get to work here:

using HypergeometricFunctions, ForwardDiff
fun(arg, v) = _\_2 F \_1 (0.5, 0.5*(v+1), 1.5, -(arg^2)/v)
ForwardDiff.derivative(v->fun(2.4, v), 1.825) # error
ForwardDiff.derivative(v->fun(2.4, v), 1.8)    # works fine

The error is a stack overflow on unsafe_gamma on a line which defines this:

unsafe_gamma(x::Real) = unsafe_gamma(float(x))

and then this function gets called from this one:

G(z::Number, ϵ::Number) = ϵ == 0 ? digamma(z)/unsafe_gamma(z) : (inv(unsafe_gamma(z))-inv(unsafe_gamma(z+ϵ)))/ϵ

From some print debugging, it is the call to inv_unsafe(z+\epsilon) that gets caught in an infinite loop.

So here’s where I can’t figure out where to go: unsafe_gamma is defined for duals, using DualNumbers. And when I print \epsilon, I get a dual. But that branch of the dispatch doesn’t seem to be getting called, and isreal(\epsilon) even seems to return true, which is what I think is causing the issue. I’m pretty stumped about what could be causing that. Can somebody help me?

Thank you so much in advance for reading.

Best to post an issue on HypergeometricFunctions.jl so Mikael Slevinsky sees it. Please include the full error message.

Note Dual are not subtypes of Real. Are you sure unsafe_gamma is defined in DualNumbers.jl?

2 Likes

Oof, right, that would make more sense. Thanks for the quick response and suggestion.

And no, sorry—unsafe_gamma is defined with several dispatch branches in HypergeometricFunctions. It looks like for compat it depends on DualNumbers, which is I presume why there is a dispatch branch for unsafe_gamma(z::Dual). The isreal returning true on \epsilon definitely has me scratching my head, but my print debugging had if isreal(z) && isreal(z+\epsilon) \n @show ... and it did trigger, so something is afoot here.

I’ll prepare an issue for the github repo. Thanks again for the immediate response!

For posterity, here’s the github issue: https://github.com/JuliaMath/HypergeometricFunctions.jl/issues/27.

1 Like

I think they are:

julia> ForwardDiff.Dual <: Real
true

I am unfamiliar with the _\_ construct, but I can’t run your MWE:

julia> using HypergeometricFunctions, ForwardDiff

julia> fun(arg, v) = _\_2 F \_1 (0.5, 0.5*(v+1), 1.5, -(arg^2)/v)
ERROR: syntax: extra token "F" after end of expression
Stacktrace:
 [1] top-level scope at none:1

I was talking about DualNumbers.Dual, ForwardDiff.Dual is different

1 Like

Hey Tamas—thanks for taking a look at this.

Apologies for the broken minimal example. The _\_ construct is me not figuring out how to get the unicode inputs to work. But it occurs to me that I can copy-paste from the github source, so here it is correctly rendered:

using HypergeometricFunctions, ForwardDiff
fun(arg, v) = _₂F₁(0.5, 0.5*(v+1), 1.5, -(arg^2)/v)
ForwardDiff.derivative(v->fun(2.4, v), 1.825) # error
ForwardDiff.derivative(v->fun(2.4, v), 1.8)   # works fine

Thanks to both of you for the comments about DualNumbers vs. ForwardDiff. If I add a dispatch branch unsafe_gamma(x::ForwardDiff.dual)=... I can confirm that that gets triggered and the stack overflow goes away. I am now having trouble finding the equivalent functionality of dualpart in ForwardDiff, but this is definitely a large step towards a fix.

If anybody happens to understand ForwardDiff, could you check this for correctness? The original code did this:

unsafe_gamma(z::Dual) = (r = realpart(z);w = unsafe_gamma(r); dual(w, w*digamma(r)*dualpart(z)))

which I think I have translated correctly to this:

function unsafe_gamma(z::ForwardDiff.Dual{T,V,N}) where{T,V,N}
  r  = z.value
  du = ForwardDiff.partials(T, z, 1)
  w  = unsafe_gamma(r)
  ForwardDiff.Dual{T}(w, w*digamma(r)*du)
end