Enzyme Reverse Diff rules for complex sqrt

Apparently the error was because I was not sub-typing on Complex. Here is a solution which works

function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Active{<:Complex{T}}) where {T<:Real}
    println("In custom augmented primal rule.")
    if needs_primal(config)
        primal = func.val(x.val)
    else
        primal = nothing
    end

    # Save x in tape if x will be overwritten
    if overwritten(config)[2]
        tape = copy(x.val)
    else
        tape = nothing
    end

    # Return an AugmentedReturn object with shadow = nothing
    return AugmentedReturn(primal, nothing, tape)
end

function reverse(config::ConfigWidth{1}, ::Const{typeof(sqrt)}, dret::Active, tape, x::Active{<:Complex{T}}) where {T<:Real}  
    println("In custom reverse rule.")
    # retrieve x value, either from original x or from tape if x may have been overwritten.
    xval = overwritten(config)[2] ? tape : x.val
    dx = inv(2*sqrt(xval))' * dret.val
    return (dx, )
end
function test(η)
    ans = sqrt(η*exp((π/4)*1im))
    return abs(ans)
end
autodiff(Enzyme.Reverse, test, Active, Active(2.0 + 0im))

with output :

In custom augmented primal rule.
In custom reverse rule.
((0.35355339059327373 - 1.1496735851465466e-17im,),)
2 Likes