Taking Complex Autodiff Seriously in ChainRules

I think only a factor of 2 in memory since ChainRules.jl doesn’t need to do extra computation on the off-diagonal if the function is analytic.

Ah, I see your point now. You’re saying that I could write code that is holomorphically unstable. One branch of the function could be holomorphic and another could be non-holomorphic due to a call to conj or abs or something.

Yes, that’s an interesting and concerning issue.

That makes me lean even further to the point of view that we should just do non-holomorphic storage for everything. It’s better to be a bit slower than just get the wrong answer by assuming things are holomorphic or to just not support non-holomorphic code.

If people are confident their code is holomorphic they can always extract \partial f /\partial z at the end of the calculation using jvp, but you can’t do the other way.

I think a complex AD can have two APIs,

  1. error on non-holomorphic to save memory.
  2. assume non-holomorphic that returns an isholomorphic flag together with the output. Also, a convenient “compactify” function that turns the output to a complex Jacobian matrix if possible.
1 Like

The tricky part here is that ComplexJacobian is not commutative, which makes it kind of confusing to think about (as the output Jacobian is then a matrix of matrices), and quite prone to programming errors.

It’s only non-commutative if f was not holomorphic, but yes, I can see how that could be problematic.

Just complex numbers as structs is fine. For a C^N → R function the gradient (simply defined as the gradient wrt each of the 2N components) is the interesting thing; eg gradient descent is simply xn+1 = xn - alpha nabla f(xn) with this definition. If you want to do Newton, in the general case the gradient will not be a holomorphic function, so you need to do the full 2N real Newton.

Complex derivatives is a field where there is no clear consensus on definitions and many communities do different things, so any external API is fine as long as it’s clearly documented. Then when you know what you’re doing and what you want you can fish it out of this API.

1 Like

Such a function can never be analytic. Isn’t it better to interpret this as a R^(2N) -> R optimization problem?

Sure, but it’s still convenient to do everything that can be done in complex. For instance the gradient of 1/2 x* A x - re(b* x) is Ax - b if A is hermitian

For example, there are lots of physical systems where it is most natural to write the problem in terms of complex numbers z, even though you are computing a real function (requiring both z and its conjugate z*) at the end.

In any linear time-invariant problem it is common to work with Fourier transforms F of physical quantities and be interested in the amplitude |F| or phase arg(F). In electrical engineering, these complex amplitudes are also called phasors, and Re[F*G]/2 can be interpreted as a time average of (Re F) (Re G). For example, if you are solving Maxwell’s equations in the frequency domain, then the time-averaged power flux is Re[E*×H]/2 in terms of the electric and magnetic fields (which solve a complex linear system), and this is a common target for optimization.

In quantum mechanics, you have a complex wavefunction ψ, but you often want “observable” functions of ψ*Oψ where Ο is some Hermitian operator.

If you have set up a solver, e.g. a finite-element or spectral discretization, in the “natural” complex variables for a physical problem, it is quite awkward and inconvenient to reformulate it in terms of real and imaginary parts. Moreover, the formulation of a complex matrix in terms of an equivalent real matrix of twice the size is not unique and the wrong choice can greatly slow convergence of iterative solvers, and requires special handling of preconditioners.

Another important optimization problem that involves real functions (error norms) of “naturally” complex quantities (Fourier transforms) is filter design.

So, from my perspective it is highly desirable to support Wirtinger derivatives for functions of complex variables.

10 Likes

A few things to note.

I think that one thing to note is that, if the differentiation method is never forming Jacobians and the output is to the real numbers, then these issues are (somewhat) all mitigated. If you only do vector-Jacobian or Jacobian-vector products, then you just need the correct adjoint on the Complex constructor, and autodiff will internally be treating the numbers as real+imaginary parts throughout the user’s code, and it’ll all magically work out in the end. Wirtinger vs not is then just a choice of a standard basis [[1,0],[0,1]] vs [[1,1],[1,-1]] internal to autodiff. In the end, the choice there doesn’t matter for correctness, but it can help out with making it easy to get compiler optimizations because for all analytic functions the second term in the basis is zero, so you can stick a hard constant zero there for lots of functions that people use and then autodiff will do less work. Technically with enough smarts the compiler could get to the same spot without that hard zero, but it’s not guaranteed and in some cases it’s difficult because it can rely on mathematical assumptions that are not true in floating point (Cauchy-Riemann cannot be collapsed to zero in LLVM except in specific cases and with @fastmath). So in the sense of a reverse-mode AD, it should really just think about implementing complex->real with vjp and jvps, and Wirtingers are a good way to write the rules if you want to do less work.

However, there is this subtle extra question then of what to do about Jacobians. This comes up for me a lot because indeed stiff ODE solvers over complex need to handle it correctly. Always using the larger Jacobian is quite wasteful, and it would be nice for NLsolve.jl to be smart here. If you write your autodiff in Wirtinger form, then the nice thing that happens is that if your function is composed only of analytic functions, it should be able to constant prop that zero all the way through, in which case you could prove that some functions are analytic. In other cases you might need symbolic help, and it might still be impossible to prove it directly from the program so it would be good to have a flag from the user. If you know the function is analytic, then the Jacobian is the small form, otherwise it’s the large form and you split to the Wirtinger basis in the Newton method. But as @YingboMa mentions, using ComplexJacobian would make it difficult to use BLAS, so you’d need a special BLAS (MaBLAS!), which is why we’ve avoided doing this solution in DiffEq.

So I think that these Jacobian functions will need to be closely tied to how they’re used in solving nonlinear systems, and might as well be left out of the autodiff libraries to be built from the vjp/jvp primitives (sparsity and scarcity also just need the same primitives). This would be a mighty fun GSoC project where, if ChainRules+Zygote is all ready to go, such a complex nonlinear solver library would be quite accessible and a unique thing to build.

6 Likes

I understand why people would like to use Wirtinger derivatives for complex functions. However, I am not sure that Wirtinger is the preferred basis to do complex AD on. Since Zero() doesn’t help the compiler anymore than HolomorphicJacobian in type inference. Then, if one relies on type inference to automatically save another factor of 2 in memory, one would run into problems for functions like foo(x) = abs(x) > 2 ? conj(x) : x as I mentioned earlier.

The HolomorphicJacobian case can break in many ways so I wouldn’t focus on that. But in theory Wirtinger is still useful in the C^n → R case where you never have to build the Jacobian and just do vjps/jvps. This is because if you have a lot of analytic functions in the middle, then in Wirtinger form the vjp rules would have a lot of hardcoded Zero() (correct me if I’m wrong, but the jvp/vjp against an analytic function can be written out in the four components, so write the jvp/vjp in this basis, you have zeros in the Jacobian, so the jvp/vjp will have a zero part in its output, which could then carry through), so in theory that could cut down on the amount of calculations that have to be done. It could turn out to be a small optimization though, especially if most of the intermediate functions you’re doing AD on are not analytic.

If ChainRules.jl doesn’t do an extra computation when the system is holomorphic, won’t that save the same amount of computation on any basis? I still cannot see why is Zero special. Also, using another type Zero with other Number types also introduces type instability when the function is like ... ? holomorphic : general.

1 Like

Yes, storing and operating on \partial f / \partial z is just expensive as dealing with \partial u/\partial x and \partial v / \partial x separately, though representing \partial f / \partial z as just a complex number can have compatibility advantages, especially with BLAS.

Oh I guess so. Its just the representation of the input/output but the rule itself would be specializing internally on analytic-ness, so the Zero wouldn’t do much and we should just rely on the differentiation over structures.

Hello,

I am probably stating the obvious here, but the only holomorphic functions with zero imaginary part (or real part for that matter) are the constant functions. Since the field of complex numbers is not orered, and constant functions are not very hard to optimize :stuck_out_tongue: , the main target should be non-holomorphic functions.

Cheers!

1 Like

Oh boy, there are plenty of applications in Signal Processing. Filter design, antenna design, you name it. And at a very fundamental level: electrical (or EM) power is a real quantity and when it is derived from a complex variable implies a non-holomorphic mapping. See this nice review:

Candan, Cagatay. “Properly Handling Complex Differentiation in Optimization and Approximation Problems.” IEEE Signal Processing Magazine 36.2 (2019): 117-124.

Having Wirtinger derivatives supported would allow the whole Julia optimization ecosystem to be used for these applications.

3 Likes

I don’t have time to properly go through this right now, but it’s worth noting that JAX / Autograd discuss how they do complex numbers here and here respectively. Would be interested to have people’s thoughts on whether their conventions make sense or not.

3 Likes

I got pretty excited at first when I read

JAX is great at complex numbers and differentiation. To support both holomorphic and non-holomorphic differentiation, JAX follows Autograd’s convention for encoding complex derivatives.

but then they go on to explain that they just throw out the imaginary part of f entirely when calculating gradients, so their answer is only correct if f is either holomorphic or \mathbb{C} \rightarrow \mathbb{R}.

In my opinion, this is completely unsatisfactory and is not what I’d call being “great at complex numbers and differentiation”.

Yeah, but I’m assuming that’s just for the gradient, as the gradient only has a sensible definition when the function is holomorphic. My interpretation of their explanation is that their pushforwards / pullbacks (they call them jvps and jvps repsectively) don’t assume things to be holomorphic, meaning that they should be able to do all of the things that you would expect. I’ve not verified that though.

1 Like