Hi all,
I met a strange bug, and, to be honest, I’m relatively new to the language, so it could be a mistake from my part.
So, I’m having a real blast playing with bayesian NNs, and I’m slowly building up from the basic example given on the Turing.jl website, while grabbing as much of an understanding of Julia inner working as I can. Now, I’m trying to extend the classification from being binary to multiclass. Thus, I defined the output of my BNN as a Categorical
instead of a Binomial
.
Logically enough, Categorical
requires its samples to be probability vector positive and summing to one, i.e., from Distributions/src/utils.jl:
isprobvec(p::AbstractVector{<:Real}) =
all(x -> x ≥ zero(x), p) && isapprox(sum(p), one(eltype(p)))
I convert the output of the last 3 neurons of my network to a probability vector with Flux.softmax
, and, naturally, it’s not exactly 1.0 due to floating points shenanigans, but isapprox
takes care of that.
The problem however, is that what works with floats doesn’t seem to work with TrackedReal
(which are instrumental to AD):
xx = ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,2,Array{Float64,2},Array{Float64,2}}}[TrackedReal<4WI>(0.42766067643153494, 0.0, 3Wk, ---), TrackedReal<FB4>(0.1639909617549507, 0.0, 3Wk, ---), TrackedReal<49W>(0.4083483618135143, 0.0, 3Wk, ---)]
sum(xx) = TrackedReal<KQ5>(0.9999999999999999, 0.0, 3Wk, ---)
one(eltype(xx)) = TrackedReal<HSG>(1.0, 0.0, ---, ---)
isapprox(sum(xx), 1.0) = true
isapprox(sum(xx), one(eltype(xx))) = false
isprobvec(r[:, i]) = false
So xx
, the output of my network for a single sample, is indeed a “very close to sum-to-one” vector of TrackedReal
(xx = [0.427..., 0.163..., 0.408...]
); its sum is a “very-close-to-one” TrackedReal
(sum = TrackedReal(0.9999999999999999, ...)
), it is virtually equal to Float64(1.0)
(isapprox(sum(xx), 1.0) = true
), but, it’s not very close to the one
of its type: isapprox(sum(xx), one(eltype(xx))) = false
. Hence, it’s not a probability vector, and it is rejected by Categorical
.
So did I miss something and are tracked reals failing to compare ≃ to one a normal behaviour, or is there a bug somewhere?
Thank you for your insights!