Here’s a proposal for how to define complex chain rules:
- By default,
frule and rrule assume f to be holomorphic and return the holomorphic derivative.
- Wirtinger derivatives can be requested by wrapping the corresponding argument in a
Wirtinger wrapper type. If done so, then:
-
frule will accept a pair (dz_dw,dz_dwb) as the corresponding input derivative and return a pair (df_dw,df_dwb) as the corresponding output derivative, and
- the pullback returned by
rrule will accept a pair (dv_df,dv_dfb) as the corresponding input derivative and return a pair (dv_dz,dv_dzb) as the corresponding output derivative.
Throughout this post, a suffix b stands for “bar”, i.e. wb = \bar w.
For illustration, the chain rules for sin would be implemented as follows.
frule((_,dz_dw), ::typeof(sin), z::Number) = sin(z), cos(z)*dz_dw
rrule(::typeof(sin), z::Number) = sin(z), dv_dz->(NO_FIELDS, dv_df*cos(z))
frule((_,(dz_dw,dz_dwb)), ::typeof(sin), z::Wirtinger{<:Number}) =
sin(remove_wirtinger(z)), (Zero(), cos(remove_wirtinger(z)).*(dz_dw,dz_dwb))
rrule(::typeof(sin), z::Wirtinger{<:Number}) =
sin(remove_wirtinger(z), ((dv_df,dv_dfb),)->(NO_FIELDS, (dv_df*cos(remove_wirtinger(z)), dv_dfb*conj(cos(remove_wirtinger(z)))))
There are a couple of nice things about this approach:
- Nonholomorphic derivatives are represented as pairs of
Complex or Array{Complex}, i.e. no special numerical type needs to be introduced. This should avoid any compatibility issues with other libraries (most notably BLAS).
- We can easily extend the framework to compute \frac{\partial f}{\partial x} and \frac{\partial f}{\partial y} if that is desired.
- Complex and real derivatives can be mixed. Example:
f(a::Real,z::Complex) = real(a*z)
derivative(f, a::Real, z::Wirtinger) = (real(remove_wirtinger(z)), (a/2,a/2))
- Holomorphic derivatives are as easy as they should be.
I suspect that the last point is more important to me than to most other people, and that there will be some debate about whether or not to make holomorphic derivatives the default; hence I would like to point out that my proposal of making holomorphic derivatives the default is based not only on my personal preferences but also on some technical arguments.
As an alternative to the above, we could define that complex arguments automatically imply Wirtinger derivatives, but this leads to the following issue: it makes perfectly sense to pass Wirtinger derivatives through \mathbb{R} \to \mathbb{R} functions (e.g. when computing the derivative of a function like \log(|z|)). Without the Wirtinger argument wrapper, every frule and rrule would thus have to be prepared to propagate either a single real derivative or a pair of Wirtinger derivatives, including all \mathbb{R} \to \mathbb{R} frule's and rrules. Making Wirtinger derivatives the default would hence require that all of ChainRules is aware of Wirtinger derivatives and handles them appropriately, which I think is not reasonable.