I am creating custom rules for sparse arrays in Enzyme, an automatic differentiation package.
The following program is a custom rule for the computation of matrix adds of type SparseMatrixCSC.
using SparseArrays
#-----------------------------------------------------------------------------------------------------------------------------------------
#
#-----------------------------------------------------------------------------------------------------------------------------------------
function EnzymeRules.augmented_primal(config, func::Const{typeof(+)}, ::Type{RT}, A::Duplicated{SparseMatrixCSC{Float64, Int64}}, B::Duplicated{SparseMatrixCSC{Float64, Int64}}) where RT
println("use add(sparse_matrix) augmented_primal")
# 関数の戻り値
res = A.val + B.val
retres = if EnzymeRules.needs_primal(config)
res
else
nothing
end
# 引数A
cache_A = if EnzymeRules.overwritten(config)[2]
copy(A.val)
else
nothing
end
# 引数B
cache_B = if EnzymeRules.overwritten(config)[3]
copy(B.val)
else
nothing
end
# y_bar
cache_res = if EnzymeRules.needs_primal(config)
copy(res)
else
res
end
# dy_bar
dres = if EnzymeRules.width(config) == 1
0.0 * copy(res)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
0.0 * copy(res)
end
end
display(dres)
# set data
cache = NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4")), Tuple{typeof(res), typeof(dres), typeof(cache_A), typeof(cache_B)}}(
(cache_res, dres, cache_A, cache_B)
)
return EnzymeRules.AugmentedReturn{typeof(retres), typeof(dres), typeof(cache)}(retres, dres, cache)
end
#-----------------------------------------------------------------------------------------------------------------------------------------
#
#-----------------------------------------------------------------------------------------------------------------------------------------
function EnzymeRules.reverse(config, func::Const{typeof(+)}, ::Type{RT}, cache, A::Duplicated{<:SparseMatrixCSC{Float64, Int64}}, B::Duplicated{<:SparseMatrixCSC{Float64, Int64}}) where RT
println("use add(sparse matrix) reverse")
y, dys, cache_A, cache_B = cache
if !EnzymeRules.overwritten(config)[2]
cache_A = A.val
end
if !EnzymeRules.overwritten(config)[3]
cache_B = B.val
end
if EnzymeRules.width(config) == 1
dys = (dys,)
end
dAs = if EnzymeRules.width(config) == 1
(A.dval,)
else
A.dval
end
dBs = if EnzymeRules.width(config) == 1
(B.dval,)
else
B.dval
end
for (dA, dB, dy) in zip(dAs, dBs, dys)
display(dy)
display(dA)
dA += dy
dB += dy
dy = eltype(dy)(0)
end
return (nothing, nothing)
end
Also, this is a sample program that calculates the derivative using this custom rule.
using Enzyme
import .EnzymeRules: augmented_primal, reverse
using .EnzymeRules
using LinearAlgebra
using SparseArrays
include("add_sparse_matrix.jl")
#--------------------------------------------------------------------------------
# Function to be differentiated
#--------------------------------------------------------------------------------
function eval(x::Vector{Float64})
A = sparse([1, 1, 2, 3], [1, 3, 2, 3], [2.0*x[2]^3.0, 1.0-x[1], 2.0+x[3], -1.0])
B = sparse([1, 1, 2, 3], [1, 3, 2, 3], [2.0*x[2], 1.0-x[1], 2.0+x[3], -1.0])
C = A + B
return C[1, 1] + C[2, 2]
end
#--------------------------------------------------------------------------------
# main
#--------------------------------------------------------------------------------
x = rand(Float64, 3)
dx = [0.0, 0.0, 0.0]
val = eval(x)
# compute gradient
autodiff(Reverse, eval, Duplicated(x, dx))
@show dx
df_numeric = similar(x)
for i in eachindex(x)
x_tmp = x[i]
x_p = copy(x)
x_p[i] += 1.0e-05
x_m = copy(x)
x_m[i] -= 1.0e-05
df_numeric[i] = (eval(x_p) - eval(x_m)) / (2.0 * 1.0e-05)
end
display(df_numeric)
error = norm(dx-df_numeric)
println("error = $(error)")
I managed to adjust it and it works without error, but I cannot get an accurate gradient.
For ordinary arrays (such as Matrix), the same procedure can be used to compute the exact gradient.
Do I need a special setup for sparse matrices? Am I missing something in my custom rules?