Apparently the error was because I was not sub-typing on Complex. Here is a solution which works
function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Active{<:Complex{T}}) where {T<:Real}
println("In custom augmented primal rule.")
if needs_primal(config)
primal = func.val(x.val)
else
primal = nothing
end
# Save x in tape if x will be overwritten
if overwritten(config)[2]
tape = copy(x.val)
else
tape = nothing
end
# Return an AugmentedReturn object with shadow = nothing
return AugmentedReturn(primal, nothing, tape)
end
function reverse(config::ConfigWidth{1}, ::Const{typeof(sqrt)}, dret::Active, tape, x::Active{<:Complex{T}}) where {T<:Real}
println("In custom reverse rule.")
# retrieve x value, either from original x or from tape if x may have been overwritten.
xval = overwritten(config)[2] ? tape : x.val
dx = inv(2*sqrt(xval))' * dret.val
return (dx, )
end
function test(η)
ans = sqrt(η*exp((π/4)*1im))
return abs(ans)
end
autodiff(Enzyme.Reverse, test, Active, Active(2.0 + 0im))
with output :
In custom augmented primal rule.
In custom reverse rule.
((0.35355339059327373 - 1.1496735851465466e-17im,),)