Julia's Broadcast vs Jax's vmap

I think this line of questioning is getting a bit off topic, as the OP issues brought up around vmap/julia broadcasting don’t relate to memory pressure.

But in any case, what you’re looking for can be done in JAX. Something like this should do the trick:

import jax.numpy as jnp

really_big_dataset = ...
aggregation = 0
for chunk in jnp.split(really_big_dataset, 50):
  really_big_result = vmap(your_func)(chunk)
  aggregation += some_kind_of_aggregation(really_big_result)

And you can swap that python for loop out for a jax.lax.fori_loop for some extra speed if you’d like! This pattern is actually used in some of the JAX example code for training loops, etc. Or you could use jax.pmap to run vmap-batched computations in parallel across multiple devices and then aggregate them together. The possibilities are endless!

2 Likes