Mostly. Just broadcasting won’t work as well as vmap
, however, because some of the operations being broadcasted are already vectorized (e.g. BLAS). vmap
will actually modify those calls (using dispatch in PyTorch and source code transforms in JAX) to use batched implementations whenever it encounters them.
Now that said, some functions and Flux layers are flexible enough to already work for this without a vmap
-like treatment. See this topic posted about a month ago: Flux loss with contribution gradient is slow - #5 by Jonas208. Basically, changing your loss function to compute a loss for each sample individually and then summing should be enough for a MLP. In fact, mean_batch_grad
and map_grad
currently return the exact same gradients because sum(map(x -> sum(model(xs)), xs)) == sum(model(xs))
!