I know that Zygote.jl doesn’t support array mutation. I tried to use ScatterNNlib.gather
but It shows array mutation error:
using Flux
using ScatterNNlib
f(x) = gather(x, [1 3; 4 2], 2)
gradient(x -> f(x), rand(4,4))
Let us look at ScatterNNlib.gather
:
function gather(input::AbstractArray{T,N}, index::AbstractArray{<:Integer,N}, dims::Integer) where {T,N}
@assert dims <= N "Specified dimensions must lower or equal to the rank of input matrix."
out = similar(index, T)
@inbounds for x = CartesianIndices(out)
tup = collect(Tuple(x))
tup[dims] = index[x]
view(out, x) .= view(input, tup...)
end
return out
end
Why does this line mutates an array? Where is Zygote.Buffer
?
view(out, x) .= view(input, tup...)