I am new to Zygote and have a question about defining my own adjoint involving getindex
with custom indexing.
I’ve defined a type
import Base: getindex
struct ProbTensor{N}
elements::Array{Float64,N}
end
function getindex(e::ProbTensor, idx...)
sum(view(e.elements, (i==0 ? Colon() : (i+3)÷2 for i in idx)...))
end
to represent a probability distribution over N Boolean values, I’m storing false values at position 1 and true values at position 2 and I represent these values in other parts of the code as false=-1
, true=1
.
I’ve overloaded getindex
to convert the -1/+1 indices into the array positions 1/2 and I’ve overloaded an index of 0 to correspond to marginalizing over the true and false cases. E.g. (p::ProbTensor{3})[0,0,0]
should sum over all elements and return 1.0.
Now, I’m trying to define an adjoint for derivatives of the output of getindex
with respect to the stored probabilities. My first attempt at the gradient of the normalization
using Zygote: gradient, @adjoint
probs = rand(2,2); probs = probs / sum(probs)
gradient(p->p[0,0], ones(2,2))
fails by complaining about accessing element [0,0].
So, I tried to define my own adjoint
covers(idx1::Int, idx2::Int) = idx1 == 0 || (idx1==-1 && idx2==1) || (idx1==1 && idx2==2)
covers(idxs1, idxs2::Tuple) = all(covers.(idxs1,idxs2))
covers(idxs1, idxs2::CartesianIndex) = covers(idxs1, Tuple(idxs2))
@adjoint getindex(e::ProbTensor, idx...) =
(getindex(p, idx...), dp -> (begin
arr = zeros(eltype(dp), size(dp))
for i in CartesianIndices(dp)
covers(idx,i) && (arr[i] = dp[i])
end
arr
end,))
(covers checks for the non-zero derivative elements)
But I still get the same error. Can someone please help set me on the correct path?