When debugging a computation, I found that the derivative of tanhas defined in DiffRules.jl suffers from catastrophic cancellation for pretty mild values (eg outside 20 in absolute value, examples below).
I propose a simple fix, but would like to solicit alternative ideas before making a PR. Recall that
Itβs not relevant for most applications of DiffRules, but it definitely is relevant to how itβs computed in ChainRules.jl since ChainRules does explicitly re-use the answer computed for the forwards pass:
Perhaps what we should do here is add a branch like
Thatβs way too aggressive. By x = 4 the 1 - tanh(x)^2 formula has already lost 2 digits, by x = 6 it has lost 4 digits, and by x = 14 it has lost 10 digits.
Yeah, it was just an example value taken from Tamasβ table.
I donβt think accuracy is the only concern here though, since AD people are typically also very performance sensitive. My assumption is that so long as values can be re-used from the primal calculation without truly catastrophic losses of performance, thatβs typically going to be preferred.
I disagree. Some applications of AD might need only low accuracy, but in general when a standard library computes a math function the programmer should generically expect it to be computed to close to machine precision.
If they programmer wants lower accuracy they should use lower precision, or a different function name tanh_approx (or some macro @approx that rewrites tanh to tanh_approx etc.).
Yes, debugging speed issues is much easier than debugging numerical issues. Whenever there is a trade-off, inaccurate but fast should be opt-in.
I should have explanded, but I thought we could keep CSE by also calculating tanh from exp(2*x) etc. But maybe I misunderstand how that works, or how relevant CSE is.
I think the simplest solution seems to be to define the derivative as sech(x)^2 without reusing the primal. IMO, as @stevengj suggested, if users or a package want a faster and more inaccurate version, they should define and use a tanh_approx function with a less accurate derivate.
Branches discussed here would be optimized for Float64 but users would still run into the same problems with e.g. Float32 (or they might be inefficient for something like Float128).
No, in this case my suggested crossover point abs(real(x)) < 1 should be independent of precision (at least for real x β¦ I havenβt thought too much about the complex case), because thatβs where tanh(x)^2 is around 0.5 or less.
tanh(x) is computed from polynomial approximations for many values of x. Worse, as far as I can tell, tanh is not inlined by the compiler, which will prevent the compiler from CSE-ing expressions computed inside tanh with expressions outside tanh:
julia> f(x) = 1+tanh(x)
f (generic function with 1 method)
julia> @code_llvm f(0.2)
; @ REPL[42]:1 within `f`
define double @julia_f_445(double %0) #0 {
top:
%1 = call double @j_tanh_447(double %0) #0
; β @ promotion.jl:410 within `+` @ float.jl:408
%2 = fadd double %1, 1.000000e+00
; β
ret double %2
}
True, I was thinking about the suggestion of a cutoff at 14. Iβm not sure how much we would gain from a branch using abs(real(x)) <= 1 in practice though.