I am trying to create custom rules for Enzyme.
I looked at the examples on Enzyme’s official website, but unfortunately could not fully understand them due to difficulty.
As a very simple example to try, I created a program with a custom rule for matrix products, but it also fails.
I need advice on how to solve the following program’s bad points.
I think a few simple examples would help everyone understand.
using Enzyme
import .EnzymeRules: augmented_primal, reverse
using .EnzymeRules
#--------------------------------------------------------------------------------
# Function to which we want to apply custom rules
# matrix-matrix product
#--------------------------------------------------------------------------------
function g(A::Matrix{Float64}, B::Matrix{Float64})
return A * B
end
#--------------------------------------------------------------------------------
# Enzyme custom rules
#--------------------------------------------------------------------------------
function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(g)}, ::Type{<:Active}, A::Duplicated, B::Duplicated)
println("In custom augmented primal rule.")
# Compute primal
primal = func.val(A.val, B.val)
# Return an AugmentedReturn object
return AugmentedReturn(primal, A.val, B.val)
end
#--------------------------------------------------------------------------------
# Enzyme custom rules
#--------------------------------------------------------------------------------
function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(g)}, dC::Active, A::Duplicated, B::Duplicated)
println("In custom reverse rule.")
# dA
A.dval = dC.val * B'.val
# dB
B.dval = A'.val * dC.val
#
return (nothing, nothing)
end
#--------------------------------------------------------------------------------
# Function to be differentiated
#--------------------------------------------------------------------------------
function eval(x::Vector{Float64})
n = 3
A = Matrix{Float64}(undef, n, n)
for i in 1 : n
for j in 1 : n
A[i, j] = 2.0 * x[i] - 3.0 * x[j]
end
end
B = Matrix{Float64}(undef, n, n)
for i in 1 : n
for j in 1 : n
B[i, j] = 2.0 * i - 3.0 * j
end
end
# matrix-matrix product
C = g(A, B)
return C[1, 1] + C[2, 2]
end
#--------------------------------------------------------------------------------
# main
#--------------------------------------------------------------------------------
x = [3.0, 1.0, 2.0]
dx = [0.0, 0.0, 0.0]
# compute gradient
autodiff(Reverse, eval, Duplicated(x, dx))
@show dx