Here’s a proposal for how to define complex chain rules:
- By default,
frule
and rrule
assume f
to be holomorphic and return the holomorphic derivative.
- Wirtinger derivatives can be requested by wrapping the corresponding argument in a
Wirtinger
wrapper type. If done so, then:
-
frule
will accept a pair (dz_dw,dz_dwb)
as the corresponding input derivative and return a pair (df_dw,df_dwb)
as the corresponding output derivative, and
- the pullback returned by
rrule
will accept a pair (dv_df,dv_dfb)
as the corresponding input derivative and return a pair (dv_dz,dv_dzb)
as the corresponding output derivative.
Throughout this post, a suffix b
stands for “bar”, i.e. wb
= \bar w.
For illustration, the chain rules for sin
would be implemented as follows.
frule((_,dz_dw), ::typeof(sin), z::Number) = sin(z), cos(z)*dz_dw
rrule(::typeof(sin), z::Number) = sin(z), dv_dz->(NO_FIELDS, dv_df*cos(z))
frule((_,(dz_dw,dz_dwb)), ::typeof(sin), z::Wirtinger{<:Number}) =
sin(remove_wirtinger(z)), (Zero(), cos(remove_wirtinger(z)).*(dz_dw,dz_dwb))
rrule(::typeof(sin), z::Wirtinger{<:Number}) =
sin(remove_wirtinger(z), ((dv_df,dv_dfb),)->(NO_FIELDS, (dv_df*cos(remove_wirtinger(z)), dv_dfb*conj(cos(remove_wirtinger(z)))))
There are a couple of nice things about this approach:
- Nonholomorphic derivatives are represented as pairs of
Complex
or Array{Complex}
, i.e. no special numerical type needs to be introduced. This should avoid any compatibility issues with other libraries (most notably BLAS).
- We can easily extend the framework to compute \frac{\partial f}{\partial x} and \frac{\partial f}{\partial y} if that is desired.
- Complex and real derivatives can be mixed. Example:
f(a::Real,z::Complex) = real(a*z)
derivative(f, a::Real, z::Wirtinger) = (real(remove_wirtinger(z)), (a/2,a/2))
- Holomorphic derivatives are as easy as they should be.
I suspect that the last point is more important to me than to most other people, and that there will be some debate about whether or not to make holomorphic derivatives the default; hence I would like to point out that my proposal of making holomorphic derivatives the default is based not only on my personal preferences but also on some technical arguments.
As an alternative to the above, we could define that complex arguments automatically imply Wirtinger derivatives, but this leads to the following issue: it makes perfectly sense to pass Wirtinger derivatives through \mathbb{R} \to \mathbb{R} functions (e.g. when computing the derivative of a function like \log(|z|)). Without the Wirtinger
argument wrapper, every frule
and rrule
would thus have to be prepared to propagate either a single real derivative or a pair of Wirtinger derivatives, including all \mathbb{R} \to \mathbb{R} frule'
s and rrule
s. Making Wirtinger derivatives the default would hence require that all of ChainRules
is aware of Wirtinger derivatives and handles them appropriately, which I think is not reasonable.