Speeding up per-sample gradients?

Mostly. Just broadcasting won’t work as well as vmap, however, because some of the operations being broadcasted are already vectorized (e.g. BLAS). vmap will actually modify those calls (using dispatch in PyTorch and source code transforms in JAX) to use batched implementations whenever it encounters them.

Now that said, some functions and Flux layers are flexible enough to already work for this without a vmap-like treatment. See this topic posted about a month ago: Flux loss with contribution gradient is slow - #5 by Jonas208. Basically, changing your loss function to compute a loss for each sample individually and then summing should be enough for a MLP. In fact, mean_batch_grad and map_grad currently return the exact same gradients because sum(map(x -> sum(model(xs)), xs)) == sum(model(xs))!