Hi,
I noticed a problem with Zygote and CUDA when optimizing a loss function containing adjoints of reshaped arrays.
Preliminary observation
First, note that creating a reshaped array of a CuArray adjoint fails when scalar indexing is disabled:
using CUDA
CUDA.allowscalar(false)
A = cu(randn(Float32, 5, 2, 5))
reshape(A', 10, 5)
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
So Base.ReshapedArray{T, N, Adjoint{T, CuArray}} is a type that Julia can construct but cannot materialize on GPU — even copy raises the scalar indexing error.
The problem with Zygote
Zygote silently creates this type in the backward pass. Here is a minimal example:
using CUDA, Zygote
CUDA.allowscalar(false)
A = cu(randn(Float32, 5, 2, 5))
B = cu(randn(Float32, 5, 2, 3))
f(A, B) = sum(real(reshape(A, 10, 5)' * reshape(B, 10, 3)))
Zygote.gradient(f, A, B) # scalar indexing error
Zygote composes reshape and adjoint lazily in the backward pass, producing a Base.ReshapedArray{T, N, Adjoint{T, CuArray}} tangent that cannot be used on GPU.
Workaround
In my code, I capture this type in a custom ChainRulesCore projectTo and materialize it with:
function materialize(x::Base.ReshapedArray{T, N, <:Adjoint{T, <:AbstractArray}}) where {T, N}
adj_materialized = copy(parent(x))
reshape(adj_materialized, size(x))
end
materialize(x) = x
The key insight is that copy(parent(x)) materializes the Adjoint first (CUDA knows how to copy an Adjoint{CuArray}), and then reshape can be applied safely.
Question
Is this a known limitation of CUDA.jl with composed array wrappers, or something that could be fixed?
Thanks,
Nicolas