Gradient of `sum(f, x)` using Zygote, when `f` is real valued and `x isa CuArray{<:Complex}`

While I’m able to differentiate sum(abs2, a) (using Zygote) when a isa CuArray{<:Complex}, this fails for sum(f, a) with f = abs2 [see below for details]. I haven’t been able to figure out exactly why – and how to fix this. Anyone (probably @maleadt :slight_smile: )?

MWE:

using CuArrays
using Zygote
import GPUArrays
f(x) = abs2(x) #Test function. Eventually, I'd like to be able to use an arbitrary real valued function f...
CuArrays.@cufunc f(x) = abs2(x) #For using f with CuArrays
GPUArrays.gpu_promote_type(::typeof(f), ::Type{Complex{T}}) where {T} = T #"hack" such that sum(f, a) isa Real
a_cpu = randn(ComplexF32, 100, 100)
a_gpu = cu(a_cpu)
sum(f, a_gpu)                         #OKAY
gradient(x -> sum(f, x), a_cpu)       #OKAY on CPU with f
gradient(x -> sum(f, x), real(a_gpu)) #OKAY on GPU with f if the input is real
gradient(x -> sum(abs2, x), a_gpu)    #OKAY on GPU with abs2
gradient(x -> sum(f, x), a_gpu)       #ERROR on GPU with f  

The error message reads

ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
  iterate(::Core.SimpleVector) at essentials.jl:600
  iterate(::Core.SimpleVector, ::Any) at essentials.jl:600
  iterate(::ExponentialBackOff) at error.jl:218
  ...
Stacktrace:
 [1] (::Zygote.var"#1728#1732"{Zygote.var"#1599#1603"})(::CuArray{Float32,2,Nothing}) at /home/troels/.julia/packages/Zygote/R2pFS/src/lib/broadcast.jl:179
 [2] (::Zygote.var"#661#back#1733"{Zygote.var"#1728#1732"{Zygote.var"#1599#1603"}})(::CuArray{Float32,2,Nothing}) at /home/troels/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [3] (::Zygote.var"#157#158"{Zygote.var"#661#back#1733"{Zygote.var"#1728#1732"{Zygote.var"#1599#1603"}},Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{}}})(::CuArray{Float32,2,Nothing}) at /home/troels/.julia/packages/Zygote/R2pFS/src/lib/lib.jl:156
 [4] (::Zygote.var"#297#back#159"{Zygote.var"#157#158"{Zygote.var"#661#back#1733"{Zygote.var"#1728#1732"{Zygote.var"#1599#1603"}},Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{}}}})(::CuArray{Float32,2,Nothing}) at /home/troels/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [5] #1135 at ./broadcast.jl:1231 [inlined]
 [6] (::typeof(∂(#1135)))(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface2.jl:0
 [7] (::Zygote.var"#28#29"{typeof(∂(#1135))})(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface.jl:38
 [8] (::Zygote.var"#1136#1138"{Zygote.var"#28#29"{typeof(∂(#1135))}})(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/lib/array.jl:170
 [9] #17 at ./none:1 [inlined]
 [10] (::typeof(∂(#17)))(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#28#29"{typeof(∂(#17))})(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface.jl:38
 [12] gradient(::Function, ::CuArray{Complex{Float32},2,Nothing}) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface.jl:47
 [13] top-level scope at none:0

Additional info:

(v1.3) pkg> st CuArrays
    Status `~/.julia/environments/v1.3/Project.toml`
  [79e6a3ab] Adapt v1.0.0
  [fa961155] CEnum v0.2.0
  [3895d2a7] CUDAapi v2.1.0
  [c5f51814] CUDAdrv v5.0.1
  [be33ccc6] CUDAnative v2.7.0
  [3a865a2d] CuArrays v1.6.0 #master (https://github.com/JuliaGPU/CuArrays.jl.git)
  [864edb3b] DataStructures v0.17.7
  [0c68f7d7] GPUArrays v2.0.1 #master (https://github.com/JuliaGPU/GPUArrays.jl.git)
  [872c559c] NNlib v0.6.2

(v1.3) pkg> st Zygote
    Status `~/.julia/environments/v1.3/Project.toml`
  [7a1cc6ca] FFTW v1.2.0
  [f6369f11] ForwardDiff v0.10.8
  [7869d1d1] IRTools v0.3.0
  [872c559c] NNlib v0.6.2
  [276daf66] SpecialFunctions v0.9.0
  [e88e6eb3] Zygote v0.4.3 #master (https://github.com/FluxML/Zygote.jl.git)
  [700de1a5] ZygoteRules v0.2.0

(Note: I asked the same question – without luck – on Slack #gpu yesterday, but realized that this might be a better place to do so…)

Although there’s likely something wrong with CuArrays functionality here, I don’t have enough Zygote experience to know what’s up.