Hello, I’ve encountered some very bizzare behaviors with equality .==
and inequality .<=
when using Flux.jl on GPU. Below is a small example:
using CUDA
using Flux
device = gpu_device()
# device = cpu_device()
y = [0, 1, 2] |> device
println(y)
println(typeof(y))
println(y .== 1)
println(y .<= 1)
f = x -> begin
y = [0, 1, 2] |> device
println(y)
println(typeof(y))
println(y .== 1)
println(y .<= 1)
println(y .>= 1)
return sum(x)
end
x = Float32[1,2] |> device
grad = Flux.gradient(f, x)
This correctly outputs
[0, 1, 2]
Vector{Int64}
Bool[0, 1, 0]
Bool[1, 1, 0]
[0, 1, 2]
Vector{Int64}
Bool[0, 1, 0]
Bool[1, 1, 0]
Bool[0, 1, 1]
on CPU.
However, when tested on both CUDA 12.5 and Metal, the outputs become
[0, 1, 2]
CuArray{Int64, 1, CUDA.DeviceMemory}
Bool[0, 1, 0]
Bool[1, 1, 0]
[0, 1, 2]
CuArray{Int64, 1, CUDA.DeviceMemory}
Bool[0, 0, 0]
Bool[1, 0, 0]
Bool[0, 1, 1]
Based on the last three lines, .==
and .<=
are inaccurate while .>=
seems to be fine. Also this problem only happens within Zygote.gradient()
, as the first few lines of outputs show that outside of the gradient computation the results are correct.
I got the same outputs by changing all integers to Float32
.
Versions:
- Julia: v1.11.3
- Flux.jl: v0.16.3
- Zygote.jl: v0.7.6
- CUDA.jl: v5.7.2
I wonder if this is a bug, or am I doing something wrong? Thank you!