Flux question: best way to write a function that reuses some terms

I have a question about the best (most computationally efficient) way to write functions for automatic differentiation in Flux in which some terms are reused.

The particular example I give is motivated by the paper Jauch et al “Random orthogonal matrices and the Cayley transform”, but I’m interested more generally than this.

The above paper parameterises the Steifel manifold as Q = [Q1’ Q2]’ where
Q1 =(I −A’*A+(B-B’)) / (I +A’*A−(B-B’)
Q2=2A / (I+A’*A−(B-B’))
and the matrices A and B are the ``free’’ parameters, with respect to which derivatives are to be calculated.

I can just write the Julia code as above, or I could ``signpost’’ more clearly that there are some reused terms, e.g.

X = A’*A - (B-B’)
Y = I + X
Q1 = (I - X) / Y
Q2 = 2A / Y

My question is: would the latter, more verbose, representation ``help’’ Flux compute derivatives more efficiently? Or would it not make any difference?

Thanks in advance!

Flux itself is unconcerned with how derivatives are generated, so this depends on the AD framework. To my knowledge, the default one (Zygote) doesn’t do much in the way of array operation manipulation or common subexpression elimination. As such, you’re probably better off manually extracting X.

One easy way to empirically test which is better is to @benchmark the two approaches. My hunch is the second one will at least be faster on the forwards pass.

1 Like