I am interested in writing a generic function that operates on an AbstractArray
and produces some output. In doing so, it needs to create an intermediate array. The function should work with CPU arrays and GPU arrays, and the function needs to be differentiable with Zygote (so no mutation allowed).
Let me illustrate with an example:
function f1(x::AbstractArray)
y = map(x -> x^2 + 1, eachindex(x))
sum(x .* y)
end
This will work with most arrays and with Zygote, but will fail if x
is a GPU array, as y
won’t be on GPU and the x .* y
can’t deal with a mix of on-GPU and on-CPU arrays.
This version:
function f2(x::AbstractArray)
y = similar(x)
map!(x -> x^2 + 1, y, eachindex(x))
sum(x .* y)
end
will work with x
being a GPU or a CPU array just fine, but will fail with Zygote due to mutation.
Is there a good idiomatic way to do write this generically without mutation?
P.S. This particular example could have been solved with dot()
, but I’m asking a more general question about computation where I need y
to be similar to x
for computation in terms of the location (CPU/GPU) of the array.