# Enzyme Reverse Diff rules for complex sqrt

I have a function that I’m trying to reverse mode diff through that at some point needs to calculate the square root of a complex number. I have been able to define custom rules for the `Forward` mode `autodiff` in the following way:

``````function forward(func::Const{typeof(sqrt)}, ::Type{<:Duplicated}, x::Duplicated{Complex{T}}) where {T<:Real}
ret = func.val(x.val)
return Duplicated(ret, 1/(2ret) * x.dval)
end
``````

but have been unable to do the same for the `Reverse` mode. Following the custom rules tutorial, I have defined the `augmented_primal` and `reverse` functions like:

``````function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Duplicated{Complex{T}}) where {T<:Real}
println("In custom augmented primal rule.")
if needs_primal(config)
primal = func.val(x.val)
else
primal = nothing
end|

# Save x in tape if x will be overwritten
if overwritten(config)[3]
tape = copy(x.val)
else
tape = nothing
end

# Return an AugmentedReturn object with shadow = nothing
return AugmentedReturn(primal, nothing, tape)
end

function reverse(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, dret::Active, tape, x::Duplicated{Complex{T}}) where {T<:Real}
println("In custom reverse rule.")
# retrieve x value, either from original x or from tape if x may have been overwritten.
xval = overwritten(config)[3] ? tape : x.val
x.dval += inv(2func(xval)) * dret.val
return (nothing, nothing)
end
``````

I however receive the error :
`ERROR: Duplicated Returns not yet handled`
when trying to execute the following test function

``````function test(η)
ans = sqrt(η + 0im)
return abs(ans)
end
autodiff(Enzyme.Reverse, test, Duplicated, Duplicated(2.0,1.0))
``````

You probably want an active return here.

I tried changing the signature to

``````autodiff(Enzyme.Reverse, test, Active, Duplicated(2.0,1.0))
``````

But enzyme still seemed to differentiate `sqrt` without using my rule. I assume I’ve set up the method signatures incorrectly somehow

Reverse mode requires floats to be passed in via active not duplicated

Thanks. Apologies for my ignorance. I’ve tried defining the rules this way:

``````
function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Active{Complex{T}}) where {T<:Real}
println("In custom augmented primal rule.")
if needs_primal(config)
primal = func.val(x.val)
else
primal = nothing
end

# Save x in tape if x will be overwritten
if overwritten(config)[3]
tape = copy(x.val)
else
tape = nothing
end

# Return an AugmentedReturn object with shadow = nothing
return AugmentedReturn(primal, nothing, tape)
end

function reverse(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, dret::Active, tape, x::Active{Complex{T}}) where {T<:Real}
println("In custom reverse rule.")
# retrieve x value, either from original x or from tape if x may have been overwritten.
xval = overwritten(config)[3] ? tape : x.val
x.dval += inv(2func(xval)) * dret.val
return (nothing, )
end
``````

but enzyme is still ignoring the square root rule on evaluation

``````function test(η)
ans = sqrt(η + 0im)
return abs(ans)
end
autodiff(Enzyme.Reverse, test, Active, Active(2.0))

Stacktrace:
[1] |
@ ./int.jl:372
[2] ldexp
@ ./math.jl:964
[3] sqrt
@ ./complex.jl:541
[4] test
@ ~/Software/Krang.jl/examples/mwe.jl:7

Stacktrace:
[1] throwerr(cstr::Cstring)
@ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:1289
[2] |
@ ./int.jl:372 [inlined]
[3] ldexp
@ ./math.jl:964 [inlined]
[4] sqrt
@ ./complex.jl:541 [inlined]
[5] test
@ ~/Software/Krang.jl/examples/mwe.jl:7 [inlined]
[6] diffejulia_test_4396wrap
@ ~/Software/Krang.jl/examples/mwe.jl:0
[7] macro expansion
@ ~/.julia/dev/Enzyme/src/compiler.jl:5440 [inlined]
[8] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Active{…}, ::Float64)
@ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5118
[9] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Active{…}, ::Vararg{…})
@ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5000
[10] autodiff
@ ~/.julia/dev/Enzyme/src/Enzyme.jl:0 [inlined]
[11] autodiff(mode::ReverseMode{false, FFIABI, false}, f::typeof(test), ::Type{Active}, args::Active{Float64})
@ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:287
``````

Apparently the error was because I was not sub-typing on `Complex`. Here is a solution which works

``````function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Active{<:Complex{T}}) where {T<:Real}
println("In custom augmented primal rule.")
if needs_primal(config)
primal = func.val(x.val)
else
primal = nothing
end

# Save x in tape if x will be overwritten
if overwritten(config)[2]
tape = copy(x.val)
else
tape = nothing
end

# Return an AugmentedReturn object with shadow = nothing
return AugmentedReturn(primal, nothing, tape)
end

function reverse(config::ConfigWidth{1}, ::Const{typeof(sqrt)}, dret::Active, tape, x::Active{<:Complex{T}}) where {T<:Real}
println("In custom reverse rule.")
# retrieve x value, either from original x or from tape if x may have been overwritten.
xval = overwritten(config)[2] ? tape : x.val
dx = inv(2*sqrt(xval))' * dret.val
return (dx, )
end
``````
``````function test(η)
ans = sqrt(η*exp((π/4)*1im))
return abs(ans)
end
autodiff(Enzyme.Reverse, test, Active, Active(2.0 + 0im))
``````

with output :

``````In custom augmented primal rule.
In custom reverse rule.
((0.35355339059327373 - 1.1496735851465466e-17im,),)
``````
2 Likes

Separately we’ve just added a first class complex sqrt rule on the main branch so you shouldn’t need a custom rule for this case any more

2 Likes