I would like to use Zygote to differentiate through some code where I would usually use setindex!
, e.g. A[i, j] += v
where v
is a scalar and A
a matrix. My current approach is to use a custom OneHotMatrix
struct OneHotMatrix <: AbstractMatrix{Float64}
v::Float64
i::Int
j::Int
size::Tuple{Int, Int}
end
Base.getindex(A::OneHotMatrix, i, j) = i == A.i && j == A.j ? A.v : 0.
Base.size(A::OneHotMatrix) = A.size
function ChainRulesCore.rrule(::Type{<:OneHotMatrix}, v, i, j, size)
Y = OneHotMatrix(v, i, j, size)
function OneHotMatrix_pullback(x̄)
Zero(), x̄[i, j], Zero(), Zero(), Zero()
end
Y, OneHotMatrix_pullback
end
and write A + OneHotMatrix(v, i, j, size(A))
. This works but I guess it is overly complicated. Are there simpler approaches for this situation?