FWIW, Jax does that by default with the @jarrett
macro, named after the one and only Jarrett.
We might as well use the true Jarrett library for this in Zygote
FWIW, Jax does that by default with the @jarrett
macro, named after the one and only Jarrett.
We might as well use the true Jarrett library for this in Zygote
Oh that’s interesting. One problem in this approach is that it is not super efficient when the variables involved in the differentiation is much fewer than the arguments of the broadcasted function. Ideally, you’d want to mark the arguments not involved in the differentiation as constants (ChainCutters.jl is my quick-and-dirty solution to it). Do you know if/how it’s handled JAX?