Hey,
I recently raised an issue on GitHub for Zygote because I fail to get Zygote differentiating code where I apply a simple function on a complex CUDA array. I’m using CUDA 3.1.0, Julia 1.6.1 and Zygote 0.6.10. It also fails on Julia 1.5.4, CUDA v2.4.0, Zygote v0.5.0.
I wanted to post it here as well, because I’m not entirely sure whether it is a Zygote or a CUDA issue. Furthermore, I can’t believe that I’m the first one to notice this.
Anyone has an idea what could be the issue below?
julia> using Zygote, CUDA
julia> x = rand(ComplexF32, (2,2))
2×2 Matrix{ComplexF32}:
0.0598111+0.678913im 0.767138+0.77825im
0.548067+0.98656im 0.306103+0.166084im
julia> x_c = CuArray(x);
julia> f(x) = sum(abs2.(x))
f (generic function with 1 method)
julia> g(x) = sum(real(x .* conj.(x)))
g (generic function with 1 method)
julia> f(x) ≈ f(x_c) ≈ g(x) ≈ g(x_c)
true
julia> Zygote.gradient(f, x)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)
julia> Zygote.gradient(g, x)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)
julia> Zygote.gradient(f, x_c)
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen}) at range.jl:664
iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:664
iterate(::T) where T<:Union{Base.KeySet{var"#s79", var"#s78"} where {var"#s79", var"#s78"<:Dict}, Base.ValueIterator{var"#s77"} where var"#s77"<:Dict} at dict.jl:693
...
Stacktrace:
[1] (::Zygote.var"#1209#1210"{Zygote.var"#1104#1108"})(ȳ::CuArray{Float32, 2})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/broadcast.jl:231
[2] (::Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}})(Δ::CuArray{Float32, 2})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[3] (::Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}})(Δ::CuArray{Float32, 2})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/lib/lib.jl:194
[4] (::Zygote.var"#1689#back#182"{Zygote.var"#180#181"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#577#back#1211"{Zygote.var"#1209#1210"{Zygote.var"#1104#1108"}}}})(Δ::CuArray{Float32, 2})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[5] Pullback
@ ./broadcast.jl:1309 [inlined]
[6] Pullback
@ ./REPL[18]:1 [inlined]
[7] (::typeof(∂(f)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[8] (::Zygote.var"#41#42"{typeof(∂(f))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
[9] gradient(f::Function, args::CuArray{ComplexF32, 2})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
[10] top-level scope
@ REPL[24]:1
[11] top-level scope
@ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81
julia> Zygote.gradient(g, x_c)
(ComplexF32[0.11962223f0 + 1.3578255f0im 1.534275f0 + 1.5565007f0im; 1.0961342f0 + 1.9731205f0im 0.61220574f0 + 0.33216715f0im],)
I would be really happy to get any hint because this is effectively stopping my work currently. I don’t have any clue how to get a workaround, since map
also seems to fail.
Thanks,
Felix