Zygote Performance

FWIW, Jax does that by default with the @jarrett macro, named after the one and only Jarrett.

We might as well use the true Jarrett library for this in Zygote :slight_smile:

4 Likes

Oh that’s interesting. One problem in this approach is that it is not super efficient when the variables involved in the differentiation is much fewer than the arguments of the broadcasted function. Ideally, you’d want to mark the arguments not involved in the differentiation as constants (ChainCutters.jl is my quick-and-dirty solution to it). Do you know if/how it’s handled JAX?

I opened:

https://github.com/FluxML/Zygote.jl/issues/336

2 Likes