ChainRules: Replacing DiffRules in the Julia AD world

If a C→C differentiation rule exists, then the function is holomorphic, so this information would be available in either case.

A @holomorphic annotation might be useful for functions that lack a differentiation rule but are still known to be holomorphic, though…

Hmm…that doesn’t sound right to me, but maybe I misunderstand you? For example, IIUC, the forward C → C rule for the non-holomorphic function conj can currently be written:

forward_rule(::@domain({C → C}), ::typeof(conj), x) = conj(x), ẋ -> (0, 1)

(ignore the churn on the domain signature syntax, I’ve been playing with a bit)

Note that we’re using Wirtinger derivatives, since they have a bunch of nice properties for defining complex primitives.

OK. I guess I was using a very narrow definition of “derivative”. Yes, it definitely makes sense to define differentiation rules for non-holomorphic complex-to-complex functions like conj(x) as well. In that case the derivative must be a two-tuple of complex numbers, like in your example.

Still, if a rule like

@forward_rule(C → C, sin(x), cos(x))

is given (where the derivative is not a two-tuple) then the function could be assumed to be holomorphic (and the remaining elements of the derivative given by the Cauchy–Riemann equations).

And that’s where the proposed @holomorphic macro would come handy.

In general, I guess one difficulty when dealing with packages like ChainRule.jl is that in math every real variable is also a complex one. But in Julia Real is not a subtype of Complex, so interfaces built for Complex types do not directly extend to Real

1 Like

At the risk of being mildly annoying: what’s the situation regarding documentation? Might we hope to see some before Christmas?

Reasonable question! Documentation definitely will arrive before Christmas.

After playing around with ChainRules in practice a bit for my own use cases, I ended up tweaking the design in a few substantial ways (some of which was based on the discussion here).

I’ve been working off of a dev branch for that, but I’d like to make a PR to master either today or tomorrow. Then I’ll start turning a lot of the code comments into docstrings. I’ll be at NeurIPS next week, but I’ll try to make some progress regardless.

5 Likes

For those following along: refactor chaining code with special propagator algebra by jrevels · Pull Request #2 · JuliaDiff/ChainRules.jl · GitHub

2 Likes

Nice. In the PR you mention maybe wanting to flip and make immediate materialisation the default. Could you elaborate a bit on why this might be a good move?

Both “eager” materialization and “delayed” materialization could be useful in different circumstances, so both will still be available. It’s just a question of which behavior serves as the more convenient default for downstream usage.

AFAICT, the main benefit to “delayed” materialization is that it enables natural fusion “across” multiple chain rule invocations. How often that optimization will be relevant/possible in downstream tools, though, remains to be seen.

Eager materialization, OTOH, is both conceptually and practically simpler API-wise. For exploratory/interactive/prototyping use, it’s annoying to type materialize all the time :stuck_out_tongue:

1 Like

AFAICT, the main benefit to “delayed” materialization is that it enables natural fusion “across” multiple chain rule invocations.

The other reason to like delayed materialisation is because it makes it possible to completely avoid unnecessary computation eg. given a reverse rule for a binary function, if the adjoint w.r.t. only one of the arguments is required you only ever materialise the required adjoint, thereby avoiding any computation associated with the other. This is a real win for “fat” nodes eg. matrix multiplication, backsolving etc.

Eager materialization, OTOH, is both conceptually and practically simpler API-wise. For exploratory/interactive/prototyping use…

Agreed, but I imagine that the majority of the usage of this package will be from people who are writing AD packages, so it’s not clear to me how much of a win this convenience would be. I’m more than happy to be proven incorrect though…

1 Like

I’ve on some occasions defined show to perform some computations that, in this case, could correspond to avoiding having to type materialize for interactive use. Hopefully I’m not misunderstanding your motivation…

Sorry, I think my earlier post wasn’t clear enough - ChainRules avoids this kind of unnecessary computation regardless of the kind of delayed vs. eager materialization I’m referring to. For example:

julia> using ChainRules: rrule, Zero, materialize, MaterializeInto

julia> using LinearAlgebra

julia> x, y = rand(3), rand(3);

julia> dotxy, (dx, dy) = rrule(LinearAlgebra.dot, x, y)
(1.2487877937850473, (getfield(ChainRules, Symbol("##720#727")){Array{Float64,1}}([0.163892, 0.817936, 0.427959]), getfield(ChainRules, Symbol("##722#729")){Array{Float64,1}}([0.694594, 0.881439, 0.967356])))

julia> dx(Zero(), 1.0)
ChainRules.Thunk{getfield(ChainRules, Symbol("##25#26")){ChainRules.Thunk{getfield(ChainRules, Symbol("##721#728")){Array{Float64,1}}},Float64}}(getfield(ChainRules, Symbol("##25#26")){ChainRules.Thunk{getfield(ChainRules, Symbol("##721#728")){Array{Float64,1}}},Float64}(ChainRules.Thunk{getfield(ChainRules, Symbol("##721#728")){Array{Float64,1}}}(getfield(ChainRules, Symbol("##721#728")){Array{Float64,1}}([0.163892, 0.817936, 0.427959])), 1.0))

# still going back and forth on the `Bundle` stuff; 
# this might just return e.g. the array in the future
julia> materialize(ans)
ChainRules.Bundle{Array{Float64,1}}([0.163892, 0.817936, 0.427959])

julia> x̄ = zeros(3)
3-element Array{Float64,1}:
 0.0
 0.0
 0.0

julia> dx(MaterializeInto(x̄), 1.0);

julia> x̄
3-element Array{Float64,1}:
 0.1638921314771622
 0.8179363645195881
 0.42795894697385894

Note that I’m never calling dy here, so I’m never “computing” the gradient w.r.t. y (in hindsight, dot was maybe a silly function for this example, but you get the point). Furthermore, since the closure dy is not being used, its lifetime can be as short as the caller needs it to be, thus allowing it to be GC’d/free’d basically instantly if the caller doesn’t need to keep it around. All of this is the case regardless of when we choose to materialize results from dx.

Maybe there are some downstream effects like that for delayed materialization as well, but I haven’t run into any such effects yet in a way that doesn’t already naturally arise from ChainRules’s closure-based API.

2 Likes

Ohhhhh yeah, you’re totally right. Thanks!

Will there be a way to test for linear functions with this? (That is: test wether the derivative is a constant, alternatively that the second derivative is zero everywhere.)

In particular, I’d like a test that does not say that abs is a linear function, even though I can evaluate its second derivative in a billion random points and get zero every time.

Bonus points if it can detect that x^2 / x is linear, but false negatives is really less of a problem.

For this kind of thing you might want to check out IntervalArithmetic.jl, which calculates with sets instead of single numbers.

1 Like

Here’s an example:

julia> using IntervalArithmetic, ForwardDiff

julia> f(x) = 3x
f (generic function with 1 method)

julia> xx = 3..4   # make an interval
[3, 4]

julia> f(xx)
[9, 12]

julia> ForwardDiff.derivative(f, xx)
[3, 3]

This gives an interval that is guaranteed to enclose the set of possible values of the derivative over the interval.

Sometimes it works:

julia> g(x) = 3x - 2x
g (generic function with 1 method)

julia> ForwardDiff.derivative(g, xx)
[1, 1]

But

julia> h(x) = x^2 / x
h (generic function with 1 method)

julia> h(xx)
[2.25, 5.33334]

julia> ForwardDiff.derivative(h, xx)
[-0.277778, 2.10417]

Here, it overestimated the range of the derivative.

Also, abs does not currently work properly (although this is easily fixed).

1 Like

Also: sign, round, etc…

Yes, indeed. But there it is not even clear what the answer should be for an Interval that contains a point of discontinuity.

Also, it would be possible to extend the “decoration” mechanism in IntervalArithmetic.jl to prove that a function is, for example, differentiable over an interval.

That sounds like a good idea, and it would make it possible to detect e.g. linear functions.

Here’s an idea:

One could introduce a new function ImpulseTrain(F, x) to represent derivatives of discontinuous functions. When called as a function, it would simply return zero(x), but for quadrature purposes it would keep track of its primitive function F.

There would be a bunch of differentiation rules like @rule(sign(x), ImpulseTrain(sign, x)), and for higher order derivatives, the rule would be something like

@rule(ImpulseTrain(F, x), (DNE(), ImpulseTrain(y->ImpulseTrain(F, y), x)) 

For interval arithmetic, one would check if the upper and lower bounds on F(x) are identical. If so, then ImpulseTrain(F, x) is zero everywhere on the interval. Otherwise the interval contains an impulse, which could maybe be reflected by returning NaN as bounds.