Performance comparison - Flux.jl's Adam vs Jax's Adam

Hello guys, I have been wondering about Julia’s performance and would like to know if it is really true that Julia is really faster than python as is always stated because I have not yet seen any performance gain since I ported my code from Python to Julia. Infact the code considerable runs faster using Jax and torch than using Julia. I have only compared a few libraries (Julia’s optim vs Jaxopt’s scipy minimize and pytorch-minimize; Julia’s Flux adam vs Jax’s adam). In the attached github gist I have run code for multidimensional optimization using Jax’s implementation of adam vs that of Flux jl. while Jax takes just about 50 seconds to run, Julia’s flux takes about 90 seconds (not considering the compile time.)

Maybe it could be the way I have written the code - I dunno. but I hope someone can look at the two notebooks and help out. Thanks.

Jax Version:

Julia Version:

EDIT: “Weight must be either Nothing, 1, 2, or a matrix” I’m not sure that’s good programming in Julia (or Python?), and the long if. Hope it’s not speed-critical, and best to profile.

I didn’t notice an obvious speed-trap, possibly @inbounds is missing though, you can try running with:

julia --check-bounds=no

to check, and if it’s faster then it or some change is missing. Also if you get lots of allocations, could be a hint.

Yes, in general, though possibly not for all ML/deep learning. Very much faster for SciML, and e.g. (seemingly untypical case):

I thought about the best way to do this but couldn’t come up with one. The idea is that I wanna set a default which uses modulus weighting, other weighting types are “unit”, “proportional” and a 2D array of weights. There was no way to accomplish this with multiple dispatch. Do you have any ideas? One way might be to write a function which accepts a 2D array and another which accepts a string I guess.

Before you (or I) worry, let’s see what’s the bottleneck, i.e. profile (since this works). Or do you think this might be it? See Profile module and ProfView.jl.

Python itself is slow (at, for example, running a loop) in a way that Julia isn’t. But using python (or bash) to launch code written in something else (like C) can be quick. Your examples are in the more complicated space between, where possibly you should think of Jax as a different language, with a compiler aimed at array manipulations, which is accessed though python.

Anyway that’s a lot of code, most of it not performance-critical. If you can supply dummy input for a function like this (or some simplified version) then perhaps people can time it and suggest ways to go faster. This seems to make a lot of slices, and other temporary arrays via repeat, which is likely to hurt performance.

2 Likes

I will check ProfView.jl. The truth is I just started Julia a couple of weeks ago so I am still tryna find my way around. it.

This was wriiten this way to make it autodifferentiable using zygote. You will notice that Jax version was writen in a simpler way.

Maybe. Again, the Jax example is almost a thousand lines of code. It would be much easier to discuss 10-line functions, with random data of roughly the right sizes & types.

Lemme try to make a shorter example.

The profiler and flame graphs is very good. Can be your first try, not MWE. Aslo the performance section of the manual is excellent, on what not to do, and how to detect problems. And JET.jl I’ve not tried much, maybe it (or other tools) should be mentioned in the docs if not already.

Thanks. I will go ahead and try them.

This gets just over a factor of 2 improvement:

# Initial code, 10^4 epochs:   4.085793 seconds (28.90 M allocations: 4.149 GiB, 9.70% gc time)
# With this compute_total_obj: 1.887204 seconds (11.76 M allocations: 3.348 GiB, 14.07% gc time)

function compute_total_obj(P::AbstractVector, F::AbstractArray, Z::AbstractArray, 
    Zerr_Re::AbstractArray, Zerr_Im::AbstractArray, 
    LB::AbstractVector, UB::AbstractVector, smf::AbstractVector, func::Function, 
    num_params::Integer, num_eis::Integer, kvals::AbstractVector, d2m::AbstractMatrix)
    # num_params = 6
    # summary(P) == "30-element Vector{Float64}"

    # Exploiting kvals = [1, 6, 11, 16, 21, 26, 31] i.e. evenly spaced, none co-incide, to simplify
    P_log = reshape(P, :, num_params)
    # Indexing like ps = @view P[istart:istop] costs a whole copy(P) in every gradient, sadly,
    # and @view only saves on forward pass. But eachcol is more efficient:
    # tmp = map(eachcol(P_log), eachcol(reshape(LB, :, num_params)), eachcol(reshape(UB, :, num_params))) do ps, L, U
    #     up = (10 .^ ps)  # do this once
    #     ps_norm = (L .+ up) ./ (1 .+ up ./ U)
    # end
    # P_norm = reduce(hcat, tmp) # With this change alone, 2.515051 seconds

    # Even better, avoid the slices completely (worth about 0.2 sec):
    up = (10 .^ P_log)
    P_norm = (reshape(LB, :, num_params) .+ up) ./ (1 .+ up ./ reshape(UB, :, num_params))
    
    smf_1 = ifelse.(isinf.(smf), 0.0, smf)
    chi_smf = sum(sum(abs2, d2m * P_log, dims = 1) .* smf_1)

    # This is perfect for vmap, as the function applied to slices here only involves broadcasting etc,
    # which could more efficiently be done on whole matrices, not slices.
    # wrss_tot = compute_wrss.(eachcol(transpose(P_norm)), eachcol(F), eachcol(Z), eachcol(Zerr_Re), eachcol(Zerr_Im), func)
    # Here done incompletely, by hand, below:
    wrss_tot = compute_wrss(transpose(P_norm), F, Z, Zerr_Re, Zerr_Im, func)
    return (sum(wrss_tot) + chi_smf)
end

function compute_wrss(p::AbstractVector, f::AbstractVector, z::AbstractVector, zerr_re::AbstractVector,zerr_im::AbstractVector, func::F) where F
    z_concat = vcat(real(z), imag(z))
    sigma = vcat(zerr_re, zerr_im)
    z_model = func(p, f)
    sum(abs2, (z_concat .- z_model) ./ sigma)  # norm^2, without sqrt-then-square
end
# "vmap"-like variant -- going inside model would be better still:
function compute_wrss(p::AbstractMatrix, f::AbstractMatrix, z::AbstractMatrix, zerr_re::AbstractMatrix,zerr_im::AbstractMatrix, func::F) where F
    z_concat = vcat(real(z), imag(z))
    sigma = vcat(zerr_re, zerr_im)
    z_model = reduce(hcat, map(func, eachcol(p), eachcol(f)))
    vec(sum(abs2, (z_concat .- z_model) ./ sigma, dims=1))
end

Here P_norm relies on kvals being evenly spaced, to use eachcol to avoid indexing. It would not be hard to write some kind of split_at function with an efficient gradient.

The functions being mapped over slices (for wrss_tot) are all quite simple things like broadcasting which could in principle apply to the whole array. My attempt at a vmap doesn’t handle this case at present. But performing the transformation by hand on compute_wrss isn’t so hard.

Doing it for model would probably help a lot, but would be more work. It might be possible to pass that tuples of length num_params instead of arrays, but I didn’t try.

Indeed I can see the improvement. However, kvals is not always evenly spaced. for instance if you set the first value of smf vector to to Inf, kvals becomes [1, 2, 7, 12, 17, 22, 27] and thus a jagged array is produced (hence the repeat function since Julia does not broadcast a scalar to a vector). The big picture is that once I set any element of smf to Inf, the that value is kept constant during optimization. How best to handle this?

(Edit: Here I had a function for generic “kvals”, not needed for the MWE above. Nor, apparently, for the real problem – “kvals is not always evenly spaced” does not mean they are generic indices.)

Something like this would allow for unequal breaks, maybe this very exists somewhere? May need refinement… (the need for a gradient is essentially a hack to get around the inefficiency of Zygote’s gradient accumulation, which could mutate one array but instead makes many new ones):

function splitat(data::AbstractVector, indices::AbstractVector{<:Integer})
  first(indices) == firstindex(data) || error("indices must start at 1")
  last(indices) == lastindex(data)+1 || error("indices must cover the input exactly")
  [view(data, lo:next-1) for (lo, next) in zip(indices, Iterators.drop(indices, 1))]
end
splitat('a':'z', [1, 3, 3, 13, 27])  # one empty array

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(splitat), data, indices)
  unsplit(dy) = (NoTangent(), reduce(vcat, unthunk(dy)), NoTangent())
  return splitat(data, indices), unsplit
end

using ChainRulesTestUtils
test_rrule(splitat, randn(10), [1, 3, 4, 11])

Not sure I follow exactly what repeat(ps, num_eis) is doing.

repeat(ps, num_eis) is similar to P_log[i, :] = ps where P_log has a shape(num_params, num_eis) such that when ps is a scalar it fills the column. in Jax I just do P_log = P_log.at[i, :].set(ps) but Zygote does not allow array mutation so it throws an error. Hence the repeat with hcat. Consider when kvals is uneven e.g [1, 2, 7, 12, 17, 22, 27], then istart=kvals[i], istop=kvals[i + 1] - 1 is 1, 1 when i ==1, ps=@view P[istart:istop] then returns a scalar since P is a vector.

exactly, it means we have a scalar (I meant 1 element vector), so repeat that to fill the column. I guess you understand what I mean right? Basically the values returned by kvals depends on the vector smf. If i set all the elements in smf to Inf, e.g, smf = [Inf, Inf, Inf, Inf, Inf, Inf], kvals becomes [1,2,3,4,5,6,7]

I thought you got an empty column but was mistaken. In fact the special case istop - istart == 0 means you get a one-element vector, rand(5)[3:3] isa Vector, not a scalar.

So it’s wrong to think that kvals is generic, and my function above is a waste of time. Instead, you allow only partitions into length-1 and length-5 pieces? diff([1, 2, 7, 12, 17, 22, 27],). But the length-1 case is garbage anyway, a fact which is also independently encoded in smf. If this case is common then maybe figuring out how not to do that work at all would be better. If you really must make a column to throw away later, why not make zeros(n)?

Edit: Or maybe it’s not garbage, but why not make things the right length immediately? Surely every language will prefer regular strides… this dual encoding of Inf and step-1, which needs all this branching, still seems pretty odd. I’m sure it can be better optimised but might be clearer just to avoid.

But, most of all, it’s important that the MWE actually exercise details you care about!

The function get_kvals was made to produce a different kvals depending on what parameter I choose to keep constant. the model could have up to 10 to 12 or more parameters and one could choose to keep certain parameters constant by setting the corresponding value in smf to Inf.

function get_kvals(smf::AbstractVector, num_eis::Integer)
    kvals = cumsum(insert!(ifelse.(isinf.(smf), 1, num_eis), 1, 1),)
    return kvals
end

The problem was how to efficiently broadcast a one element vector to fill up the column of P_log (in such cases) in a way that is both autodifferentiable and efficient. Ofcourse in Jax and Pytorch I had no issues and it was fast enough. Why they are both faster than Julia was my worry.

Yea I’m still cracking my brains on how best to work around it.

Actually no column is thrown away. We do a simultaneous minimization such that the parameters vary smoothly. The total number of parameters (total_params) minimized is (num_params * num_eis - (num_eis - 1) * sum(isinf.(smf))) . For instance if we num_params = 6 and num_eis =5 like we have in the example, then the total number of params is 30. if the fist value in smf is set to Inf meaning that you wanna keep the first parameter constant then total_params = 26 and this is where kvals helps. It is used to keep track of the parameters. So In this case we minimize 26 parameters but compute the objective based on 30. The parameter kept constant becomes a one element vector which is broadcasted. It is a simple trick that helps.