Hey @weymouth sorry I never got around to helping more with your question when you posted it on Slack. So after some trial and error, I think I was able to get this working by writing a custom method for expintx(::Complex{<:Dual})
:
using ForwardDiff: value, partials, Dual, Partials
function SpecialFunctions.expintx(ϕ::Complex{<:ForwardDiff.Dual{Tag}}) where {Tag}
# Split input into real and imaginary parts
x, y = reim(ϕ)
# This gives the 'finite' complex part of the input
z = complex(value(x), value(y))
# Calculate the finite part of the output
Ω = expintx(z)
# split into real and imaginary parts
u, v = reim(Ω)
# Ω - inv(z) is the value of expintx'(z), so we'll split that into real and imaginary parts too
∂u, ∂v = reim(Ω - inv(z))
# Now lets deal with the infinitesimals from the real and imaginary parts of ϕ
px, py = partials(x), partials(y)
# We have something of the form (∂u + i ∂v) (px + i py)
# Split again into real and imaginary parts
du = Dual{Tag}(u, ∂u*px - ∂v*py)
dv = Dual{Tag}(v, ∂v*px + ∂u*py)
# And combine
complex(du, dv)
end
This is a lot of typing, but is also pretty straightforward and re-usable. You just plug the frule
into the part where I calculate ∂u, ∂v
and that’s really it. I think this could be automated with ForwardDiffChainRules
if someone put a little work into it.
This gives something for me that’s 6x faster than Zygote and produces the same answer:
julia> Dx_fwd(x, y) = ForwardDiff.derivative(x->D(x,y), x);
julia> f = @time Dx_fwd.(x',y);
0.920399 seconds (48.94 k allocations: 1.214 MiB
julia> z = @time Dx_zygote.(x',y);
5.967287 seconds (49.71 M allocations: 3.960 GiB, 3.16% gc time)
julia> f ≈ z
true
The speed is about the exact same as just calculating the primal so that’s a good sign:
julia> @time D.(x', y);
0.915461 seconds (36.71 k allocations: 669.703 KiB)