I have a problem with differentiation in Zygote. Let me start with an example of what I need to achieve:
If we have a vector like a = [0,2,4,1], and b = 2, I would like to create a vector which is a binary, with ones as the first 2 (=b) lowest elements of a and the rest be zeros. i.e. my desired output is c = [1,0,0,1] in this case. To this end, I figured that a permutation matrix is needed so that p*a = sort(a), and here is my problem:
first thing I came up with
function create_prem(a) a_sort = sortperm(a) a_perm = zeros(4, 4) for i in 1:4 a_perm[i:i,:] = vcat(zeros(a_sort[i]-1, 1),1,zeros(4-a_sort[i], 1))' end return a_perm end function my(a,b) p = create_prem(a) return sum(a[1:b]) end a = [0,2,4,1] gradient(x -> my(x, 2), a) # Error > Mutating arrays is not supported # this error points to the line in "for" loop. I could circumvent this issue with Buffer() function create_prem_buffer(a) a_sort = sortperm(a) buf_perm = Buffer(zeros(4, 4)) for i in 1:n_ev buf_perm[i:i,:] = vcat(zeros(a_sort[i]-1, 1),1,zeros(U, 4-a_sort[i], 1))' end return copy(buf_perm) end function my_buffer(a,b) p = create_prem_buffer(a) return sum(a[1:b]) end gradient(x -> my_buffer(x, 2), a) # Error > Mutating arrays is not supported # this message points to the line "sortperm(a)" and I have no idea how to deal with it.
I really appreciate your help in advance.