I have a problem with differentiation in Zygote. Let me start with an example of what I need to achieve:
If we have a vector like a = [0,2,4,1], and b = 2, I would like to create a vector which is a binary, with ones as the first 2 (=b) lowest elements of a and the rest be zeros. i.e. my desired output is c = [1,0,0,1] in this case. To this end, I figured that a permutation matrix is needed so that p*a = sort(a), and here is my problem:
minimal example:
first thing I came up with
function create_prem(a)
a_sort = sortperm(a)
a_perm = zeros(4, 4)
for i in 1:4
a_perm[i:i,:] = vcat(zeros(a_sort[i]-1, 1),1,zeros(4-a_sort[i], 1))'
end
return a_perm
end
function my(a,b)
p = create_prem(a)
return sum(a[1:b])
end
a = [0,2,4,1]
gradient(x -> my(x, 2), a) # Error > Mutating arrays is not supported
# this error points to the line in "for" loop. I could circumvent this issue with Buffer()
function create_prem_buffer(a)
a_sort = sortperm(a)
buf_perm = Buffer(zeros(4, 4))
for i in 1:n_ev
buf_perm[i:i,:] = vcat(zeros(a_sort[i]-1, 1),1,zeros(U, 4-a_sort[i], 1))'
end
return copy(buf_perm)
end
function my_buffer(a,b)
p = create_prem_buffer(a)
return sum(a[1:b])
end
gradient(x -> my_buffer(x, 2), a) # Error > Mutating arrays is not supported
# this message points to the line "sortperm(a)" and I have no idea how to deal with it.
I really appreciate your help in advance.