Add ForwardDiff rule to allow complex arguments`

Hey @weymouth sorry I never got around to helping more with your question when you posted it on Slack. So after some trial and error, I think I was able to get this working by writing a custom method for expintx(::Complex{<:Dual}):

using ForwardDiff: value, partials, Dual, Partials

function SpecialFunctions.expintx(ϕ::Complex{<:ForwardDiff.Dual{Tag}}) where {Tag}
    # Split input into real and imaginary parts
    x, y = reim(ϕ)
    # This gives the 'finite' complex part of the input
    z = complex(value(x), value(y))

    # Calculate the finite part of the output
    Ω = expintx(z)
    # split into real and imaginary parts
    u, v = reim(Ω)

    # Ω - inv(z) is the value of expintx'(z), so we'll split that into real and imaginary parts too
    ∂u, ∂v = reim(Ω - inv(z))

    # Now lets deal with the infinitesimals from the real and imaginary parts of ϕ
    px, py = partials(x), partials(y)
    # We have something of the form (∂u + i ∂v) (px + i py)
    # Split again into real and imaginary parts
    du = Dual{Tag}(u, ∂u*px - ∂v*py)
    dv = Dual{Tag}(v, ∂v*px + ∂u*py)

    # And combine
    complex(du, dv)
end

This is a lot of typing, but is also pretty straightforward and re-usable. You just plug the frule into the part where I calculate ∂u, ∂v and that’s really it. I think this could be automated with ForwardDiffChainRules if someone put a little work into it.

This gives something for me that’s 6x faster than Zygote and produces the same answer:

julia> Dx_fwd(x, y) = ForwardDiff.derivative(x->D(x,y), x);

julia> f = @time Dx_fwd.(x',y);
  0.920399 seconds (48.94 k allocations: 1.214 MiB

julia> z = @time Dx_zygote.(x',y);
  5.967287 seconds (49.71 M allocations: 3.960 GiB, 3.16% gc time)

julia> f ≈ z
true

The speed is about the exact same as just calculating the primal so that’s a good sign:

julia> @time D.(x', y);
  0.915461 seconds (36.71 k allocations: 669.703 KiB)
2 Likes