Enzyme Reverse Diff rules for complex sqrt

I have a function that I’m trying to reverse mode diff through that at some point needs to calculate the square root of a complex number. I have been able to define custom rules for the Forward mode autodiff in the following way:

function forward(func::Const{typeof(sqrt)}, ::Type{<:Duplicated}, x::Duplicated{Complex{T}}) where {T<:Real} 
    ret = func.val(x.val)
    return Duplicated(ret, 1/(2ret) * x.dval)
end

but have been unable to do the same for the Reverse mode. Following the custom rules tutorial, I have defined the augmented_primal and reverse functions like:

function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Duplicated{Complex{T}}) where {T<:Real}
    println("In custom augmented primal rule.")
    if needs_primal(config)
        primal = func.val(x.val)
    else
        primal = nothing
    end|

    # Save x in tape if x will be overwritten
    if overwritten(config)[3]
        tape = copy(x.val)
    else
        tape = nothing
    end

    # Return an AugmentedReturn object with shadow = nothing
    return AugmentedReturn(primal, nothing, tape)
end

function reverse(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, dret::Active, tape, x::Duplicated{Complex{T}}) where {T<:Real}  
    println("In custom reverse rule.")
    # retrieve x value, either from original x or from tape if x may have been overwritten.
    xval = overwritten(config)[3] ? tape : x.val
    x.dval += inv(2func(xval)) * dret.val
    return (nothing, nothing)
end

I however receive the error :
ERROR: Duplicated Returns not yet handled
when trying to execute the following test function

function test(η)
    ans = sqrt(η + 0im)
    return abs(ans)
end
autodiff(Enzyme.Reverse, test, Duplicated, Duplicated(2.0,1.0))

You probably want an active return here.

I tried changing the signature to

autodiff(Enzyme.Reverse, test, Active, Duplicated(2.0,1.0))

But enzyme still seemed to differentiate sqrt without using my rule. I assume I’ve set up the method signatures incorrectly somehow

Reverse mode requires floats to be passed in via active not duplicated

Thanks. Apologies for my ignorance. I’ve tried defining the rules this way:


function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Active{Complex{T}}) where {T<:Real}
    println("In custom augmented primal rule.")
    if needs_primal(config)
        primal = func.val(x.val)
    else
        primal = nothing
    end

    # Save x in tape if x will be overwritten
    if overwritten(config)[3]
        tape = copy(x.val)
    else
        tape = nothing
    end

    # Return an AugmentedReturn object with shadow = nothing
    return AugmentedReturn(primal, nothing, tape)
end

function reverse(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, dret::Active, tape, x::Active{Complex{T}}) where {T<:Real}  
    println("In custom reverse rule.")
    # retrieve x value, either from original x or from tape if x may have been overwritten.
    xval = overwritten(config)[3] ? tape : x.val
    x.dval += inv(2func(xval)) * dret.val
    return (nothing, )
end

but enzyme is still ignoring the square root rule on evaluation

function test(η)
    ans = sqrt(η + 0im)
    return abs(ans)
end
autodiff(Enzyme.Reverse, test, Active, Active(2.0))



Stacktrace:
 [1] |
   @ ./int.jl:372
 [2] ldexp
   @ ./math.jl:964
 [3] sqrt
   @ ./complex.jl:541
 [4] test
   @ ~/Software/Krang.jl/examples/mwe.jl:7


Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:1289
  [2] |
    @ ./int.jl:372 [inlined]
  [3] ldexp
    @ ./math.jl:964 [inlined]
  [4] sqrt
    @ ./complex.jl:541 [inlined]
  [5] test
    @ ~/Software/Krang.jl/examples/mwe.jl:7 [inlined]
  [6] diffejulia_test_4396wrap
    @ ~/Software/Krang.jl/examples/mwe.jl:0
  [7] macro expansion
    @ ~/.julia/dev/Enzyme/src/compiler.jl:5440 [inlined]
  [8] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Active{…}, ::Float64)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5118
  [9] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Active{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5000
 [10] autodiff
    @ ~/.julia/dev/Enzyme/src/Enzyme.jl:0 [inlined]
 [11] autodiff(mode::ReverseMode{false, FFIABI, false}, f::typeof(test), ::Type{Active}, args::Active{Float64})
    @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:287

Apparently the error was because I was not sub-typing on Complex. Here is a solution which works

function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Active{<:Complex{T}}) where {T<:Real}
    println("In custom augmented primal rule.")
    if needs_primal(config)
        primal = func.val(x.val)
    else
        primal = nothing
    end

    # Save x in tape if x will be overwritten
    if overwritten(config)[2]
        tape = copy(x.val)
    else
        tape = nothing
    end

    # Return an AugmentedReturn object with shadow = nothing
    return AugmentedReturn(primal, nothing, tape)
end

function reverse(config::ConfigWidth{1}, ::Const{typeof(sqrt)}, dret::Active, tape, x::Active{<:Complex{T}}) where {T<:Real}  
    println("In custom reverse rule.")
    # retrieve x value, either from original x or from tape if x may have been overwritten.
    xval = overwritten(config)[2] ? tape : x.val
    dx = inv(2*sqrt(xval))' * dret.val
    return (dx, )
end
function test(η)
    ans = sqrt(η*exp((π/4)*1im))
    return abs(ans)
end
autodiff(Enzyme.Reverse, test, Active, Active(2.0 + 0im))

with output :

In custom augmented primal rule.
In custom reverse rule.
((0.35355339059327373 - 1.1496735851465466e-17im,),)
2 Likes

Separately we’ve just added a first class complex sqrt rule on the main branch so you shouldn’t need a custom rule for this case any more

2 Likes