Zygote errors on simple operations with Complex CUDA Arrays

Hey,

I recently raised an issue on GitHub for Zygote because I fail to get Zygote differentiating code where I apply a simple function on a complex CUDA array. I’m using CUDA 3.1.0, Julia 1.6.1 and Zygote 0.6.10. It also fails on Julia 1.5.4, CUDA v2.4.0, Zygote v0.5.0.

I wanted to post it here as well, because I’m not entirely sure whether it is a Zygote or a CUDA issue. Furthermore, I can’t believe that I’m the first one to notice this.

Anyone has an idea what could be the issue below?

julia> using Zygote, CUDA

julia> x = rand(ComplexF32, (2,2))
2×2 Matrix{ComplexF32}:
 0.0598111+0.678913im  0.767138+0.77825im
  0.548067+0.98656im   0.306103+0.166084im

julia> x_c = CuArray(x);

julia> f(x) = sum(abs2.(x))
f (generic function with 1 method)

julia> g(x) = sum(real(x .* conj.(x)))
g (generic function with 1 method)

julia> f(x) ≈ f(x_c) ≈ g(x) ≈ g(x_c)
true

julia> Zygote.gradient(f, x)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)

julia> Zygote.gradient(g, x)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)

julia> Zygote.gradient(f, x_c)
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
  iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:664
  iterate(::T) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
  ...
Stacktrace:
  [1] (::Zygote.var"#1209#1210"{Zygote.var"#1104#1108"})(ȳ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/broadcast.jl:231
  [2] (::Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [3] (::Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194
  [4] (::Zygote.var"#1689#back#182"{Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [5] Pullback
    @ ./broadcast.jl:1309 [inlined]
  [6] Pullback
    @ ./REPL[18]:1 [inlined]
  [7] (::typeof(∂(f)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
  [8] (::Zygote.var"#41#42"{typeof(∂(f))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
  [9] gradient(f::Function, args::CuArray{ComplexF32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [10] top-level scope
    @ REPL[24]:1
 [11] top-level scope
    @ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81

julia> Zygote.gradient(g, x_c)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)

I would be really happy to get any hint because this is effectively stopping my work currently. I don’t have any clue how to get a workaround, since map also seems to fail.

Thanks,

Felix