.== and .<= inside Zygote.gradient() are inaccurate on GPU

Hello, I’ve encountered some very bizzare behaviors with equality .== and inequality .<= when using Flux.jl on GPU. Below is a small example:

using CUDA
using Flux

device = gpu_device()
# device = cpu_device()

y = [0, 1, 2] |> device
println(y)
println(typeof(y))
println(y .== 1)
println(y .<= 1)

f = x -> begin
    y = [0, 1, 2] |> device
    println(y)
    println(typeof(y))
    println(y .== 1)
    println(y .<= 1)
    println(y .>= 1)
    return sum(x)
end

x = Float32[1,2] |> device
grad = Flux.gradient(f, x)

This correctly outputs

[0, 1, 2]
Vector{Int64}
Bool[0, 1, 0]
Bool[1, 1, 0]
[0, 1, 2]
Vector{Int64}
Bool[0, 1, 0]
Bool[1, 1, 0]
Bool[0, 1, 1]

on CPU.

However, when tested on both CUDA 12.5 and Metal, the outputs become

[0, 1, 2]
CuArray{Int64, 1, CUDA.DeviceMemory}
Bool[0, 1, 0]
Bool[1, 1, 0]
[0, 1, 2]
CuArray{Int64, 1, CUDA.DeviceMemory}
Bool[0, 0, 0]
Bool[1, 0, 0]
Bool[0, 1, 1]

Based on the last three lines, .== and .<= are inaccurate while .>= seems to be fine. Also this problem only happens within Zygote.gradient(), as the first few lines of outputs show that outside of the gradient computation the results are correct.

I got the same outputs by changing all integers to Float32.

Versions:

  • Julia: v1.11.3
  • Flux.jl: v0.16.3
  • Zygote.jl: v0.7.6
  • CUDA.jl: v5.7.2

I wonder if this is a bug, or am I doing something wrong? Thank you!

What’s going on is this: Zygote computes the gradient of broadcasting using ForwardDiff, at least for GPU arrays (and with some exceptions for functions it knows). That means broadcasting things like Dual.(y) .<= 1:

julia> y .<= 1
3-element MtlVector{Bool, Metal.PrivateStorage}:
 1
 1
 0

julia> using ForwardDiff: Dual

julia> println(Dual.(y, 1) .<= 1)
Bool[1, 0, 0]

(@v1.11) pkg> st ForwardDiff
Status `~/.julia/environments/v1.11/Project.toml`
  [f6369f11] ForwardDiff v1.0.1

The middle 0 (false) is because Dual(1, 1) is regarded as a small perturbation 1.0 + ε, hence above 1. This behaviour is quite new, versions before ForwardDiff 1.0 behaved differently:

julia> println(Dual.(y, 1) .<= 1)
Bool[1, 1, 0]

(jl_c7B1O4) pkg> st ForwardDiff
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_c7B1O4/Project.toml`
⌃ [f6369f11] ForwardDiff v0.10.38

Whether your example is a bug I’m not sure. Zygote has no reason to be differentiating something within println, as this doesn’t contribute to the answer. In fact broadcasting .<= always produces an array of Bool, which isn’t differentiable, which is a second reason for Zygote not to be differentiating that. Nevertheless, Zygote tends to get its fingers into everything. You can tell it not to with @ignore:

julia> let
           f = x -> begin
               Zygote.@ignore println(x .== 1)
               Zygote.@ignore println(x .<= 1)
               return sum(x)
           end

           x = Float32[0,1,2] |> device
           grad = Flux.gradient(f, x)
       end;
Bool[0, 1, 0]
Bool[1, 1, 0]

But whether you should or not depends on what you are doing. If you are optimising some x, you are going to perturb it, and code which depends strongly on the difference between 0.99999994f0 and 1.0000001f0 is then usually a bad idea.

2 Likes

I see, that is very interesting! But it doesn’t seem to explain why the problem only happens on GPUs, and not on CPU.

I’m using elementwise comparison of integers to mask data as part of my Flux model, so Zygote.@ignore should be appropriate.

1 Like

In addition, I tried the following

CUDA.allowscalar(true)
println(y[2] == 1) # true
println(y[2] <= 1)  # true

The paths taken for GPU and not are different. The most generic broadcasting uses Zygote itself on every element, and that is too complex (too mutable?) to compile to a GPU kernel.

The CPU path here has an explicit branch on the eltype, if T == Bool, in which case it knows there is nothing for AD to do. The GPU path here could perhaps have a similar branch… or perhaps can just be deleted, says the comment there?

Yes. Without broadcasting, there is no ForwardDiff, and Zygote doesn’t do this “perturbation destroys ==” thing.

Nor does most AD. It’s a bit of a delicate issue. If this is an implementation of prod (with an optimisation), then most systems get gradient(prod2, [1,2,0,4,0,6.]) wrong:

function prod2(xs::Vector)
    p = one(eltype(xs))
    for x in xs
        p = p * x
        p == 0 && break  # exit early once you know the answer
    end
    p
end
1 Like

Actually, I found that this doesn’t work. I can use

ChainRulesCore.@ignore_derivatives mask = y .== 1

but mask is necessary to mask some other data which needs to be differentiated through. ~Because of this, y or mask cannot be ignore’d, so I have to deal with the aforementioned ForwardDiff problem.~ [EDIT: figured out how to do it the right way, see below]

Obviously there are ways to produce the correct y .== 1 BitVector with some tricks, but I think it shouldn’t have been necessary in the first place. My feeling is this should be regarded as a bug.

Can you show a bit more detail? Like what’s the MWE that gets a wrong answer without @ignore? Maybe that’s a bug. Why can’t it use @ignore? That I don’t follow.

(But a computation, not println – which might be a separate, trivial bug.)

I think I did it in the wrong way earlier. The following would work:

f = x -> begin
    y = [0, 1, 2] |> device
    mask = begin
        Flux.@ignore_derivatives y .== 1
    end
    return sum(x[mask])
end

x = Float32[1, 2, 3] |> device
grad = Flux.gradient(f, x) # [0.0, 1.0, 0.0]
1 Like

Ok. Ideally this would just work, without @ignore. Mind making a Zygote issue? (Or even better a PR?)

Sure, I’ll create an issue first. Thanks!

1 Like