Trying to make my robust logistic function differentiable as part of AD, getting errors so far

Just write your function like this

function LOGISTIC(T) 
    log(one(T)/eps(T) - one(T))
end

function logistic(x::T) where {T}
    LOGISTIC_T = LOGISTIC(T)
    x > LOGISTIC_T && return one(x)
    x < -LOGISTIC_T && return zero(x)
    return one(x) / (one(x) + exp(-x)) end

This way, your function will work for any type supporting eps, one, and exp. Note that due to specialization, this should be just as fast as your implementation.

3 Likes