I would like to share a simple technique to perform automatic differentiation (AD) of complex functions. I developed this technique a few months ago and have been using it successfully for past few months. The technique works for holomorphic functions, and it uses ForwardDiff.jl
.
As an example, consider the following complex function f
:
julia> f(z) = 1 / (z - (3+4im))
f (generic function with 1 method)
julia> f(2+2im)
-0.2 + 0.4im
The function is holomorphic in ℂ except for z = 3+i4. However, neither ForwardDiff.jl
nor Zygote.jl
calculates its derivative successfully:
julia> using ForwardDiff: derivative
julia> derivative(f, 2+2im)
ERROR: MethodError: no method matching derivative(::typeof(f), ::Complex{Int64})
...
julia> using Zygote
julia> gradient(f, 2+2im)
ERROR: Output is complex, so the gradient is not defined.
...
On the other hand, the hand-derived derivative evaluates without any problem:
julia> f′hand(z) = -1 / (z - (3+4im))^2
f′hand (generic function with 1 method)
julia> f′hand(2+2im)
0.12 + 0.16im
If a complex-valued function’s domain is ℝ, we can calculate its derivative by differentiating the real and imaginary parts of the function separately using ForwardDiff.jl
, because the real and imaginary parts of f : ℝ → ℂ are ℝ → ℝ, which can be differentiated by ForwardDiff.jl
. Here is the application of this technique to the above f
for real z
:
julia> f′realz(z) = derivative(z->real(f(z)), z) + im * derivative(z->imag(f(z)), z)
f′realz (generic function with 1 method)
julia> f′realz(2) ≈ f′hand(2) # works for real z
true
julia> f′realz(2+2im) ≈ f′hand(2+2im) # doesn't work for complex z
ERROR: MethodError: no method matching derivative(::var"#3#5", ::Complex{Int64})
...
Now, we push this technique one step further for holomorphic functions. An important property of the holomorphic function is that its directional derivatives are the same for all directions in the complex plane. Therefore, if f is holomorphic, f'(z) can be successfully calculated by measuring the function’s rate of change along the real-axis direction from z. In other words, if we define g(x) = f(z+x) for real x for a given complex z, then f'(z) = g'(0). Note that g : ℝ → ℂ can be differentiated by using the trick described above. The following code implements this idea:
julia> f′compz(z) = (g(x) = f(x+z); derivative(x->real(g(x)),0) + im * derivative(x->imag(g(x)),0))
f′compz (generic function with 1 method)
julia> f′compz(2+2im) ≈ f′hand(2+2im) # works for complex z
true
I find that this is a simple and useful workaround for performing automatic differentiation of holomorphic functions by AD packages that do not support complex functions. Hopefully other people find my technique useful as well!