Accurate derivative for tanh

When debugging a computation, I found that the derivative of tanh as defined in DiffRules.jl suffers from catastrophic cancellation for pretty mild values (eg outside 20 in absolute value, examples below).

I propose a simple fix, but would like to solicit alternative ideas before making a PR. Recall that

\tanh(x) = \frac{e^{2x} - 1}{e^{2x} + 1}

and its derivative is

\tanh'(x) = \frac{4 e^{2x}}{(e^{2x} + 1)^2}
dth1(x) = 1 - abs2(tanh(x))      # what we now have in DiffRules
dth0(x) = oftype(float(x), dth1(BigFloat(x))) # more precise calculation
function dth2(x)                              # proposed fix
    z = 2*x
    ez = exp(z)
    abs(z) > 0.5 ? 4 / (ez - exp(-z)) : 4 * ez / abs2(1 + ez)
end

using PrettyTables              # examples
tab = [(; x, d0 = dth0(x), d1 = dth1(x), d2 = dth2(x)) for x in -0:40];
pretty_table(IOContext(stdout, :limit => false), tab)

β”Œβ”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚     x β”‚          d0 β”‚          d1 β”‚          d2 β”‚
β”‚ Int64 β”‚     Float64 β”‚     Float64 β”‚     Float64 β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚     0 β”‚         1.0 β”‚         1.0 β”‚         1.0 β”‚
β”‚     1 β”‚    0.419974 β”‚    0.419974 β”‚    0.551441 β”‚
β”‚     2 β”‚   0.0706508 β”‚   0.0706508 β”‚   0.0732871 β”‚
β”‚     3 β”‚  0.00986604 β”‚  0.00986604 β”‚  0.00991507 β”‚
β”‚     4 β”‚  0.00134095 β”‚  0.00134095 β”‚  0.00134185 β”‚
β”‚     5 β”‚ 0.000181583 β”‚ 0.000181583 β”‚   0.0001816 β”‚
β”‚     6 β”‚  2.45765e-5 β”‚  2.45765e-5 β”‚  2.45768e-5 β”‚
β”‚     7 β”‚  3.32611e-6 β”‚  3.32611e-6 β”‚  3.32611e-6 β”‚
β”‚     8 β”‚  4.50141e-7 β”‚  4.50141e-7 β”‚  4.50141e-7 β”‚
β”‚     9 β”‚  6.09199e-8 β”‚  6.09199e-8 β”‚  6.09199e-8 β”‚
β”‚    10 β”‚  8.24461e-9 β”‚  8.24461e-9 β”‚  8.24461e-9 β”‚
β”‚    11 β”‚  1.11579e-9 β”‚  1.11579e-9 β”‚  1.11579e-9 β”‚
β”‚    12 β”‚ 1.51005e-10 β”‚ 1.51005e-10 β”‚ 1.51005e-10 β”‚
β”‚    13 β”‚ 2.04364e-11 β”‚ 2.04363e-11 β”‚ 2.04364e-11 β”‚
β”‚    14 β”‚ 2.76576e-12 β”‚ 2.76579e-12 β”‚ 2.76576e-12 β”‚
β”‚    15 β”‚ 3.74305e-13 β”‚ 3.74367e-13 β”‚ 3.74305e-13 β”‚
β”‚    16 β”‚ 5.06567e-14 β”‚ 5.06262e-14 β”‚ 5.06567e-14 β”‚
β”‚    17 β”‚ 6.85563e-15 β”‚ 6.88338e-15 β”‚ 6.85563e-15 β”‚
β”‚    18 β”‚ 9.27809e-16 β”‚ 8.88178e-16 β”‚ 9.27809e-16 β”‚
β”‚    19 β”‚ 1.25565e-16 β”‚ 2.22045e-16 β”‚ 1.25565e-16 β”‚
β”‚    20 β”‚ 1.69934e-17 β”‚         0.0 β”‚ 1.69934e-17 β”‚
β”‚    21 β”‚ 2.29981e-18 β”‚         0.0 β”‚ 2.29981e-18 β”‚
β”‚    22 β”‚ 3.11245e-19 β”‚         0.0 β”‚ 3.11245e-19 β”‚
β”‚    23 β”‚ 4.21225e-20 β”‚         0.0 β”‚ 4.21225e-20 β”‚
β”‚    24 β”‚ 5.70066e-21 β”‚         0.0 β”‚ 5.70066e-21 β”‚
β”‚    25 β”‚   7.715e-22 β”‚         0.0 β”‚   7.715e-22 β”‚
β”‚    26 β”‚ 1.04411e-22 β”‚         0.0 β”‚ 1.04411e-22 β”‚
β”‚    27 β”‚ 1.41305e-23 β”‚         0.0 β”‚ 1.41305e-23 β”‚
β”‚    28 β”‚ 1.91236e-24 β”‚         0.0 β”‚ 1.91236e-24 β”‚
β”‚    29 β”‚ 2.58809e-25 β”‚         0.0 β”‚ 2.58809e-25 β”‚
β”‚    30 β”‚  3.5026e-26 β”‚         0.0 β”‚  3.5026e-26 β”‚
β”‚    31 β”‚ 4.74026e-27 β”‚         0.0 β”‚ 4.74026e-27 β”‚
β”‚    32 β”‚ 6.41524e-28 β”‚         0.0 β”‚ 6.41524e-28 β”‚
β”‚    33 β”‚ 8.68209e-29 β”‚         0.0 β”‚ 8.68209e-29 β”‚
β”‚    34 β”‚ 1.17499e-29 β”‚         0.0 β”‚ 1.17499e-29 β”‚
β”‚    35 β”‚ 1.59018e-30 β”‚         0.0 β”‚ 1.59018e-30 β”‚
β”‚    36 β”‚ 2.15207e-31 β”‚         0.0 β”‚ 2.15207e-31 β”‚
β”‚    37 β”‚ 2.91252e-32 β”‚         0.0 β”‚ 2.91252e-32 β”‚
β”‚    38 β”‚ 3.94166e-33 β”‚         0.0 β”‚ 3.94166e-33 β”‚
β”‚    39 β”‚ 5.33446e-34 β”‚         0.0 β”‚ 5.33446e-34 β”‚
β”‚    40 β”‚ 7.21941e-35 β”‚         0.0 β”‚ 7.21941e-35 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Why not just use \tanh'(x) = \mathrm{sech}(x)^2?

PS. Note that using abs2(z) here is wrong for complex arguments, and for real arguments has no advantage over z^2 anyway.

1 Like

That’s how it was before

but it was changed for CSE. I don’t know how relevant that is though at the moment.

It’s not relevant for most applications of DiffRules, but it definitely is relevant to how it’s computed in ChainRules.jl since ChainRules does explicitly re-use the answer computed for the forwards pass:

Perhaps what we should do here is add a branch like

abs(real(x)) <= 14 ? (1 - Ξ©^2) : sech(x)^2

And then maybe remove that branch only for the @fastmath version?

With your new formulation you’ll lose any CSE too, so you might as well go back to using sech.

1 Like

That’s way too aggressive. By x = 4 the 1 - tanh(x)^2 formula has already lost 2 digits, by x = 6 it has lost 4 digits, and by x = 14 it has lost 10 digits.

Doing a few quick experiments, the crossover point where sech(x)^2 becomes more accurate seems to be around x = 1, so you could do abs(real(x)) <= 1 ? (1 - Ξ©^2) : sech(x)^2 (or real(x)^2 <= 1 … not sure which is faster).

Yeah, it was just an example value taken from Tamas’ table.

I don’t think accuracy is the only concern here though, since AD people are typically also very performance sensitive. My assumption is that so long as values can be re-used from the primal calculation without truly catastrophic losses of performance, that’s typically going to be preferred.

ForwardDiff applies CSE to the rules generated from the DiffRules expressions.

Yes, the issue also affects ChainRules: rule for tanh has catastrophic cancellation for |x| > 20 Β· Issue #102 Β· JuliaDiff/DiffRules.jl Β· GitHub

I disagree. Some applications of AD might need only low accuracy, but in general when a standard library computes a math function the programmer should generically expect it to be computed to close to machine precision.

If they programmer wants lower accuracy they should use lower precision, or a different function name tanh_approx (or some macro @approx that rewrites tanh to tanh_approx etc.).

7 Likes

Yes, debugging speed issues is much easier than debugging numerical issues. Whenever there is a trade-off, inaccurate but fast should be opt-in.

I should have explanded, but I thought we could keep CSE by also calculating tanh from exp(2*x) etc. But maybe I misunderstand how that works, or how relevant CSE is.

I think the simplest solution seems to be to define the derivative as sech(x)^2 without reusing the primal. IMO, as @stevengj suggested, if users or a package want a faster and more inaccurate version, they should define and use a tanh_approx function with a less accurate derivate.

Branches discussed here would be optimized for Float64 but users would still run into the same problems with e.g. Float32 (or they might be inefficient for something like Float128).

No, in this case my suggested crossover point abs(real(x)) < 1 should be independent of precision (at least for real x … I haven’t thought too much about the complex case), because that’s where tanh(x)^2 is around 0.5 or less.

tanh(x) is computed from polynomial approximations for many values of x. Worse, as far as I can tell, tanh is not inlined by the compiler, which will prevent the compiler from CSE-ing expressions computed inside tanh with expressions outside tanh:

julia> f(x) = 1+tanh(x)
f (generic function with 1 method)

julia> @code_llvm f(0.2)
;  @ REPL[42]:1 within `f`
define double @julia_f_445(double %0) #0 {
top:
  %1 = call double @j_tanh_447(double %0) #0
; β”Œ @ promotion.jl:410 within `+` @ float.jl:408
   %2 = fadd double %1, 1.000000e+00
; β””
  ret double %2
}
1 Like

True, I was thinking about the suggestion of a cutoff at 14. I’m not sure how much we would gain from a branch using abs(real(x)) <= 1 in practice though.