I have a function that I’m trying to reverse mode diff through that at some point needs to calculate the square root of a complex number. I have been able to define custom rules for the Forward
mode autodiff
in the following way:
function forward(func::Const{typeof(sqrt)}, ::Type{<:Duplicated}, x::Duplicated{Complex{T}}) where {T<:Real}
ret = func.val(x.val)
return Duplicated(ret, 1/(2ret) * x.dval)
end
but have been unable to do the same for the Reverse
mode. Following the custom rules tutorial, I have defined the augmented_primal
and reverse
functions like:
function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Duplicated{Complex{T}}) where {T<:Real}
println("In custom augmented primal rule.")
if needs_primal(config)
primal = func.val(x.val)
else
primal = nothing
end|
# Save x in tape if x will be overwritten
if overwritten(config)[3]
tape = copy(x.val)
else
tape = nothing
end
# Return an AugmentedReturn object with shadow = nothing
return AugmentedReturn(primal, nothing, tape)
end
function reverse(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, dret::Active, tape, x::Duplicated{Complex{T}}) where {T<:Real}
println("In custom reverse rule.")
# retrieve x value, either from original x or from tape if x may have been overwritten.
xval = overwritten(config)[3] ? tape : x.val
x.dval += inv(2func(xval)) * dret.val
return (nothing, nothing)
end
I however receive the error :
ERROR: Duplicated Returns not yet handled
when trying to execute the following test function
function test(η)
ans = sqrt(η + 0im)
return abs(ans)
end
autodiff(Enzyme.Reverse, test, Duplicated, Duplicated(2.0,1.0))
You probably want an active return here.
I tried changing the signature to
autodiff(Enzyme.Reverse, test, Active, Duplicated(2.0,1.0))
But enzyme still seemed to differentiate sqrt
without using my rule. I assume I’ve set up the method signatures incorrectly somehow
Reverse mode requires floats to be passed in via active not duplicated
Thanks. Apologies for my ignorance. I’ve tried defining the rules this way:
function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Active{Complex{T}}) where {T<:Real}
println("In custom augmented primal rule.")
if needs_primal(config)
primal = func.val(x.val)
else
primal = nothing
end
# Save x in tape if x will be overwritten
if overwritten(config)[3]
tape = copy(x.val)
else
tape = nothing
end
# Return an AugmentedReturn object with shadow = nothing
return AugmentedReturn(primal, nothing, tape)
end
function reverse(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, dret::Active, tape, x::Active{Complex{T}}) where {T<:Real}
println("In custom reverse rule.")
# retrieve x value, either from original x or from tape if x may have been overwritten.
xval = overwritten(config)[3] ? tape : x.val
x.dval += inv(2func(xval)) * dret.val
return (nothing, )
end
but enzyme is still ignoring the square root rule on evaluation
function test(η)
ans = sqrt(η + 0im)
return abs(ans)
end
autodiff(Enzyme.Reverse, test, Active, Active(2.0))
Stacktrace:
[1] |
@ ./int.jl:372
[2] ldexp
@ ./math.jl:964
[3] sqrt
@ ./complex.jl:541
[4] test
@ ~/Software/Krang.jl/examples/mwe.jl:7
Stacktrace:
[1] throwerr(cstr::Cstring)
@ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:1289
[2] |
@ ./int.jl:372 [inlined]
[3] ldexp
@ ./math.jl:964 [inlined]
[4] sqrt
@ ./complex.jl:541 [inlined]
[5] test
@ ~/Software/Krang.jl/examples/mwe.jl:7 [inlined]
[6] diffejulia_test_4396wrap
@ ~/Software/Krang.jl/examples/mwe.jl:0
[7] macro expansion
@ ~/.julia/dev/Enzyme/src/compiler.jl:5440 [inlined]
[8] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Active{…}, ::Float64)
@ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5118
[9] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Active{…}, ::Vararg{…})
@ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5000
[10] autodiff
@ ~/.julia/dev/Enzyme/src/Enzyme.jl:0 [inlined]
[11] autodiff(mode::ReverseMode{false, FFIABI, false}, f::typeof(test), ::Type{Active}, args::Active{Float64})
@ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:287
Apparently the error was because I was not sub-typing on Complex
. Here is a solution which works
function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Active{<:Complex{T}}) where {T<:Real}
println("In custom augmented primal rule.")
if needs_primal(config)
primal = func.val(x.val)
else
primal = nothing
end
# Save x in tape if x will be overwritten
if overwritten(config)[2]
tape = copy(x.val)
else
tape = nothing
end
# Return an AugmentedReturn object with shadow = nothing
return AugmentedReturn(primal, nothing, tape)
end
function reverse(config::ConfigWidth{1}, ::Const{typeof(sqrt)}, dret::Active, tape, x::Active{<:Complex{T}}) where {T<:Real}
println("In custom reverse rule.")
# retrieve x value, either from original x or from tape if x may have been overwritten.
xval = overwritten(config)[2] ? tape : x.val
dx = inv(2*sqrt(xval))' * dret.val
return (dx, )
end
function test(η)
ans = sqrt(η*exp((π/4)*1im))
return abs(ans)
end
autodiff(Enzyme.Reverse, test, Active, Active(2.0 + 0im))
with output :
In custom augmented primal rule.
In custom reverse rule.
((0.35355339059327373 - 1.1496735851465466e-17im,),)
2 Likes
Separately we’ve just added a first class complex sqrt rule on the main branch so you shouldn’t need a custom rule for this case any more
2 Likes