Add ForwardDiff rule to allow complex arguments`

I’m having trouble getting the derivative of a function which uses a complex variable internally. The key function is Di(x,y,T) = real(expintx(complex(-(1+T^2),(x+y*T)*√(1+T^2)))) which Zygote is happy to deal with, but it’s too slow for my application.

ForwardDiff should be faster, but I don’t see how to work around the complex type inside the exponential integral.

Note that I the derivative function is simple expintx'(z) = Ω-inv(z), but I don’t know how to set up a rule for this to bypass the problem.

using FastGaussQuadrature,SpecialFunctions
Tgl65, wgl65 = gausslegendre(65);
Di(x,y,T) = real(expintx(complex(-(1+T^2),(x+y*T)*√(1+T^2))))
D(x,y,w=wgl65,T=Tgl65) = sum(w[i]*Di(x,y,T[i]) for i in eachindex(w,T))

D(0.1,0.1) # working
x,y=-10:0.1:5,-4:0.1:4
@time @. D(x',y) # working 

using Zygote
Dx_zygote(x,y) = gradient(x->D(x,y),x)[1]
Dx_zygote(0.1,0.1) # working
@time @. Dx_zygote(x',y) # working, but ~10x slower than D (due to 4GiB memory overhead?)

using ForwardDiff
Dx_fd(x,y) = ForwardDiff.derivative(x->D(x,y),x)
Dx_fd(0.1,0.1) #ERROR: MethodError: no method matching expintx(::Complex{ForwardDiff.Dual...
ForwardDiff.derivative(expintx, 1.0) # ERROR: MethodError: no method matching expintx(::ForwardDiff.Dual...

using ForwardDiffChainRules
@ForwardDiff_frule SpecialFunctions.expintx(x::ForwardDiff.Dual) # expintx'(z) = Ω-inv(z)
ForwardDiff.derivative(expintx, 1.0) # working
ForwardDiff.derivative(expintx, 1.0+1.0im) # ERROR: DimensionMismatch: derivative(f, x) expects that x is a real number (does not support Wirtinger derivatives)
ForwardDiff.derivative(x->real(expintx(complex(-1,x))), 1.0) # ERROR: MethodError: no method matching expintx(::Complex{ForwardDiff.Dual...

Hey @weymouth sorry I never got around to helping more with your question when you posted it on Slack. So after some trial and error, I think I was able to get this working by writing a custom method for expintx(::Complex{<:Dual}):

using ForwardDiff: value, partials, Dual, Partials

function SpecialFunctions.expintx(ϕ::Complex{<:ForwardDiff.Dual{Tag}}) where {Tag}
    # Split input into real and imaginary parts
    x, y = reim(ϕ)
    # This gives the 'finite' complex part of the input
    z = complex(value(x), value(y))

    # Calculate the finite part of the output
    Ω = expintx(z)
    # split into real and imaginary parts
    u, v = reim(Ω)

    # Ω - inv(z) is the value of expintx'(z), so we'll split that into real and imaginary parts too
    ∂u, ∂v = reim(Ω - inv(z))

    # Now lets deal with the infinitesimals from the real and imaginary parts of ϕ
    px, py = partials(x), partials(y)
    # We have something of the form (∂u + i ∂v) (px + i py)
    # Split again into real and imaginary parts
    du = Dual{Tag}(u, ∂u*px - ∂v*py)
    dv = Dual{Tag}(v, ∂v*px + ∂u*py)

    # And combine
    complex(du, dv)
end

This is a lot of typing, but is also pretty straightforward and re-usable. You just plug the frule into the part where I calculate ∂u, ∂v and that’s really it. I think this could be automated with ForwardDiffChainRules if someone put a little work into it.

This gives something for me that’s 6x faster than Zygote and produces the same answer:

julia> Dx_fwd(x, y) = ForwardDiff.derivative(x->D(x,y), x);

julia> f = @time Dx_fwd.(x',y);
  0.920399 seconds (48.94 k allocations: 1.214 MiB

julia> z = @time Dx_zygote.(x',y);
  5.967287 seconds (49.71 M allocations: 3.960 GiB, 3.16% gc time)

julia> f ≈ z
true

The speed is about the exact same as just calculating the primal so that’s a good sign:

julia> @time D.(x', y);
  0.915461 seconds (36.71 k allocations: 669.703 KiB)
2 Likes

Out of curiosity, why doesn’t ForwardDiffChainRules work natively here?

It doesn’t work because when you deal with complex numbers you get a Complex{<:Dual}, not a Dual{<:Complex}.

That’s awesome. I also see virtually the same time for D and D'. Thanks so much for the help.

I am pulling this together for a new open-source class on some old school numerical methods, and I am looking forward to showcasing some modern programming features like this.

One other thing you can try is Diffractor.jl. Your example appears to not work out of the box with Diffractor.jl, but only due to an issue in ChainRules that I found: `frule` for `sum` doesn't work for `Generator` · Issue #768 · JuliaDiff/ChainRules.jl · GitHub

I was able to get your example working by just changing

D(x,y,w=wgl65,T=Tgl65) = sum(w[i]*Di(x,y,T[i]) for i in eachindex(w,T))

to

D(x,y,w=wgl65,T=Tgl65) = sum(i -> w[i]*Di(x,y,T[i]), eachindex(w,T))

and then everything works for me:

using Diffractor: DiffractorForwardBackend
using AbstractDifferentiation: derivative, jacobian

D(x,y,w=wgl65,T=Tgl65) = sum(i -> w[i]*Di(x,y,T[i]), eachindex(w,T));
Dx_diffractor(x, y) = derivative(DiffractorForwardBackend(), x -> D(x, y), x)[1];
julia> @time Dx_diffractor.(x',y);
  0.914197 seconds (244.63 k allocations: 5.320 MiB)

In general though, Diffractor may not be as mature or stable as ForwardDiff, but it has the advantage of natively supporting ChainRules and not relying on Dual types.

1 Like

It seems I can’t even get Diffractor to load…

using Diffractor: DiffractorForwardBackend  # ERROR: UndefVarError: `DiffractorForwardBackend` not defined

What version of Julia do you have an what version of Diffractor? You should use v1.10 of julia.

I think Diffractor might be downgraded by AbstractDifferentiation v0.6, so if that’s the case then do an explicit ] add Diffractor@v0.2.3

Actually it’s the other way around, Diffractor is still on AbstractDiff v0.5

Yes that’s the problem I was pointing out. If you add Diffractor and AbstractDifferentiation, the resolver might decide to give you Diffractor@0.2.0 and AbstractDifferentiation@0.6.0, but what you actually want is Diffractor@0.2.3 and AbstractDifferentiation@0.5

1 Like