Zygote.jl: How to get the gradient of sparse matrix

Here is an example adapted from ChainRules.

using LazyArrays, ChainRulesCore, LinearAlgebra, Zygote

mydot(x, A, y) = dot(x, A, y)

function ChainRulesCore.rrule(::typeof(mydot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:Number}, y::AbstractVector{<:Number})
    z = dot(x, A, y)
    function dot_pullback(Ω̄)
        Ay = @~ A * y
        ΔΩ = unthunk(Ω̄)
        cΔΩ = conj(ΔΩ)
        dx = @~(cΔΩ .* Ay)
        ay = adjoint(y)
        dA = @~(ΔΩ .* x .* ay)
        aA = adjoint(A)
        dy = @~(ΔΩ .* (aA * x))
        return (NoTangent(), dx, dA, dy)
    end
    dot_pullback(::ZeroTangent) = (NoTangent(), ZeroTangent(), ZeroTangent(), ZeroTangent())
    return z, dot_pullback
end
julia> x = rand(200); A = rand(200, 300); y = rand(300);

julia> Zygote.pullback(mydot, x, A, y)[2](1.0)[2] |> Base.summarysize
4184

julia> Base.summarysize(x) + Base.summarysize(y)
4080
2 Likes