Trying to make my robust logistic function differentiable as part of AD, getting errors so far

I found that StatsFuns.logistic occasionally had issues with really extreme values and so have this more robust version of the logistic function defined as below:

const LOGISTIC_64 = log(Float64(1)/eps(Float64) - Float64(1))
const LOGISTIC_32 = log(Float32(1)/eps(Float32) - Float32(1))
"""
Return the logistic function computed in a numerically stable way:
``logistic(x) = 1/(1+exp(-x))``
"""
function logistic(x::Float64)
	x > LOGISTIC_64  && return one(x)
	x < -LOGISTIC_64 && return zero(x)
	return one(x) / (one(x) + exp(-x))
end
function logistic(x::Float32)
	x > LOGISTIC_32  && return one(x)
	x < -LOGISTIC_32 && return zero(x)
	return one(x) / (one(x) + exp(-x))
end
logistic(x) = logistic(float(x))

I am using it as part of some densities for MCMC sampling and Bayesian inference. I’d like to be able to apply AD techniques to some functions that involve the logistic function as implemented above, but when I do end up getting the following error:

ERROR: StackOverflowError:
Stacktrace:
 [1] logistic(::ForwardDiff.Dual{ForwardDiff.Tag{var"#57#58",Float64},Float64,4}) at /Users/harrisonwilde/Library/Mobile Documents/com~apple~CloudDocs/PhD/Synthetic Data/src/creditcard/distrib_utils.jl:21 (repeats 79984 times)

Presumably as it repeatedly calls the bottom line over and over because calling float(x) where x is of type ForwardDiff.Dual{ForwardDiff.Tag{var"#57#58",Float64},Float64,4} must not work. How can I write my functions so they are “differentiable” in this setting, any guidance would be much appreciated!

Just write your function like this

function LOGISTIC(T) 
    log(one(T)/eps(T) - one(T))
end

function logistic(x::T) where {T}
    LOGISTIC_T = LOGISTIC(T)
    x > LOGISTIC_T && return one(x)
    x < -LOGISTIC_T && return zero(x)
    return one(x) / (one(x) + exp(-x)) end

This way, your function will work for any type supporting eps, one, and exp. Note that due to specialization, this should be just as fast as your implementation.

3 Likes

Thank you, that makes a lot of sense, I think you forgot to add a where {T} on the function definition though right?

Fixed, Good point!

This should be fixed in StatsFuns.jl directly. I’ll try to submit a patch tonight since there’s a few different issues worth fixing here:

  1. It’s better to use exp(x) / (exp(x) + 1) here to handle subnormals properly.

Compare and contrast the implementations in the region that generates subnormal outputs:

let x = -740.0
    exp(x) / (exp(x) + 1)
end
# 4.2e-322

let x = -740.0
    1 / (1 + exp(-x))
end
# 0.0
  1. You might want to use ifelse instead of && since minimizing branches increases the chances SIMD will work. It may not be sufficient to get SIMD to kick in, but I would guess it will work out better.

  2. A potentially bigger issue is that, if you’re ever computing log(logistic(x)), you should use the specialized technique from Equation 10 on page 7 of this paper rather than compute logistic(x) and then apply log to the result.

1 Like

I’m not sure I agree with this example, can you clarify why you think the first one is preferable? Also the thresholding here makes sure that the answer would be zero anyway but in a predictable way since the threshold for Float64 is around ±36.

The OP formula doesn’t deal with subnormals properly, but it does have the benefit of producing 1.0 instead of NaN for larger positive numbers. Possibly the best of both worlds would be to switch which you use based on sign of x.

Run the calculation in higher precision and then truncate to Float64:

Float64(1 / (1 + exp(-big(-740.0))))
# 4.2e-322

Over all intervals I’ve tested, the first implementation is closer to the higher precision result – which is to say that it seems the first strictly dominates the second in precision.

I think it’s best not to switch based on the sign of x since you can simply guard against NaN, but the first implementation is more precise for large positive inputs in the pre-overflow region. See the PR I made for what I think is the proper way to do this: https://github.com/JuliaStats/StatsFuns.jl/pull/94

I think my question was more to do with “what would you do with that 4.2e-322” ? I think there’s a point where it makes sense to truncate results to avoid nasty propagation of errors. In the context of special funs, use of big etc, I could see users willing to use your version but usually this function is used in ML/Stats where it makes little sense to keep track of these values, just like it would make little sense (IMO) to train a neural net in big precision.

Anyway I think it’s just a matter of perspective, I agree with your point about big precision, my perspective is that if you truncate (which I believe can often make sense) the point is irrelevant.

1 Like

Tangentially related, subnormals can be really slow to compute with (https://en.wikipedia.org/wiki/Denormal_number#Performance_issues, 50x speed difference in gemv for different values in vector, https://docs.julialang.org/en/v1/manual/performance-tips/#Treat-Subnormal-Numbers-as-Zeros-1)

2 Likes

That’s fair: there’s a tradeoff here between precision and performance.

I don’t really have any opinion on the numerical issues, but there seems to be a lot of unnecessary typing here. Due to type promotion you can just as well write

const LOGISTIC_64 = log(1 / eps(Float64) - 1) 
const LOGISTIC_32 = log(1 / eps(Float32) - 1)
...
1 / (1 + exp(-x))

If you really need literal Float64 or Float32 you can write

1.0  # Float64
1e0  # also Float64
1f0  # Float32

Integers are really nice, you can use them in most places and they will promote to the correct precision. This too, is redundant:

logistic(x) = logistic(float(x))

since the division inside the logistic function will promote the values to appropriate floats.

3 Likes