Taking Complex Autodiff Seriously in ChainRules

Ah, yes I missed that, it is an interesting way to do it. I guess we could do the gradient the way they do it but also supply a jacobian that gives the real deal.

I’m not sure I agree with

though, but I suppose if there’s a place to enforce holomorphic assumptions, gradient would be an okay place to do it.

1 Like

Distinguishing between gradients and derivatives is a very sensible thing to do.

For illustration, consider |x|:

  • Its gradient is \nabla |x| = \tfrac{x}{|x|} \in \mathbb{C} such that an expression like x - \nabla |x| makes sense (think of this as a gradient descent step).
  • Its derivative is d_x|\cdot| (\delta x) = \tfrac{1}{|x|} (\text{Re}(x) \, \text{Re}(\delta x) + \text{Im}(x) \, \text{Im}(\delta x)) \in \mathbb{R} such that an expression like |x + \delta x| = |x| + d_x | \cdot |(\delta x) + \mathcal{O}(|\delta x|^2) makes sense (think of this as a Taylor series).
3 Likes

Both (co-)gradients as steepest-ascent directions and derivatives for Taylor-like expansions stem from the same thing in CR calculus, though notations vary. You can write the multidimensional “Taylor” expansion f(x+\delta) = f(x) + \bar{\delta}\cdot\nabla_\bar{x} f + \delta\cdot\nabla_x f + O(\delta^2), and if f is real-valued then \nabla_x f= \overline{\nabla_\bar{x} f} and \nabla_\bar{x} f = (\frac{\partial f}{\partial \bar{x}_1}, \ldots, \frac{\partial f}{\partial \bar{x}_n}) is the steepest-ascent direction in x. (If f is complex-valued then you can’t talk about “ascent” or “descent.”) In that real-f case, however, since f(x+\delta) = f(x) + 2\Re[\bar{\delta}\cdot\nabla_\bar{x} f] + O(\delta^2), you need to pass 2\nabla_\bar{x} f as the “gradient” to any code expecting the gradient in terms of the real and imaginary parts (e.g. an optimization routine that only knows about real vectors).

Since CR calculus is widely used and fairly standardized, I don’t think we need to get into metaphysical arguments about whether gradients are “sensible” for non-holomorphic functions. They are undeniably useful, and the only question in my mind is what interface to present.

8 Likes

Since CR calculus is widely used and fairly standardized, I don’t think we need to get into metaphysical arguments about whether gradients are “sensible” for non-holomorphic functions. They are undeniably useful, and the only question in my mind is what interface to present.

I am starting to realise that non-holomorphic functions are important in some fields (mostly electrical engineering, afaict), but so are holomorphic functions in other fields (e.g. complex analysis and approximation theory, which is where I come from). I have a feeling that the discussion so far heavily favours the former. I can imagine that this is somewhat sensible since there are many more electrical engineers than there are people working on complex analysis, but I do hope we can find a framework which supports non-holomorphic derivatives without breaking (or at least severely complicating) the holomorphic case.

The challenge then seems to be that for f : \mathbb{C} \to \mathbb{C}, we want f' : \mathbb{C} \to \mathbb{C} in the holomorphic case, but f' : \mathbb{C} \to \mathbb{C}^2 or f' :\mathbb{C} \to \mathbb{R}^{2 \times 2} or any other variation of this theme in the non-holomorphic case. I still believe that the “default” (e.g. in an expression like f'(x)) should be to assume that f(x) is holomorphic and return the holomorphic derivative. If not, I cannot write generic code for dealing with a function f(z::T) -> T where T <: Number, since then f'(z) would be of type T if T <: Real but f'(z) would be something else if T <: Complex.

Actually, Autograd’s convention, Wirtinger derivatives and holomorphic derivatives are equivalent in all cases where the various derivatives are defined. So the only thing that is needed is a system for turning on and off whether to compute and return \tfrac{\partial f}{\partial \bar z}.

As far as I can tell, the only clean way of implementing this is to have frule and rrule return \nabla_z f, and then provide separate wirtinger_frule and wirtinger_rrule functions which return (\nabla_z f, \nabla_{\bar z} f).

Apologies, I should have been more precise – I definitely wasn’t trying to engage in a metaphysical debate :slight_smile: . I was just recapitulating the point that you generally need four real numbers to talk about the derivative of a \mathbb{C} \to \mathbb{C} function f at a point, rather than two.

I am definitely not debating is the need for non-holomorphic Chain Rules – the need for such rules is very clear. Moreover I completely agree with the idea that an AD tool really has to assume that everything is non-holomorphic by default, while potentially allowing a user to specify that they believe the function they’re asking for the pushforward / pullback of if holomorphic, and that they would therefore like an answer than assumes that (such an assumption probably won’t be able to avoid much computation. As has been discussed previously in this thread, it’s perfectly possible for a holomorphic function to be a composition of non-holomorphic functions).

@ChrisRackauckas made the central point here, which I’m going to rephrase slightly: it’s helpful to forget about derivatives and gradients for now, and focus on pushforwards / jacobian-vector products and pullbacks / vector-jacobian products. These are the fundamental operations with which an AD tool / set of rules for an AD tool (ChainRules) is concerned, as gradients / derivatives follow from these. The criteria for choosing how they should be defined should probably consider the implications for special cases of interest, such as for holomorphic \mathbb{C} \to \mathbb{C} functions, and the interface that an AD tool implement for the gradient of a \mathbb{C}^N \to \mathbb{C} or \mathbb{C}^N \to \mathbb{R} (implemented in terms of a pullback / vector-jacobian product).

As has already been discussed, one way of tackling these is just by treating complex numbers and their tangents as length-2 vectors, and just naively deriving everything from there. Autograd / JAX don’t appear to do this completely naively, which is why I linked to their conventions. So my question here is whether or not these are sufficient / ideal, or whether there is room for improvement. Is there some tangible advantage to the choices that JAX / Autograd make? From what Chris discusses above, perhaps it doesn’t really matter :man_shrugging:.

Again, I’ve not had time to properly think through everything here, but it’s important that the discussion is centred around how the operations that ChainRules / AD tools actually implement should be defined.

5 Likes

In a case \mathbb{C}^m \to \mathbb{R}, intermediate functions will generically be of the form \mathbb{C}^p \to \mathbb{C}^s. The pullbacks will take \mathbb{C}^s \to \mathbb{C}^p, but these don’t need any special structure; they should do the same thing as the equivalent \mathbb{R}^{2s} \to \mathbb{R}^{2p}. In such a case, one shouldn’t need a complex Jacobian, and many reverse-mode rules can be implemented the same way for \mathbb{C}^p \to \mathbb{C}^s as for \mathbb{R}^p \to \mathbb{R}^s (see e.g. Zygote’s rules).
Perhaps it’s my bias as someone who only needs Jvps for \mathbb{R}^m \to \mathbb{C}^s \to \mathbb{R}, but it would be a shame if the result of this was that rules that are perfectly fine for such complex differentiation need to be completely rewritten to support this more general case. Especially when IIUC, one can get the complex Jacobian or Wirtinger derivatives from such rules with just two backwards passes (see e.g. https://fluxml.ai/Zygote.jl/stable/complex/).

2 Likes

Here’s a proposal for how to define complex chain rules:

  1. By default, frule and rrule assume f to be holomorphic and return the holomorphic derivative.
  2. Wirtinger derivatives can be requested by wrapping the corresponding argument in a Wirtinger wrapper type. If done so, then:
  • frule will accept a pair (dz_dw,dz_dwb) as the corresponding input derivative and return a pair (df_dw,df_dwb) as the corresponding output derivative, and
  • the pullback returned by rrule will accept a pair (dv_df,dv_dfb) as the corresponding input derivative and return a pair (dv_dz,dv_dzb) as the corresponding output derivative.

Throughout this post, a suffix b stands for “bar”, i.e. wb = \bar w.

For illustration, the chain rules for sin would be implemented as follows.

frule((_,dz_dw), ::typeof(sin), z::Number) = sin(z), cos(z)*dz_dw
rrule(::typeof(sin), z::Number) = sin(z), dv_dz->(NO_FIELDS, dv_df*cos(z))

frule((_,(dz_dw,dz_dwb)), ::typeof(sin), z::Wirtinger{<:Number}) = 
    sin(remove_wirtinger(z)), (Zero(), cos(remove_wirtinger(z)).*(dz_dw,dz_dwb))
rrule(::typeof(sin), z::Wirtinger{<:Number}) = 
    sin(remove_wirtinger(z), ((dv_df,dv_dfb),)->(NO_FIELDS, (dv_df*cos(remove_wirtinger(z)), dv_dfb*conj(cos(remove_wirtinger(z)))))

There are a couple of nice things about this approach:

  • Nonholomorphic derivatives are represented as pairs of Complex or Array{Complex}, i.e. no special numerical type needs to be introduced. This should avoid any compatibility issues with other libraries (most notably BLAS).
  • We can easily extend the framework to compute \frac{\partial f}{\partial x} and \frac{\partial f}{\partial y} if that is desired.
  • Complex and real derivatives can be mixed. Example:
f(a::Real,z::Complex) = real(a*z)
derivative(f, a::Real, z::Wirtinger) = (real(remove_wirtinger(z)), (a/2,a/2))
  • Holomorphic derivatives are as easy as they should be.

I suspect that the last point is more important to me than to most other people, and that there will be some debate about whether or not to make holomorphic derivatives the default; hence I would like to point out that my proposal of making holomorphic derivatives the default is based not only on my personal preferences but also on some technical arguments.

As an alternative to the above, we could define that complex arguments automatically imply Wirtinger derivatives, but this leads to the following issue: it makes perfectly sense to pass Wirtinger derivatives through \mathbb{R} \to \mathbb{R} functions (e.g. when computing the derivative of a function like \log(|z|)). Without the Wirtinger argument wrapper, every frule and rrule would thus have to be prepared to propagate either a single real derivative or a pair of Wirtinger derivatives, including all \mathbb{R} \to \mathbb{R} frule's and rrules. Making Wirtinger derivatives the default would hence require that all of ChainRules is aware of Wirtinger derivatives and handles them appropriately, which I think is not reasonable.

1 Like

After some reflection, I think the current Zygote behaviour is the right way to go. It supports pretending everything is holomorphic, but also defines the rules correctly such that you can get out the full Jacobian if you want / need. Following @sethaxen’s link to the Zygote docs, we see that the way it works is that if you do

using Zygote
y, back = Zygote.pullback(abs2, 1.0 + im)

julia> back(1), back(im)
((2.0 + 2.0im,), (0.0 + 0.0im,))

this is the full Jacobian. back(1) is asking for (J*[1, 0])' and back(im) is asking for (J*[0, 1])', and we can then form the Wirtinger derivatives via

du, dv = back(1)[1], back(im)[1]
(du' + im*dv')/2, (du + im*dv)/2

To my surprise, this actually works on @oxinabox’s Chainrules branch of Zygote, even though ChainRules specifically defines

@scalar_rule abs2(x) 2x

Is there something I’m misunderstanding here or is Zygote somehow skipping that chain-rule on this branch?

2 Likes

Yeah, the PR skips abs and abs2 and uses Zygote’s: https://github.com/FluxML/Zygote.jl/blob/bf913a2a8ed616242e2f5378fbe598b289dd550a/src/lib/number.jl#L26-L30

2 Likes

I see why gradient(f,x) does the right thing for the cases 1) f holomorphic and 2) f(::Complex)::Real, but I don’t see why back(1),back(im) gives the Jacobian in the nonholomorphic case. What is the mathematical definition of pullback(f,x)?

The pullback of a function in Zygote is a function \mathcal{B_f}(v) = J_f^\dagger * v where v is a vector, J_f is the Jacobian of the function f and ^\dagger is the Hermitian conjugate.

Zygote will treat a struct with n fields as if it were a vector of length n, so the complex number 1 + 0im is treated like the vector [1 0] and 0 + 1im is treated like the vector [0, 1] in the above Jacobian vector product. The vector returned from this product is then converted into a struct

That is, in this case \mathcal{B}_f(1) = J_f^\dagger * [1, 0] and \mathcal{B}_f(i) = J_f ^\dagger * [0, 1]. The Jacobian here is 2x2, so this function returns a vector with two elements which are then interpreted as the real and imaginary parts of a complex number.

This is how Zygote deals with all structs, and it’s quite beautiful imo.

2 Likes

Thanks for your explanations, makes sense now.

This indeed seems to solve all issues very elegantly. The only thing I am still worried about is that holomorphic derivatives have to be implemented as

@scalar_rule f(z) conj(df(z))

It is too easy to forget the conj() and then waste hours figuring out why the resulting derivatives are not correct. I can see why doing it this way is consistent with interpreting complex derivatives as the “gradients” \tfrac{\partial u}{\partial x} + \iota \tfrac{\partial u}{\partial y}, but maybe it would nevertheless be easier to define complex derivatives as \tfrac{\partial u}{\partial x} - \iota \tfrac{\partial u}{\partial y} and avoid the conj() issue.

Yep, I couldn’t workout why it wasn’t working,
and I thought our defintion was copied directly from DiffRules (it might have been),
and Zygote has special handling for abs and abs2 in its DiffRules interactions so i figured that it would also need special handling for its abs and abs2 ChainRule interactions.

1 Like

Do they? For scalar rules, both pushforwards and pullbacks can be computed by multiplying the sensitivity by the derivative of f. In Zygote’s approach to pullbacks, that derivative needs to be conjugated, but in forward mode it doesn’t. If that convention was adopted by ChainRules, you’d still want to implement the un-conjugated derivative, and then the @scalar_rule macro would conjugate in the rrule it creates. The rules themselves wouldn’t need to change. Unless I’m missing something.

1 Like

You are right. Here’s a detailed write-up.

It seems that Zygote’s convention is that

# Not ChainRules.jl syntax. 
# I merge the pullback call for notational convenience.
rrule(  
    z->u(z)+v(z)*im, 
    x+y*im, 
    dg_du+dg_dv*im
) -> dg_dx + dg_dy*im

represents the vector-matrix product

\begin{aligned} \begin{pmatrix} \tfrac{dg}{dx} & \tfrac{dg}{dy} \end{pmatrix} &= \begin{pmatrix} \tfrac{dg}{du} & \tfrac{dg}{dv} \end{pmatrix} \begin{pmatrix} \tfrac{du}{dx} & \tfrac{du}{dy} \\ \tfrac{dv}{dx} & \tfrac{dv}{dy} \\ \end{pmatrix} \\&= \begin{pmatrix} \tfrac{dg}{du} \tfrac{du}{dx} + \tfrac{dg}{dv} \tfrac{dv}{dx} & \tfrac{dg}{du} \tfrac{du}{dy} + \tfrac{dg}{dv} \tfrac{dv}{dy} \end{pmatrix} . \end{aligned}

The analogous definition for frule would then be that

# Again not ChainRules.jl syntax. 
# I rearranged arguments.
frule(  
    z->u(z)+v(z)*im, 
    x+y*im, 
    dx_dt+dy_dt*im
) -> du_dt + dv_dt*im

represents the matrix-vector product

\begin{aligned} \begin{pmatrix} \tfrac{du}{dt} \\ \tfrac{dv}{dt} \end{pmatrix} &= \begin{pmatrix} \tfrac{du}{dx} & \tfrac{du}{dy} \\ \tfrac{dv}{dx} & \tfrac{dv}{dy} \\ \end{pmatrix} \begin{pmatrix} \tfrac{dx}{dt} \\ \tfrac{dy}{dt} \end{pmatrix} \\&= \begin{pmatrix} \tfrac{du}{dx} \tfrac{dx}{dt} + \tfrac{du}{dy} \tfrac{dy}{dt} \\ \tfrac{dv}{dx} \tfrac{dx}{dt} + \tfrac{dv}{dy} \tfrac{dy}{dt} \end{pmatrix} . \end{aligned}

If f(x+\iota y) = u(x,y) + \iota v(x,y) is holomorphic, then we want to implement these linear-algebra products as complex multiplication with \tfrac{df}{dz} = \tfrac{du}{dx} - \iota \tfrac{du}{dy}. We have

\begin{aligned} \tfrac{dg}{dx} + \iota \tfrac{dg}{dy} &= \tfrac{dg}{du} \tfrac{du}{dx} + \tfrac{dg}{dv} \tfrac{dv}{dx} + \iota \left( \tfrac{dg}{du} \tfrac{du}{dy} + \tfrac{dg}{dv} \tfrac{dv}{dy} \right) \\&= \tfrac{dg}{du} \tfrac{du}{dx} - \tfrac{dg}{dv} \tfrac{dv}{dx} + \iota \left( \tfrac{dg}{du} \tfrac{du}{dy} + \tfrac{dg}{du} \tfrac{dv}{dx} \right) \\&= \left( \tfrac{dg}{du} + \iota \tfrac{dg}{dv} \right) \left( \tfrac{du}{dx} + \iota \tfrac{du}{dy} \right) , \end{aligned}
\begin{aligned} \tfrac{du}{dt} + \iota \tfrac{du}{dt} &= \tfrac{du}{dx} \tfrac{dx}{dt} + \tfrac{du}{dy} \tfrac{dy}{dt} + \iota \left( \tfrac{dv}{dx} \tfrac{dx}{dt} + \tfrac{dv}{dy} \tfrac{dy}{dt} \right) \\&= \tfrac{du}{dx} \tfrac{dx}{dt} + \tfrac{du}{dy} \tfrac{dy}{dt} + \iota \left( - \tfrac{du}{dy} \tfrac{dx}{dt} + \tfrac{du}{dx} \tfrac{dy}{dt} \right) \\&= \left( \tfrac{du}{dx} - \iota \tfrac{du}{dy} \right) \left( \tfrac{dx}{dt} + \iota \tfrac{dy}{dt} \right) . \end{aligned}

Thus rrule can indeed be implemented as multiplication with \overline{\tfrac{df}{dz}}, while frule can be implemented as multiplication with \tfrac{df}{dz}.