Julian methods for batched computations without loops

Hi all,
I’m fairly new to Julia, coming from Python + JAX. As such, I’ve become accustomed to writing vectorized expressions to avoid expensive for loops – however, I understand that this isn’t something that should be a concern in Julia. I do like functional programming in general though, so I have really enjoyed JAX’s vmap functionality for expressing such computations. I have been trying to write my Julia code in a similar fashion, although occasionally I get stuck and it feels like I’m fighting with Julia just to write this way, which of course I want to avoid.

This afternoon I found what I think is a nice solution to a problem I was working on. I have a batch of N vectors of dimension A, and it’s represented by a Matrix m of size (A, N). For each vector, I want to compute the softmax and sample a categorical variable parameterized by the softmax as the category probabilities. A first attempt was something like

using Distributions
using Random

[rand(Categorical(p)) for p in eachcol(m)]

but this doesn’t work – I get an exception like

 MethodError: Cannot `convert` an object of type Vector{Float64} to an object of type SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}

Eventually I realized I could fix this with

[rand(Categorical(collect(p))) for p in eachcol(m)]

but this is fairly hard to read. Finally I settled on the following,

map(rand ∘ Categorical ∘ collect, eachcol(m))

which I find much easier to read. However, I rarely see code like this written in Julia, and I’m concerned that there are consequences of this code that I’m unaware of.

How would you write write this code, and is there anything outright bad with the way that I wrote it?

I think it’s a bug that Categorical doesn’t like result in eachcol():

julia> m
3×3 Matrix{Float64}:
 0.108003  0.398417  0.308719
 0.848307  0.167795  0.367462
 0.04369   0.433788  0.323819

julia> first(eachcol(m))
3-element view(::Matrix{Float64}, :, 1) with eltype Float64:
 0.10800316713033577
 0.848306784286756
 0.04369004858290827

julia> first(eachcol(m)) isa AbstractVector
true
1 Like

if you don’t like jamming stuff in map, maybe try a do-block:


julia> map(eachcol(m)) do col
           vec = collect(col)
           rand(Categorical(vec))
       end
3-element Vector{Int64}:
 2
 1
 1

performance-wise, not really too much issue, people would probably make their own function at that point instead of doing composition inside map()

@jling Thanks for the response!

I actually don’t mind jamming stuff in map, my map solution actually just looks like a (nicer) version of how I’d use vmap in JAX.

Another attempt I tried this morning was

eachcol(m) .|> collect .|> Categorical .|> rand

which is fairly neat IMO, however the map solution I posted appears to be almost twice as fast according to @time in the REPL.

1 Like

then you should find anonymous function in Julia much nicer?

map(x->rand(Categorical(collect(x))), eachcol(m))

I mean what do you do in JAX?

use @btime from BenchmarkTools.jl, correctly

1 Like

I actually find the map(rand ∘ Categorical ∘ collect, eachcol(m)) to be quite a bit nicer than what I do with JAX, sorry if I miscommunicated that. I prefer this over the anonymous function suggestion that you gave simply because I don’t like the nested parentheses, I find the function composition version easier to digest. Of course this isn’t possible to do (nicely) in JAX, but it’s nicer :slight_smile:

Thanks for the @btime tip!