Lux.jl vs Jax

I’ve started using Julia this semester, and would like to reimplement some ML models from Jax to Lux.jl. I have three concerns:

  1. I noticed there aren’t constructs like vmap. My existing Jax code has quite a few vmap calls. It’s crucial to take advantage of the GPU.
  2. How fast is Lux.jl compared to Jax implementation for moderately beefy models?
  3. I have a hard time designing structs which are to be vectorized. For example, I could have
struct Foo
  x::Float32
end

xs = Foo[...]

or

struct Foo
  x::Vector{Float32}
end

I prefer the first one since it is easier to reason with, but I believe the second is necessary for good GPU performance. Do I need to make this trade-off or am I missing something?

  1. GPU Samplers. The Distributions.jl package seems to implement distributions on the CPU, but I don’t see anything about GPU sampling.

Thank you!

1 Like

Welcome! Answers to 2 of your questions:

  1. I noticed there aren’t constructs like vmap. My existing Jax code has quite a few vmap calls. It’s crucial to take advantage of the GPU.

No vmap in Julia currently unfortunately. Just try putting your loop on the outer most level with a normal map rather than pushing to the inner-most loop like vmap. As long as your inner-most loops already saturate the GPU, it won’t really matter (if they don’t though you do have to think more than the easy solution vmap offers)

  1. I have a hard time designing structs which are to be vectorized. For example, I could have

A unique thing Julia can do that Jax can’t is put arrays of custom structs on GPU as long as the the structs are concrete (all types known at compile time). So e.g. this is totally valid and performant

julia> struct Foo
          x::Float32
       end

julia> cu(Foo.([1,2,3]))

3-element CuArray{Foo, 1, CUDA.Mem.DeviceBuffer}:
 Foo(1.0f0)
 Foo(2.0f0)
 Foo(3.0f0)

so if you find code is easier to write/reason about that way, go for it.

4 Likes

Thank you! I didn’t know you can do that with structs. Could you elaborate on Q1 though, maybe with a simple example?

On Q1, what do you want to do? Give the input NN a batch of input arrays simoultaneously?

1 Like

I like StructArrays.jl which does the “struct of arrays” arrangement.

For Q1, I think the closest thing we’ll have in Julia is Reactant.jl, but it’s still a very very early prototype. @avikpal and @wsmoses would be able to tell us more

2 Likes

Generally beefy models on CUDA are just cuDNN calls internally, so it should be comparable performance. It is a bit hard to say for sure before looking at the model, but if there is a noticeable slowdown, open an issue or make a post here or in github and we can take a look.

For CPU it would be slower if you have conv routines (smaller models are actually faster than pytorch but that is not what you are interested in), and for ROCm again it depends because we don’t have all the bindings. Metal and oneAPI are experimental and performance there is quite bad atm.

But overall, the eventual goal is to be able to compile via Reactant, and if you take a look at the Reactant and Lux repo, we are actively working on an easy way to take a Lux model and make it faster.

2 Likes