Are Enzyme's custom rules for sparse arrays different from regular Matrix?

I am creating custom rules for sparse arrays in Enzyme, an automatic differentiation package.
The following program is a custom rule for the computation of matrix adds of type SparseMatrixCSC.

using SparseArrays
#-----------------------------------------------------------------------------------------------------------------------------------------
#
#-----------------------------------------------------------------------------------------------------------------------------------------
function EnzymeRules.augmented_primal(config, func::Const{typeof(+)}, ::Type{RT}, A::Duplicated{SparseMatrixCSC{Float64, Int64}}, B::Duplicated{SparseMatrixCSC{Float64, Int64}}) where RT

    println("use add(sparse_matrix) augmented_primal")

    # 関数の戻り値
    res = A.val + B.val
    retres = if EnzymeRules.needs_primal(config)
        res
    else
        nothing
    end

    # 引数A
    cache_A = if EnzymeRules.overwritten(config)[2]
        copy(A.val)
    else
        nothing
    end

    # 引数B
    cache_B = if EnzymeRules.overwritten(config)[3]
        copy(B.val)
    else
        nothing
    end

    # y_bar
    cache_res = if EnzymeRules.needs_primal(config)
        copy(res)
    else
        res
    end

    # dy_bar
    dres = if EnzymeRules.width(config) == 1
        0.0 * copy(res)
    else
        ntuple(Val(EnzymeRules.width(config))) do i
            Base.@_inline_meta
            0.0 * copy(res)
        end
    end

    display(dres)

    # set data
    cache = NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4")), Tuple{typeof(res), typeof(dres), typeof(cache_A), typeof(cache_B)}}(
        (cache_res, dres, cache_A, cache_B)
    )

    return EnzymeRules.AugmentedReturn{typeof(retres), typeof(dres), typeof(cache)}(retres, dres, cache)
end
#-----------------------------------------------------------------------------------------------------------------------------------------
#
#-----------------------------------------------------------------------------------------------------------------------------------------
function EnzymeRules.reverse(config, func::Const{typeof(+)}, ::Type{RT}, cache, A::Duplicated{<:SparseMatrixCSC{Float64, Int64}}, B::Duplicated{<:SparseMatrixCSC{Float64, Int64}}) where RT

    println("use add(sparse matrix) reverse")

    y, dys, cache_A, cache_B = cache

    if !EnzymeRules.overwritten(config)[2]
        cache_A = A.val
    end

    if !EnzymeRules.overwritten(config)[3]
        cache_B = B.val
    end

    if EnzymeRules.width(config) == 1
        dys = (dys,)
    end

    dAs = if EnzymeRules.width(config) == 1
        (A.dval,)
    else
        A.dval
    end

    dBs = if EnzymeRules.width(config) == 1
        (B.dval,)
    else
        B.dval
    end

    for (dA, dB, dy) in zip(dAs, dBs, dys)
        display(dy)
        display(dA)
        dA += dy
        dB += dy
        dy = eltype(dy)(0)
    end

    return (nothing, nothing)
end

Also, this is a sample program that calculates the derivative using this custom rule.

using Enzyme
import .EnzymeRules: augmented_primal, reverse
using .EnzymeRules
using LinearAlgebra
using SparseArrays
include("add_sparse_matrix.jl")
#--------------------------------------------------------------------------------
# Function to be differentiated
#--------------------------------------------------------------------------------
function eval(x::Vector{Float64})
    A = sparse([1, 1, 2, 3], [1, 3, 2, 3], [2.0*x[2]^3.0, 1.0-x[1], 2.0+x[3], -1.0])
    B = sparse([1, 1, 2, 3], [1, 3, 2, 3], [2.0*x[2], 1.0-x[1], 2.0+x[3], -1.0])
    C = A + B
    return C[1, 1] + C[2, 2]
end
#--------------------------------------------------------------------------------
# main
#--------------------------------------------------------------------------------
x = rand(Float64, 3)
dx = [0.0, 0.0, 0.0]

val = eval(x)

# compute gradient
autodiff(Reverse, eval, Duplicated(x, dx))
@show dx 

df_numeric = similar(x)
for i in eachindex(x)
    x_tmp = x[i]
    x_p = copy(x)
    x_p[i] += 1.0e-05
    x_m = copy(x)
    x_m[i] -= 1.0e-05
    df_numeric[i] = (eval(x_p) - eval(x_m)) / (2.0 * 1.0e-05)
end

display(df_numeric)

error = norm(dx-df_numeric)
println("error = $(error)")

I managed to adjust it and it works without error, but I cannot get an accurate gradient.
For ordinary arrays (such as Matrix), the same procedure can be used to compute the exact gradient.

Do I need a special setup for sparse matrices? Am I missing something in my custom rules?

You probably want to add in place so use .+= rather than +=.

Relatedly, if this hits an error without your rule, this is something others would also likely like to have too.

Make a PR on Enzyme.jl?

I have already tried.
Execution will result in an error. It may be determined that it is due to a different sparse pattern.

ERROR: BoundsError: attempt to access 0-element Vector{Float64} at index [0]

For sparse matrices, basic calculations are not possible without custom rules.
The following will produce an error at runtime or incorrectly calculate the gradient.

I’m new here, but I’ll try to do some PR on Enzyme.jl.

#--------------------------------------------------------------------------------
# Function to be differentiated
#--------------------------------------------------------------------------------
function eval(x::Vector{Float64})
    A = sparse([1, 1, 2, 3], [1, 3, 2, 3], [2.0*x[2]^3.0, 1.0-x[1], 2.0+x[3], -1.0])
    B = sparse([1, 1, 2, 3], [1, 3, 2, 3], [2.0*x[2], 1.0-x[1], 2.0+x[3], -1.0])
    
    # Not working
    #C = A + B
    
    # Not working
    #C = A * B

    # Not working
    #C = 2.0 * A

    #return C[1, 1] + C[2, 2]

    # It works, but does not match the numerical derivative.
    b = rand(Float64, 3)
    c = b' * A * b
    return c
end
#--------------------------------------------------------------------------------
# main
#--------------------------------------------------------------------------------
x = rand(Float64, 3)
dx = [0.0, 0.0, 0.0]

val = eval(x)

# compute gradient
autodiff(Reverse, eval, Duplicated(x, dx))
@show dx 

df_numeric = similar(x)
for i in eachindex(x)
    x_tmp = x[i]
    x_p = copy(x)
    x_p[i] += 1.0e-05
    x_m = copy(x)
    x_m[i] -= 1.0e-05
    df_numeric[i] = (eval(x_p) - eval(x_m)) / (2.0 * 1.0e-05)
end

display(df_numeric)

error = norm(dx-df_numeric)
println("error = $(error)")