Zygote mutating arrays is not supported

I have a problem with differentiation in Zygote. Let me start with an example of what I need to achieve:
If we have a vector like a = [0,2,4,1], and b = 2, I would like to create a vector which is a binary, with ones as the first 2 (=b) lowest elements of a and the rest be zeros. i.e. my desired output is c = [1,0,0,1] in this case. To this end, I figured that a permutation matrix is needed so that p*a = sort(a), and here is my problem:
minimal example:

first thing I came up with

function create_prem(a)
    a_sort = sortperm(a)
    a_perm = zeros(4, 4)
    for i in 1:4
        a_perm[i:i,:] = vcat(zeros(a_sort[i]-1, 1),1,zeros(4-a_sort[i], 1))'
    end
    return a_perm
end

function my(a,b)
    p = create_prem(a)
    return sum(a[1:b])
end
a = [0,2,4,1]
gradient(x -> my(x, 2), a) # Error > Mutating arrays is not supported
# this error points to the line in "for" loop. I could circumvent this issue with Buffer()

function create_prem_buffer(a)
    a_sort = sortperm(a)
    buf_perm = Buffer(zeros(4, 4))
    for i in 1:n_ev
        buf_perm[i:i,:] = vcat(zeros(a_sort[i]-1, 1),1,zeros(U, 4-a_sort[i], 1))'
    end
    return copy(buf_perm)
end

function my_buffer(a,b)
    p = create_prem_buffer(a)
    return sum(a[1:b])
end

gradient(x -> my_buffer(x, 2), a) #  Error > Mutating arrays is not supported
# this message points to the line "sortperm(a)" and I have no idea how to deal with it.

I really appreciate your help in advance.

Here’s one way:

julia> function keepn(a::AbstractVector, b::Int)
         which = sortperm_b(a, b)
         [i in which ? a[i] : zero(eltype(a)) for i in axes(a,1)]
       end;

julia> sortperm_b(a, b::Int) = sortperm(a)[1:b];

julia> Zygote.@nograd sortperm_b

julia> a = [0,2,4,1];

julia> gradient(a -> sum(keepn(a,2)), a)
([1, 0, 0, 1],)

I think Zygote ought to have @nograd sortperm globally, as that can’t have a sensible gradient, in which case there would be no need to pull this out into its own function like I did.

2 Likes

Thanks a lot. This seems to be working.