I spend a lot of time with jax, where I often have to employ the “double where trick” to prevent NaNs from propagating in the backward pass.
The Zygote docs mention that arrays should be immutable. So if I want to enforce some limiting values (because I’m anticipating some NaNs to pop up), then I’m assuming I need to do array-based logic, just as I do with jax (rather than looping through each index and updating the value based on some conditional).
Are there any idiomatic tricks with Zygote in this context? Should I also implement a “double where” to keep the backward pass clean? Or is the compiler smart enough to work around this?
Thanks!