I’ve started using Julia this semester, and would like to reimplement some ML models from Jax to Lux.jl. I have three concerns:
I noticed there aren’t constructs like vmap. My existing Jax code has quite a few vmap calls. It’s crucial to take advantage of the GPU.
How fast is Lux.jl compared to Jax implementation for moderately beefy models?
I have a hard time designing structs which are to be vectorized. For example, I could have
struct Foo
x::Float32
end
xs = Foo[...]
or
struct Foo
x::Vector{Float32}
end
I prefer the first one since it is easier to reason with, but I believe the second is necessary for good GPU performance. Do I need to make this trade-off or am I missing something?
GPU Samplers. The Distributions.jl package seems to implement distributions on the CPU, but I don’t see anything about GPU sampling.
I noticed there aren’t constructs like vmap. My existing Jax code has quite a few vmap calls. It’s crucial to take advantage of the GPU.
No vmap in Julia currently unfortunately. Just try putting your loop on the outer most level with a normal map rather than pushing to the inner-most loop like vmap. As long as your inner-most loops already saturate the GPU, it won’t really matter (if they don’t though you do have to think more than the easy solution vmap offers)
I have a hard time designing structs which are to be vectorized. For example, I could have
A unique thing Julia can do that Jax can’t is put arrays of custom structs on GPU as long as the the structs are concrete (all types known at compile time). So e.g. this is totally valid and performant
For Q1, I think the closest thing we’ll have in Julia is Reactant.jl, but it’s still a very very early prototype. @avikpal and @wsmoses would be able to tell us more
Generally beefy models on CUDA are just cuDNN calls internally, so it should be comparable performance. It is a bit hard to say for sure before looking at the model, but if there is a noticeable slowdown, open an issue or make a post here or in github and we can take a look.
For CPU it would be slower if you have conv routines (smaller models are actually faster than pytorch but that is not what you are interested in), and for ROCm again it depends because we don’t have all the bindings. Metal and oneAPI are experimental and performance there is quite bad atm.
But overall, the eventual goal is to be able to compile via Reactant, and if you take a look at the Reactant and Lux repo, we are actively working on an easy way to take a Lux model and make it faster.