# Gradient of `sum(f, x)` using Zygote, when `f` is real valued and `x isa CuArray{<:Complex}`

While I’m able to differentiate `sum(abs2, a)` (using Zygote) when `a isa CuArray{<:Complex}`, this fails for `sum(f, a)` with `f = abs2` [see below for details]. I haven’t been able to figure out exactly why – and how to fix this. Anyone (probably @maleadt )?

MWE:

``````using CuArrays
using Zygote
import GPUArrays
f(x) = abs2(x) #Test function. Eventually, I'd like to be able to use an arbitrary real valued function f...
CuArrays.@cufunc f(x) = abs2(x) #For using f with CuArrays
GPUArrays.gpu_promote_type(::typeof(f), ::Type{Complex{T}}) where {T} = T #"hack" such that sum(f, a) isa Real
a_cpu = randn(ComplexF32, 100, 100)
a_gpu = cu(a_cpu)
sum(f, a_gpu)                         #OKAY
gradient(x -> sum(f, x), a_cpu)       #OKAY on CPU with f
gradient(x -> sum(f, x), real(a_gpu)) #OKAY on GPU with f if the input is real
gradient(x -> sum(abs2, x), a_gpu)    #OKAY on GPU with abs2
gradient(x -> sum(f, x), a_gpu)       #ERROR on GPU with f
``````

``````ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
iterate(::Core.SimpleVector) at essentials.jl:600
iterate(::Core.SimpleVector, ::Any) at essentials.jl:600
iterate(::ExponentialBackOff) at error.jl:218
...
Stacktrace:
[3] (::Zygote.var"#157#158"{Zygote.var"#661#back#1733"{Zygote.var"#1728#1732"{Zygote.var"#1599#1603"}},Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{}}})(::CuArray{Float32,2,Nothing}) at /home/troels/.julia/packages/Zygote/R2pFS/src/lib/lib.jl:156
[6] (::typeof(∂(#1135)))(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface2.jl:0
[7] (::Zygote.var"#28#29"{typeof(∂(#1135))})(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface.jl:38
[8] (::Zygote.var"#1136#1138"{Zygote.var"#28#29"{typeof(∂(#1135))}})(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/lib/array.jl:170
[9] #17 at ./none:1 [inlined]
[10] (::typeof(∂(#17)))(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface2.jl:0
[11] (::Zygote.var"#28#29"{typeof(∂(#17))})(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface.jl:38
[13] top-level scope at none:0
``````

``````(v1.3) pkg> st CuArrays
Status `~/.julia/environments/v1.3/Project.toml`
[fa961155] CEnum v0.2.0
[3895d2a7] CUDAapi v2.1.0
[be33ccc6] CUDAnative v2.7.0
[3a865a2d] CuArrays v1.6.0 #master (https://github.com/JuliaGPU/CuArrays.jl.git)
[864edb3b] DataStructures v0.17.7
[0c68f7d7] GPUArrays v2.0.1 #master (https://github.com/JuliaGPU/GPUArrays.jl.git)
[872c559c] NNlib v0.6.2

(v1.3) pkg> st Zygote
Status `~/.julia/environments/v1.3/Project.toml`
[7a1cc6ca] FFTW v1.2.0
[f6369f11] ForwardDiff v0.10.8
[7869d1d1] IRTools v0.3.0
[872c559c] NNlib v0.6.2
[276daf66] SpecialFunctions v0.9.0
[e88e6eb3] Zygote v0.4.3 #master (https://github.com/FluxML/Zygote.jl.git)
[700de1a5] ZygoteRules v0.2.0
``````

(Note: I asked the same question – without luck – on Slack #gpu yesterday, but realized that this might be a better place to do so…)

Although there’s likely something wrong with CuArrays functionality here, I don’t have enough Zygote experience to know what’s up.