Julia's Broadcast vs Jax's vmap

I definitely really like being able to explicitly annotate the broadcasted dimensions — there’s a great usefulness to Julia’s broadcast permissiveness, but it’d be really cool to have an explicit mode like Jax.

We do need a dedicated struct for each slice. Once we have that, we could totally add a broadcast “rule” that transformed broadcasted dot over eachslice and a 0-dim container like ref to be a matvec. It’s quite the tiny peephole, but I think it could be reasonable if that’s causing major issues.

We also need to implement better broadcasting over generators like the sort that each slice currently returns in any case.

The mean performance issue looks to be a exasperated by a performance issue with mean itself, I think.

Thanks for putting this all together!

12 Likes