Automatic differentation of `f(x::Real) = f(float(x))` leads to stackoverflow

Many definitions follow this pattern:

f(x::AbstractFloat) = 2x # any computation here
f(x::Real) = f(float(x))

But then:

julia> ForwardDiff.derivative(f, 1.)
ERROR: StackOverflowError:
Stacktrace:
 [1] f(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(f),Float64},Float64,1}) at ./REPL[3]:1 (repeats 80000 times)

Because float(Dual) gives Dual which is not AbstractFloat. Is there a way around this?

  1. Ideally, define the operations themselves to work with Dual. Then you don’t need to define separate methods. For *, ForwardDiff does this.
  2. In case you need a signature for Dual, define that. See various macros in ForwardDiff for that.

See also:

Basically, don’t follow the pattern that Base switched to in 0.7, it’s just not a good idea. Something like https://github.com/JuliaLang/julia/issues/26552#issuecomment-374958654 would turn the StackOverFlowError into a MethodError that describes exactly what method you need to define, while still doing the right thing for Reals that do have the property that float(x) isa AbstractFloat.

I don’t follow. Can you point me to an example, to see what you mean?

What you mean is that I should define instead:

f(x::Real) = _f(float(x))
_f(x::AbstractFloat) = 2x

But then I get a MethodError, as you say, so I still don’t get automatic differentiation. Am I forced to explicitly define

function f(d::ForwardDiff.Dual{T}) where {T}
       x = ForwardDiff.value(d)
       v = f(x)
       g = 2
       ForwardDiff.Dual{T}(v, g * ForwardDiff.partials(d))
end

? This is painful for more complicated functions, and really is just what I wanted to avoid by using AD.

Is there a better way?

If your f is somehow ‘primitive’, in the sense that it uses operations that don’t have derivative rules defined for them in ForwardDiff.jl/DiffRules.jl, then you have to define an overload for Dual, yes. But if it’s like your example case, I’d just define f(x::Real) = 2x. Could you provide a more realistic example where it’s clear why you’d need the float call?

Not at the moment. I had an example, but I figured out a better way to do it.

And another one: https://github.com/JuliaDiff/ForwardDiff.jl/issues/363.