Hi,
I am currently trying to use Enzyme to get the derivative of a function that takes a multidimensional array of 3x3 matrices and outputs a scalar. The caveat is that these matrices are elements of the Lie-group of special unitary matrices SU(3) and therefore the derivatives w.r.t. each element in the array should be in the corresponding algebra (traceless anti-Hermitian matrices).
Since a MWE would be too long to post here, I will just show what the function in question roughly looks like:
using LinearAlgebra
using StaticArrays
remultr(args...) = real(tr(*(args...)))
function plaquette_sum(U::Array{SMatrix{3,3,ComplexF64,9}, 5})
p = 0.0
for site in CartesianIndices(size(U)[2:end])
for μ in 1:3
for ν in μ+1:4
p += remultr(U[μ,site], U[ν,site], U[μ,site], U[ν,site])
end
end
end
return p
end
I was able to write a reverse-mode rule for remultr
following the example in the docs, that works well enough for now:
using Enzyme
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules
# Some functionality needed for the definition of the Lie-derivative
@inline function Base.circshift(shift::Integer, args::Vararg{T, N}) where {N,T}
j = mod1(shift, N)
ntuple(k -> args[k-j+ifelse(k>j,0,N)], Val(N))
end
# following the example from the docs
function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(remultr)},
::Type{<:Active}, args::Vararg{Active,N}) where {N}
argvals = ntuple(i -> args[i].val, Val(N))
if needs_primal(config)
primal = func.val(argvals...)
else
primal = nothing
end
if overwritten(config)[3]
tape = copy(argvals)
else
tape = nothing
end
return AugmentedReturn(primal, nothing, tape)
end
function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(remultr)},
dret::Active, tape, args::Vararg{Active,N}) where {N}
argvals = ntuple(i -> args[i].val, Val(N))
dargs = ntuple(Val(N)) do i
0.5traceless_antihermitian(*(circshift(i-1, argvals...)...))
end
return dargs
end
matrices = [mat1, mat2, mat3, mat4] # vector of random special unitary matrices
annotated_matrices = ntuple(i -> Active(matrices[i]), Val(length(matrices)))
der = autodiff(Reverse, f, Active, annotated_matrices...) # works as wanted
But if I now try to get the gradient of the plaquette_sum
function, I get an error that I can’t get to the source of:
eye3 = one(SMatrix{3, 3, ComplexF64, 9})
zero3 = zero(SMatrix{3, 3, ComplexF64, 9})
U = Array{SMatrix{3, 3, ComplexF64, 9}, 5}(undef, 4, 4, 4, 4, 4); fill!(U, eye3);
dU = similar(U); fill!(dU, zero3);
autodiff(Reverse, f, Active, DuplicatedNoNeed(U, dU)) # AssertionError: !(overwritten[end])
This might be a bit much, but any help on either getting rid of the error or even finding a different approach to my problem (e.g. using something other than Enzyme) would be greatly appreciated.