While I’m able to differentiate sum(abs2, a)
(using Zygote) when a isa CuArray{<:Complex}
, this fails for sum(f, a)
with f = abs2
[see below for details]. I haven’t been able to figure out exactly why – and how to fix this. Anyone (probably @maleadt )?
MWE:
using CuArrays
using Zygote
import GPUArrays
f(x) = abs2(x) #Test function. Eventually, I'd like to be able to use an arbitrary real valued function f...
CuArrays.@cufunc f(x) = abs2(x) #For using f with CuArrays
GPUArrays.gpu_promote_type(::typeof(f), ::Type{Complex{T}}) where {T} = T #"hack" such that sum(f, a) isa Real
a_cpu = randn(ComplexF32, 100, 100)
a_gpu = cu(a_cpu)
sum(f, a_gpu) #OKAY
gradient(x -> sum(f, x), a_cpu) #OKAY on CPU with f
gradient(x -> sum(f, x), real(a_gpu)) #OKAY on GPU with f if the input is real
gradient(x -> sum(abs2, x), a_gpu) #OKAY on GPU with abs2
gradient(x -> sum(f, x), a_gpu) #ERROR on GPU with f
The error message reads
ERROR: MethodError: no method matching iterate(::Nothing)
Closest candidates are:
iterate(::Core.SimpleVector) at essentials.jl:600
iterate(::Core.SimpleVector, ::Any) at essentials.jl:600
iterate(::ExponentialBackOff) at error.jl:218
...
Stacktrace:
[1] (::Zygote.var"#1728#1732"{Zygote.var"#1599#1603"})(::CuArray{Float32,2,Nothing}) at /home/troels/.julia/packages/Zygote/R2pFS/src/lib/broadcast.jl:179
[2] (::Zygote.var"#661#back#1733"{Zygote.var"#1728#1732"{Zygote.var"#1599#1603"}})(::CuArray{Float32,2,Nothing}) at /home/troels/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[3] (::Zygote.var"#157#158"{Zygote.var"#661#back#1733"{Zygote.var"#1728#1732"{Zygote.var"#1599#1603"}},Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{}}})(::CuArray{Float32,2,Nothing}) at /home/troels/.julia/packages/Zygote/R2pFS/src/lib/lib.jl:156
[4] (::Zygote.var"#297#back#159"{Zygote.var"#157#158"{Zygote.var"#661#back#1733"{Zygote.var"#1728#1732"{Zygote.var"#1599#1603"}},Tuple{Tuple{Nothing,Nothing,Nothing},Tuple{}}}})(::CuArray{Float32,2,Nothing}) at /home/troels/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[5] #1135 at ./broadcast.jl:1231 [inlined]
[6] (::typeof(∂(#1135)))(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface2.jl:0
[7] (::Zygote.var"#28#29"{typeof(∂(#1135))})(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface.jl:38
[8] (::Zygote.var"#1136#1138"{Zygote.var"#28#29"{typeof(∂(#1135))}})(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/lib/array.jl:170
[9] #17 at ./none:1 [inlined]
[10] (::typeof(∂(#17)))(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface2.jl:0
[11] (::Zygote.var"#28#29"{typeof(∂(#17))})(::Float32) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface.jl:38
[12] gradient(::Function, ::CuArray{Complex{Float32},2,Nothing}) at /home/troels/.julia/packages/Zygote/R2pFS/src/compiler/interface.jl:47
[13] top-level scope at none:0
Additional info:
(v1.3) pkg> st CuArrays
Status `~/.julia/environments/v1.3/Project.toml`
[79e6a3ab] Adapt v1.0.0
[fa961155] CEnum v0.2.0
[3895d2a7] CUDAapi v2.1.0
[c5f51814] CUDAdrv v5.0.1
[be33ccc6] CUDAnative v2.7.0
[3a865a2d] CuArrays v1.6.0 #master (https://github.com/JuliaGPU/CuArrays.jl.git)
[864edb3b] DataStructures v0.17.7
[0c68f7d7] GPUArrays v2.0.1 #master (https://github.com/JuliaGPU/GPUArrays.jl.git)
[872c559c] NNlib v0.6.2
(v1.3) pkg> st Zygote
Status `~/.julia/environments/v1.3/Project.toml`
[7a1cc6ca] FFTW v1.2.0
[f6369f11] ForwardDiff v0.10.8
[7869d1d1] IRTools v0.3.0
[872c559c] NNlib v0.6.2
[276daf66] SpecialFunctions v0.9.0
[e88e6eb3] Zygote v0.4.3 #master (https://github.com/FluxML/Zygote.jl.git)
[700de1a5] ZygoteRules v0.2.0
(Note: I asked the same question – without luck – on Slack #gpu yesterday, but realized that this might be a better place to do so…)