Trouble writing custom rule in Enzyme: `AssertionError: !(overwritten[end])`


I am currently trying to use Enzyme to get the derivative of a function that takes a multidimensional array of 3x3 matrices and outputs a scalar. The caveat is that these matrices are elements of the Lie-group of special unitary matrices SU(3) and therefore the derivatives w.r.t. each element in the array should be in the corresponding algebra (traceless anti-Hermitian matrices).

Since a MWE would be too long to post here, I will just show what the function in question roughly looks like:

using LinearAlgebra
using StaticArrays

remultr(args...) = real(tr(*(args...)))

function plaquette_sum(U::Array{SMatrix{3,3,ComplexF64,9}, 5})
    p = 0.0

    for site in CartesianIndices(size(U)[2:end])
        for μ in 1:3
            for ν in μ+1:4
                p += remultr(U[μ,site], U[ν,site], U[μ,site], U[ν,site])

    return p

I was able to write a reverse-mode rule for remultr following the example in the docs, that works well enough for now:

using Enzyme
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules

# Some functionality needed for the definition of the Lie-derivative
@inline function Base.circshift(shift::Integer, args::Vararg{T, N}) where {N,T}
    j = mod1(shift, N)
    ntuple(k -> args[k-j+ifelse(k>j,0,N)], Val(N))

# following the example from the docs
function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(remultr)},
    ::Type{<:Active}, args::Vararg{Active,N}) where {N}
    argvals = ntuple(i -> args[i].val, Val(N))
    if needs_primal(config)
        primal = func.val(argvals...)
        primal = nothing
    if overwritten(config)[3]
        tape = copy(argvals)
        tape = nothing
    return AugmentedReturn(primal, nothing, tape)

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(remultr)},
    dret::Active, tape, args::Vararg{Active,N}) where {N}

    argvals = ntuple(i -> args[i].val, Val(N))
    dargs = ntuple(Val(N)) do i
        0.5traceless_antihermitian(*(circshift(i-1, argvals...)...))
    return dargs

matrices = [mat1, mat2, mat3, mat4] # vector of random special unitary matrices
annotated_matrices = ntuple(i -> Active(matrices[i]), Val(length(matrices)))
der = autodiff(Reverse, f, Active, annotated_matrices...) # works as wanted

But if I now try to get the gradient of the plaquette_sum function, I get an error that I can’t get to the source of:

eye3 = one(SMatrix{3, 3, ComplexF64, 9})
zero3 = zero(SMatrix{3, 3, ComplexF64, 9})
U = Array{SMatrix{3, 3, ComplexF64, 9}, 5}(undef, 4, 4, 4, 4, 4); fill!(U, eye3);
dU = similar(U); fill!(dU, zero3);
autodiff(Reverse, f, Active, DuplicatedNoNeed(U, dU)) # AssertionError: !(overwritten[end])

This might be a bit much, but any help on either getting rid of the error or even finding a different approach to my problem (e.g. using something other than Enzyme) would be greatly appreciated.

Can you past the whole error log and MWE as an issue on Enzyme.jl?

This isn’t enough context to understand what is happening, unfortunately

My bad. I just filed an issue on the Enzyme.jl github page (Trouble writing custom rule in Enzyme: `AssertionError: !(overwritten[end])` · Issue #1242 · EnzymeAD/Enzyme.jl · GitHub) with hopefully enough context.

The issue has been fixed by @wsmoses in Custom rules overwritten fix by wsmoses · Pull Request #1258 · EnzymeAD/Enzyme.jl · GitHub.