Taking Complex Autodiff Seriously in ChainRules

So these conversations have been starting up recently again on the Slack #autodiff channel and in places like https://github.com/JuliaDiff/ChainRulesCore.jl/issues/159.

I think there’s a lot of misconceptions out there about the derivatives of functions of complex numbers, caused in part by confusing notation and in part by many people’s education about complex numbers focusing too heavily on the holomorphic case. However, in an AD system where we want to deal with general code, we can not limit ourselves to only talking about holomorphic functions.

I’m posting this here for now so we can discuss the ideas and use LaTeX for math, but once these ideas get discussed a bit, it’d be good to turn this conversation into a GitHub issue / PR in ChainRules.jl.

Common background

First of all, we can always treat a function of a complex input f(z) as a function of two real inputs f(x, y) where z = x +i y. We can further write f(x, y) = u(x, y) + i v(x, y) where u and v are real valued, i.e. u(x, y) and v(x, y) are the real and complex parts of f(z).

From this perspective, a \mathbb{C} \rightarrow \mathbb{C} function is a function which takes in two inputs (x,~ y) and gives two outputs (u,~ v). Hence, the first derivatives of f can be expressed if you like as a Jacobian matrix

J = \begin{pmatrix} {\partial u / \partial x} & {\partial u / \partial y} \\ {\partial v / \partial x} & {\partial v / \partial y} \end{pmatrix}

A holomorphic function is a special class of functions for which the following holds:

\frac{\partial u}{\partial x} = \frac{\partial v}{\partial y}\\ \frac{\partial u}{\partial y} = -\frac{\partial v}{\partial x}

which means that there is a symmetry (scarcity) in in the above Jacobian matrix and not all elements need to be computed. This is not true of arbitrary complex functions, but the symmetry does give optimization opportunities and makes our lives easier so I think it’s worth respecting.

Explicity, for a holomorphic function f, the Jacobian can be written

J = {\partial u \over \partial x}\begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} + {\partial u \over \partial y}\begin{bmatrix} 0 & 1\\ -1 & 0 \end{bmatrix}

Wirtinger Derivatives

People often like to define two new derivative operators (autodiff people like to call these “Wirtinger derivatives”)

{\partial \over \partial z} = {\partial \over \partial x} - i {\partial \over \partial y}\\ {\partial \over \partial \bar{z}} = {\partial \over \partial x} + i {\partial \over \partial y}

and now for a holomorphic function, f, we have \partial f(z)/ \partial \bar{z} = 0 identically. This is why people sometimes say that a holomorphic function is a function with a single degree of freedom even though it appears to live in a two dimensional space (the complex plane), and they will write non-holomorphic functions as f(z, \bar{z}) to signify that f depends not only on z, but it’s complex conjugate \bar{z}.

In terms of these Wirtinger derivatives, our above Jacobian is written

J =Re\left({\partial f \over\partial z}\right) \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} + Im\left({\partial f \over \partial z}\right) \begin{bmatrix} 0 & -1 \\ 1 & 0 \end{bmatrix} + Re\left({\partial f \over \partial \bar{z}}\right) \begin{bmatrix} 1 & 0 \\ 0 & -1 \end{bmatrix} + Im\left({\partial f \over \partial \bar{z}}\right) \begin{bmatrix} 0 & 1 \\ 1 & 0 \end{bmatrix}

There are nice geometric interpretations of this representation, but I’ll refrain.

We can do a unitary transformation on this Jacobian matrix to get

J = \begin{bmatrix} {\partial f / \partial z} \\ {\partial f / \partial \bar{z}} \end{bmatrix}

(It’s possible we really want the transpose of this thing? I’m not sure yet, open to suggestions.) Note that this is a 2 element complex vector instead of a 2x2 real matrix, but it contains the same amount of information. At this point, you might prefer to call this a gradient instead of a Jacobian. :man_shrugging:

Currently, ChainRules.jl seems to only handle the holomorphic part ({\partial f / \partial z}) of the Jacobian, but not the anti-holomorphic part ({\partial f / \partial \bar{z}}). For instance,

julia> using ChainRules

julia> rrule(sqrt, 1 + im)[2](1)
(Zero(), 0.3884434935075093 - 0.16089856322639565im)

julia> rrule(abs2, 1 + im)[2](1)
(Zero(), 2 + 2im)

If I understand correctly, this is saying that ChainRules doesn’t provide any indication that abs2 is not a holomorphic function, and in general any derivatives that are obtained from functions involving abs2 will be incorrect.

Proposal?

I think that as nice as the Wirtinger representation can be conceptually, there are real advantage of working in terms of x,~y,~u,~v. It is closer to how Julia actually represents complex numbers internally (we store the real and imaginary parts) and makes it much more explicit that we are really talking about the Jacobian of a function with two inputs and two outputs, not some scalar derivative like the Wirtiger derivative of a holomorphic function might lead you to beleive. In the x,~y,~u~v representation, the property of being Holomorphic is an example of Jacobian scarcity (@ChrisRackauckas is the person in these parts to talk to about scarcity as far as I’m aware).

However, when I started cooking up some smaple code to illustrate this proposal, I realized that writing the rrule for say cos(z::Complex) was actually quite ugly in terms of x,~y,~u,~v. I think the answer is just to provide Wirtinger style interfaces for the rules, but don’t represent that data in that way.

Here’s a toy example of the interface I’m imagining. The idea is that ComplexJacobians are basically 2x2 StaticArrays, but we also define a constructor to convert a HolomorphicJacobian to a single Complex number if that’s needed.

using StaticArrays

abstract type AbstractComplexJacobian{T} <: StaticArray{Tuple{2, 2}, T, 2} end

function Base.show(io::IO, ::MIME"text/plain", cj::AbstractComplexJacobian)
    println(io, typeof(cj))
    println(io, " ∂u_∂x=$(cj.∂u_∂x)  ∂u_∂y=$(cj.∂u_∂y)")
    println(io, " ∂v_∂x=$(cj.∂v_∂x)  ∂v_∂y=$(cj.∂v_∂y)")
end

Base.size(::AbstractComplexJacobian) = (2, 2)
Base.length(::AbstractComplexJacobian) = 4

struct HolomorphicJacobian{T} <: AbstractComplexJacobian{T}
    ∂u_∂x::T
    ∂v_∂x::T
end
Complex(hj::HolomorphicJacobian{T}) where {T<:Real} = hj.∂u_∂x + im*hj.∂v_∂x

struct ComplexJacobian{T} <: AbstractComplexJacobian{T}
    ∂u_∂x::T
    ∂v_∂x::T
    ∂u_∂y::T
    ∂v_∂y::T
end

function Base.getproperty(hj::HolomorphicJacobian, s::Symbol)
    if s === :∂u_∂y
        -hj.∂v_∂x
    elseif s === :∂v_∂y
        hj.∂u_∂x
    else
        getfield(hj, s)
    end
end

function Base.getindex(cj::AbstractComplexJacobian, i::Int)
    s = (:∂u_∂x, :∂v_∂x, :∂u_∂y, :∂v_∂y)[i]
    getproperty(cj, s)
end

function Base.getindex(cj::AbstractComplexJacobian, i::Int, j::Int)
    s = (:∂u_∂x, :∂v_∂x, :∂u_∂y, :∂v_∂y)[i + 2(j-1)]
    getproperty(cj, s)
end

function wirtinger(::Type{HolomorphicJacobian})
    function (∂f_∂z)
        HolomorphicJacobian(real(∂f_∂z), imag(∂f_∂z))
    end
end 

function wirtinger(::Type{ComplexJacobian})
    function (∂f_∂z, ∂f_∂z̄)
        ComplexJacobian(real(∂f_∂z) + real(∂f_∂z̄), 
                        imag(∂f_∂z) + imag(∂f_∂z̄),
                        imag(∂f_∂z) - imag(∂f_∂z̄),
                        real(∂f_∂z) - real(∂f_∂z̄))
    end
end 
julia> wirtinger(HolomorphicJacobian)(1 + im)
HolomorphicJacobian{Int64}
 ∂u_∂x=1  ∂u_∂y=-1
 ∂v_∂x=1  ∂v_∂y=1


julia> HolomorphicJacobian(1, 2)
HolomorphicJacobian{Int64}
 ∂u_∂x=1  ∂u_∂y=-2
 ∂v_∂x=2  ∂v_∂y=1


julia> ComplexJacobian(1, 2, 3, 4)
ComplexJacobian{Int64}
 ∂u_∂x=1  ∂u_∂y=3
 ∂v_∂x=2  ∂v_∂y=4

Here are some sample chain rules using this interface:

function ChainRules.rrule(::typeof(cos), z::Complex)
    sinz, cosz = sincos(z)
    cosz, ∂cos(Δz) = Δz * wirtinger(HolomorphicJacobian)(-sinz)
end
function ChainRules.rrule(::typeof(sin), z::Complex)
    sinz, cosz = sincos(z)
    sinz, ∂sin(Δz) = Δz * wirtinger(HolomorphicJacobian)(cosz)
end

function ChainRules.rrule(::typeof(abs2), z::Complex)
    # abs2(z)  = z*z̄
    # ∂abs2/∂z = z̄
    # ∂abs2/∂z̄ = z
    abs2(z), ∂abs2(Δz) = Δz * wirtinger(ComplexJacobian)(conj(z), z)
end

One potential advantage of this representation is that for \mathbb{C}^n \rightarrow \mathbb{C}^m functions, we can easily choose at the level of the chainrule if we want to return a Matrix{<:ComplexJacobian} or a ComplexJacobian{<:Matrix}.

Sorry for the long winded, and yet still too naive post. Let me know what you think! (cc @oxinabox)

22 Likes

It might be useful to mention a few applications where non-holomorphic / Jacobian derivatives are need. Off the top of my head, I cannot think of any.

1 Like

I feel one major problem with this kind of complex AD is that the output is awkward to use, since one cannot call BLAS/LAPACK on Matrix{<:AbstractComplexJacobian}. For instance, one cannot do Newton iteration easily.

I wonder if complex AD should error if the function is not holomorphic, and to handle the other cases, the AD package can provide convenience functions to convert a complex function to a 2N real function.

1 Like

There are a few options there. For instance, this wouldn’t be a problem if you’re able to get to a structs of arrays approach, but that can be inefficient sometimes if we’re not clever.

Perhaps this is a good argument for not using the representation I showed above and instead just returning the complex number ∂f_∂z for analytic functions and

SVector{2}(∂f_∂z, ∂f_∂z̄)

for non analytic functions? One thing I don’t particularily like about this approach is that it privileges holomorphic functions over anti-holomorphic functions (ones for which \partial f/\partial z = 0 but \partial f / \partial \bar{z} \neq 0), but I suppose few people care about such functions, and this way users at least won’t encounter weird types if they’re just taking derivatives of holomorphic functions.

Another thought though: a very common point of confusion in complex AD (and the thing that spurred me to write this post in the first place) is that reverse mode AD tools like Zygote will just give you a complex number but what users don’t realize is that they’ve actually gotten the adjoint of the thing they think they’re calculating.

By always returning an abstract array, we can make it very explicit that they’re getting an Adjoint object because we can wrap it as such.

I am not sure introducing the AbstractComplexJacobian type helps anything. If one tries to implement this, an immediate problem is that the type of the output container cannot be easily determined.

Wouldn’t this be using tremendous engineering effort to fix an edge case, with no speed gain? If the function is not complex differentiable, then splitting the system into reals is perhaps not too much to ask.

With the approach of using AbstractComplexJacobian and friends, when hitting a nonholomorphic function, you can either convert everything into a “larger” type like ComplexJacobian, or you can make the array heterogeneous which is definitely not BLAS/LAPACK compatible. Both solutions make everything type unstable, and even in the best case, when the small union optimization triggers, it will introduce many branches to the program. I don’t think some possible savings in memory justifies the cost. Worse still, it doesn’t actually save computations comparing to always use a 2 x 2 memory block (i.e. save computations whenever possible, but always use the maximum space).

I agree with you that in an ideal world, one would do this with the help of a very smart and fast compiler to make things look nicer and use generic linear algebra packages that are optimized and tuned on multiple architectures. However, I don’t think there is a way to make a “type splitting” like approach happen practically. IMO, if we want to have a complex AD now, then we should implement the “error on nonholomorphic” or the “always use ComplexJacobian (2 x 2)” approach first.

Better infrastructure for automatically building structs of arrays would be hugely beneficial for lots of stuff outside autodiff. It’d be really useful for things like LoopVectorization.jl and people doing linear algebra with custom struct types.

Why would we end up with abstract containers? The only case where that’d happen is if you had a function where it wasn’t statically knowable if it was holomorphic or not, which sounds like a very weird case.

With tons of inference issues with Cassette (maybe also many cases in Zygote?), I kind of think “everything can be theoretically be derived by the compiler is static” is also very weird. :stuck_out_tongue: Maybe I am very wrong and too pessimistic, of course.

If inference issues caused the output type to be inferred as AbstractComplexJacobian, then why would it be any better at figuring out that the output type was Complex?

Wouldn’t the current behaviour have the exact same problem?

The problem is that HolomorphicJacobian could be promoted to ComplexJacobian during the propagation. That not only changes the type of the “number” but also changes the container type of the Jacobian/gradient.

As @YingboMa hinted, this can happen if you’re trying to find extrema of a complex (non-analytic) function, say f(z) = -z'z. Current AD tools would fail to give the right answer.

There are many examples though, basically any code that does something non-analytic, i.e. asking for the real part of a complex number could be wrong, and since we are going for ‘whole-language AD’, these should be handled.

1 Like

Couldn’t you say the same about Complex if it’s interacting with other <:Numbers?

Of course.

julia> mysqrt(x) = x < 0 ? sqrt(complex(x)) : sqrt(x)
mysqrt (generic function with 1 method)

julia> foo(n) = map(mysqrt, -10:n)
foo (generic function with 1 method)

julia> foo(1)
12-element Array{Number,1}:
 0.0 + 3.1622776601683795im
 0.0 + 3.0im
 0.0 + 2.8284271247461903im
 0.0 + 2.6457513110645907im
 0.0 + 2.449489742783178im
 0.0 + 2.23606797749979im
 0.0 + 2.0im
 0.0 + 1.7320508075688772im
 0.0 + 1.4142135623730951im
 0.0 + 1.0im
     0.0
     1.0

This is analogous to functions like foo(x) = abs(x) > 2 ? conj(x) : x that can return HolomorphicJacobian or ComplexJacobian based on the input value.

Given f : \mathbb{C} \to \mathbb{C}, I do think f'(z) should be a function \mathbb{C} \to \mathbb{C} and hence undefined if f(z) is not holomorphic. This is the right thing to do e.g. for Newton’s method, Taylor series approximation, differential equations, etc.

If the Jacobian is what you want, then you can always do g(x,y) = (z = complex(x,y); [real(f(z)), imag(f(z))]) and compute the Jacobian of g. This way, no special derivative types are needed, and at least forward differentiation should work out of the box. Of course, we can provide a convenience function as_R2_fun(f) (better name suggestions welcome) to make it easier to assemble such a g.

What Mason proposed is to use ComplexJacobian or HolomorphicJacobian to hold the Jacobian information.

My concern is that this kind of breaks down when the function is vector-valued, compiler cannot reason about it efficiently, and the output container is hard to determine and hard to use.

Any C^N → R optimization problem. Also any code where there’s an intermediate step of separating real and complex, even if the final result is holomorphic.

I feel extremely strongly that the first duty of an AD code is “primum non nocere”: if you don’t know how to differentiate something, give up, but don’t silently return a wrong answer. I have zero knowledge of AD systems, but it would seem to me that it would be simplest to just treat complex number as structs (re, im) and that’s it. Interfaces can be built on top if necessary.

The set of functions resulting in non-analytic behavior is relatively small: essentially it’s conj and any function C → R (real, imag - including accessors -, abs, norm …) so for a package like ChainRules it could make sense to expose high-level functionality that assume analyticity, and hardcode the “special cases” at the low level. Any fallback should definitely not assume analyticity though.

5 Likes

No, for a general function f: \mathbb{C} \rightarrow \mathbb{C} you can easily construct non-holomorphic functions and they will have well defined things like Taylor series so long as you keep track of your degrees of freedom.

Consider f(z) = z + 2\bar{z} . This takes in a complex number, puts out a complex number. Derivatives are well defined if you use either the pair (\partial/\partial x, \partial/\partial y) or (\partial/\partial z, \partial/\partial \bar{z}) (or any other unitarily equivalent set), it’s just that this function can’t be treated as a function of a single variable like holomorphic functions can.

1 Like

My concern is that this kind of breaks down when the function is vector-valued, compiler cannot reason about it efficiently, and the output container is hard to determine and hard to use.

I fully agree. Making this work with a generic package for e.g. Newton’s method would be highly non-trivial, as far as I can tell.

I believe that essentially means always defaulting to assuming things are non-holomorphic. This is a reasonable thing to do, but it will cost a factor of 2 in performance / memory, so it’d be good to have holomorphic support if possible


Edit: I see now it was you in fact who posted on the Slack autodiff channel about this factor of two being not such a big deal. I guess we’re on the same page.

Since you have the experience with C^N → R optimization problems, how do you interpret and use a derivative of a function that’s not complex differentiable? What should an AD system return to be the most helpful for your use case?

Just treat the function like a vector in \mathbb{R}^2 for the most part (with some convenient products).

1 Like