Automatic differentiation of complex functions: simple workaround

I would like to share a simple technique to perform automatic differentiation (AD) of complex functions. I developed this technique a few months ago and have been using it successfully for past few months. The technique works for holomorphic functions, and it uses ForwardDiff.jl.

As an example, consider the following complex function f:

julia> f(z) = 1 / (z - (3+4im))
f (generic function with 1 method)

julia> f(2+2im)
-0.2 + 0.4im

The function is holomorphic in except for z = 3+i4. However, neither ForwardDiff.jl nor Zygote.jl calculates its derivative successfully:

julia> using ForwardDiff: derivative

julia> derivative(f, 2+2im)
ERROR: MethodError: no method matching derivative(::typeof(f), ::Complex{Int64})
...

julia> using Zygote

julia> gradient(f, 2+2im)
ERROR: Output is complex, so the gradient is not defined.
...

On the other hand, the hand-derived derivative evaluates without any problem:

julia> f′hand(z) = -1 / (z - (3+4im))^2
f′hand (generic function with 1 method)

julia> f′hand(2+2im)
0.12 + 0.16im

If a complex-valued function’s domain is , we can calculate its derivative by differentiating the real and imaginary parts of the function separately using ForwardDiff.jl, because the real and imaginary parts of f : ℝ → ℂ are ℝ → ℝ, which can be differentiated by ForwardDiff.jl. Here is the application of this technique to the above f for real z:

julia> f′realz(z) = derivative(z->real(f(z)), z) + im * derivative(z->imag(f(z)), z)
f′realz (generic function with 1 method)

julia> f′realz(2) ≈ f′hand(2)  # works for real z
true

julia> f′realz(2+2im) ≈ f′hand(2+2im)  # doesn't work for complex z
ERROR: MethodError: no method matching derivative(::var"#3#5", ::Complex{Int64})
...

Now, we push this technique one step further for holomorphic functions. An important property of the holomorphic function is that its directional derivatives are the same for all directions in the complex plane. Therefore, if f is holomorphic, f'(z) can be successfully calculated by measuring the function’s rate of change along the real-axis direction from z. In other words, if we define g(x) = f(z+x) for real x for a given complex z, then f'(z) = g'(0). Note that g : ℝ → ℂ can be differentiated by using the trick described above. The following code implements this idea:

julia> f′compz(z) = (g(x) = f(x+z); derivative(x->real(g(x)),0) + im * derivative(x->imag(g(x)),0))
f′compz (generic function with 1 method)

julia> f′compz(2+2im) ≈ f′hand(2+2im)  # works for complex z
true

I find that this is a simple and useful workaround for performing automatic differentiation of holomorphic functions by AD packages that do not support complex functions. Hopefully other people find my technique useful as well!

4 Likes

Why not:

julia> f(Dual(2,1) + im*Dual(2,0))
Dual{Nothing}(-0.2,0.12000000000000001) + Dual{Nothing}(0.4,0.16)*im
1 Like

Could you please elaborate why that works and how you picked the dual numbers?

The dual numbers still satisfy f(a + b*ε) = f(a) + b*f'(a)*ε and so I just set b to 1.

Note ForwardDiff.jl has Dual <: Real so one makes a Complex{<:Dual}. Alternatively you can use DualNumbers.jl which allows Dual{<:Complex}:

julia> f( Dual(2+im,1))
-0.09999999999999999 + 0.3im + 0.08ɛ + 0.06imɛ

Finally you can also just auto-diff through a complex number constructor:

julia> f(xy::AbstractVector) = vcat(reim(f(complex(xy...)))...)
f (generic function with 2 methods)

julia> ForwardDiff.jacobian(f, [2.,2.])
2×2 Matrix{Float64}:
 0.12  -0.16
 0.16   0.12
7 Likes

@dlfivefifty, I didn’t know about Dual! Thanks for the information. For my future reference, here is a way to get the derivative using Dual:

julia> using DualNumbers  # ForwardDiff not needed

julia> z₀ = rand(ComplexF64);  # arbitrary complex number

julia> f(Dual(z₀,1)).epsilon ≈ f′hand(z₀)
true

As a separate note, it seems that ForwardDiff.jl already has the capability to differentiate holomorphic functions via ForwardDiff.Dual. Then I am not sure why ForwardDiff.derivative() does not support complex functions yet…

In any case, developing the workaround gave me a great opportunity to think about complex analysis again. Hope other people enjoy it!

5 Likes

Is it possible to have a hessian of a higher derivative of a complex function?

We had a big thread about this a while ago: Taking Complex Autodiff Seriously in ChainRules

The TL:DR is that holomorphic functions are not the only functions which exist and are widely used, they’re actually a very very small subset of the functions that exist and are used. It would be a really bad idea if AD packages automatically assumed that functions are holomorphic. Derivatives of complex functions are still perfectly well defined even if the function is not holomorphic.

The approach that won out was just computing pullbacks correctly, so that users could either construct the full jacobian from the pullback, or a wirtinger derivative ( that is, \partial \over \partial z) if they prefer. (explanation of the math here: Taking Complex Autodiff Seriously in ChainRules - #52 by Mason)

If you do indeed want to just deal with \partial \over \partial z alone, you can do that like so:

using Zygote

function ∂z(f, z)
  _, back = Zygote.pullback(f, z)
  du, dv = back(1)[1], back(im)[1]
  (du' + im*dv')/2
end
julia> ∂z(f, 2+2im)
0.12000000000000002 + 0.16000000000000003im

Zygote.jl has docs on this here: Complex Differentiation · Zygote
And ChainRules.jl has docs on it here: Complex numbers · ChainRules

3 Likes