I’m having trouble getting the derivative of a function which uses a complex variable internally. The key function is Di(x,y,T) = real(expintx(complex(-(1+T^2),(x+y*T)*√(1+T^2))))
which Zygote is happy to deal with, but it’s too slow for my application.
ForwardDiff should be faster, but I don’t see how to work around the complex type inside the exponential integral.
Note that I the derivative function is simple expintx'(z) = Ω-inv(z)
, but I don’t know how to set up a rule for this to bypass the problem.
using FastGaussQuadrature,SpecialFunctions
Tgl65, wgl65 = gausslegendre(65);
Di(x,y,T) = real(expintx(complex(-(1+T^2),(x+y*T)*√(1+T^2))))
D(x,y,w=wgl65,T=Tgl65) = sum(w[i]*Di(x,y,T[i]) for i in eachindex(w,T))
D(0.1,0.1) # working
x,y=-10:0.1:5,-4:0.1:4
@time @. D(x',y) # working
using Zygote
Dx_zygote(x,y) = gradient(x->D(x,y),x)[1]
Dx_zygote(0.1,0.1) # working
@time @. Dx_zygote(x',y) # working, but ~10x slower than D (due to 4GiB memory overhead?)
using ForwardDiff
Dx_fd(x,y) = ForwardDiff.derivative(x->D(x,y),x)
Dx_fd(0.1,0.1) #ERROR: MethodError: no method matching expintx(::Complex{ForwardDiff.Dual...
ForwardDiff.derivative(expintx, 1.0) # ERROR: MethodError: no method matching expintx(::ForwardDiff.Dual...
using ForwardDiffChainRules
@ForwardDiff_frule SpecialFunctions.expintx(x::ForwardDiff.Dual) # expintx'(z) = Ω-inv(z)
ForwardDiff.derivative(expintx, 1.0) # working
ForwardDiff.derivative(expintx, 1.0+1.0im) # ERROR: DimensionMismatch: derivative(f, x) expects that x is a real number (does not support Wirtinger derivatives)
ForwardDiff.derivative(x->real(expintx(complex(-1,x))), 1.0) # ERROR: MethodError: no method matching expintx(::Complex{ForwardDiff.Dual...