Automatic differentiation of `cispi`

I think the rrule for cispi seems to be missing. Maybe it should be included in the basic set of rules?

function loss2(data)
    sum(abs.(cispi.(data)))
end
data = ones(2,2)
loss2(data)
gradient(loss2, data)

yields:

julia> gradient(loss2, data)
ERROR: Non-differentiable function Core.Intrinsics.copysign_float
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] macro expansion
    @ ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0 [inlined]
  [3] (::typeof(∂(λ)))(Δ::Float64)
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:43
  [4] Pullback
    @ .\floatfuncs.jl:5 [inlined]
  [5] (::typeof(∂(copysign)))(Δ::Float64)
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
  [6] Pullback
    @ .\special\trig.jl:885 [inlined]
  [7] (::typeof(∂(sincospi)))(Δ::Tuple{Float64, Float64})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
  [8] Pullback
    @ .\complex.jl:544 [inlined]
  [9] #1124
    @ ~\.julia\packages\Zygote\TaBlo\src\lib\broadcast.jl:192 [inlined]
 [10] #4
    @ .\generator.jl:36 [inlined]
 [11] iterate
    @ .\generator.jl:47 [inlined]
 [12] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Matrix{Tuple{ComplexF64, typeof(∂(cispi))}}, Matrix{ComplexF64}}}, Base.var"#4#5"{Zygote.var"#1124#1130"}})
    @ Base .\array.jl:678
 [13] map
    @ .\abstractarray.jl:2383 [inlined]
 [14] (::Zygote.var"#∇broadcasted#1129"{Tuple{Matrix{Float64}}, Matrix{Tuple{ComplexF64, typeof(∂(cispi))}}, Val{2}})(ȳ::Matrix{ComplexF64})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\lib\broadcast.jl:192
 [15] #4008#back
    @ ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59 [inlined]
 [16] #209
    @ ~\.julia\packages\Zygote\TaBlo\src\lib\lib.jl:203 [inlined]
 [17] #1746#back
    @ ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59 [inlined]
 [18] Pullback
    @ .\broadcast.jl:1309 [inlined]
 [19] (::typeof(∂(broadcasted)))(Δ::Matrix{ComplexF64})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
 [20] Pullback
    @ .\REPL[52]:2 [inlined]
 [21] (::typeof(∂(loss2)))(Δ::Float64)
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
 [22] (::Zygote.var"#46#47"{typeof(∂(loss2))})(Δ::Float64)
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface.jl:41
 [23] gradient(f::Function, args::Matrix{Float64})
    @ Zygote ~\.julia\packages\Zygote\TaBlo\src\compiler\interface.jl:76
 [24] top-level scope
    @ REPL[54]:1

A suggestion could be:

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(cispi), dat)
    Y = cispi(dat)
    function cispi_pullback(outer_grad)
        return (NoTangent(), outer_grad .* conj.(1im .* pi .* Y))
    end 
    return Y, cispi_pullback
end

but I am not sure, if this is fully efficient wrt. broadcasting, if applied to large arrays pointwise.

1 Like

Indeed there is no rule. I think in ChainRules this wants to be @scalar_rule, would be good to have.

It does seem to work with ForwardDiff, which Zygote presently uses for (most) real but not complex-valued broadcasting:

julia> Zygote.gradient(x -> sum((real∘cispi).(x)), [0 0.5; 1 1.5])
([-0.0 -3.141592653589793; -0.0 3.141592653589793],)

julia> ForwardDiff.gradient(x -> sum(real.(cispi.(x))), [0 0.5 1 1.5])
1×4 Matrix{Float64}:
 0.0  -3.14159  0.0  3.14159

It could be traced back to the missing sincospi rule, which I added (with Felix’s help) to the ChainRules.jl code. See this pull request.

1 Like

Would also be good to open an issue about a rule for copysign since that also shouldn’t be failing here.