Performance comparison - Flux.jl's Adam vs Jax's Adam

Well, here’s one which uses this weird repeat convention, plus the fusion above. About 2.5 times faster. I’d bet there’s another factor of 2 to be had by improving model.

One improvement not in this code block is to use withgradient to get the loss and the gradient at once, rather than running twice. Only 20% or so but worth having.

# Initial code, 10^4 epochs:   4.085793 seconds (28.90 M allocations: 4.149 GiB, 9.70% gc time)
# With this compute_total_obj: 1.915556 seconds (11.76 M allocations: 3.365 GiB, 14.05% gc time)
# withgradient too:            1.614427 seconds (10.87 M allocations: 2.682 GiB, 13.52% gc time)

function compute_total_obj(P::AbstractVector, F::AbstractArray, Z::AbstractArray, 
    Zerr_Re::AbstractArray, Zerr_Im::AbstractArray, 
    LB::AbstractVector, UB::AbstractVector, smf::AbstractVector, func::Function, 
    num_params::Integer, num_eis::Integer, kvals::AbstractVector, d2m::AbstractMatrix)

    P_log = weirdcat(P, kvals, num_params-1)
    up = (10 .^ P_log)
    P_norm = (reshape(LB, :, num_params) .+ up) ./ (1 .+ up ./ reshape(UB, :, num_params))
    
    smf_1 = ifelse.(isinf.(smf), 0.0, smf)
    chi_smf = sum(sum(abs2, d2m * P_log, dims = 1) .* smf_1)

    wrss_tot = compute_wrss(transpose(P_norm), F, Z, Zerr_Re, Zerr_Im, func)
    return (sum(wrss_tot) + chi_smf)
end

# "vmap"-like variant of compute_wrss -- going inside model would be better still:
function compute_wrss(p::AbstractMatrix, f::AbstractMatrix, z::AbstractMatrix, zerr_re::AbstractMatrix,zerr_im::AbstractMatrix, func::F) where F
    z_concat = vcat(real(z), imag(z))
    sigma = vcat(zerr_re, zerr_im)
    z_model = reduce(hcat, map(func, eachcol(p), eachcol(f)))
    vec(sum(abs2, (z_concat .- z_model) ./ sigma, dims=1))
end

# Function to turn a vector into a matrix by repeating entries as whole columns
function weirdcat(data::AbstractVector, indices::AbstractVector{<:Integer}, step::Integer)
  first(indices) == firstindex(data) || error("indices must start at 1")
  last(indices) == lastindex(data)+1 || error("indices must cover the input exactly")
  slices = map(zip(indices, Iterators.drop(indices, 1))) do (lo, next)
    if (next - lo) == step
      # @assert StepRangeLen(lo, 1, step) == lo:next-1  # but changed for type-stability
      view(data, StepRangeLen(lo, 1, step)) 
    elseif (next - lo) == 1
      view(data, StepRangeLen(lo, 0, step))
    else
      error("step must be 1 or given integer")
    end
  end
  reduce(hcat, slices)
end

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(weirdcat), data, indices, step)
  function uncat(dy)
    pieces = map(eachcol(unthunk(dy)), zip(indices, Iterators.drop(indices, 1))) do col, (lo, next)
        if (next - lo) == step
          col
        else
          sum(col, dims=1)  # not the same type as col, unfortunately
        end
    end
    (NoTangent(), reduce(vcat, pieces)::Vector{Float64}, NoTangent(), NoTangent())
  end
  return weirdcat(data, indices, step), uncat
end

using ChainRulesTestUtils  # here check_inferred needs ::Vector{Float64} annotation,
# could be done more elegantly / more generally. Not sure how much performance cares.
test_rrule(weirdcat, randn(7), [1, 4, 5, 8], 3; check_inferred=true)
3 Likes

Thanks @mcabbott .This is incredible. I could never have thought about this. The run time was practically cut in half and takes less time than the Jax equivalent. I changed the line P_norm = (reshape(LB, :, num_params) .+ up) ./ (1 .+ up ./ reshape(UB, :, num_params)) to P_norm = (weirdcat(LB, kvals, num_eis) .+ up) ./ (1 .+ up ./ weirdcat(UB, kvals, num_eis)) which now works for all cases.

Could you please show me how you used withgradient to get the loss and the gradient at once, rather than running twice?

Now I’ve gotta love Julia.

1 Like

You can replace these two lines

gs = gradient(p -> compute_total_obj(p, F, Z, ...)[1]
training_loss = compute_total_obj(p, F, Z, ...)  # 2nd forward pass

with one

training_loss, (gs,) = withgradient(p -> compute_total_obj(p, F, Z, ...)

or perhaps

res = withgradient(p -> compute_total_obj(p, F, Z, ...)
gs = res.grad[1]
training_loss = res.val

It’s a bit unfortunate that so much tinkering is necessary. And that what works well for Zygote-able code is not the same style as fast ordinary Julia code.

1 Like

I can see that indeed

One last question please. Could you try taking the hessian of the function. I seem to be unable. The gradient works well but the hessian throws an error TypeError: in typeassert, expected Vector{Float64}, got a value of type Vector{ForwardDiff.Dual{Nothing, Float64, 12}}
This is my hessian function

    function get_hess(p)
        H = Flux.hessian(p -> compute_total_obj(vec(p), F, Z, Zerr_Re, Zerr_Im, lb_vec, ub_vec, smf, func, num_params, num_eis, kvals, d2m), p)
        return H
    end

I tried using vec(p) in place of p but it didnt work.

What exactly is throwing this? The error comes about because you’re trying to put dual numbers in a vector of floats. Can you post the full stacktrace?

The error is because my rrule(::typeof(weirdcat) asserts a type ::Vector{Float64} for stability. That could be removed. Not sure whether it matters for speed.

1 Like
Stacktrace:
#   [1] (::var"#uncat#6"{Vector{Int64}, Int64})(dy::Matrix{ForwardDiff.Dual{Nothing, Float64, 12}})
#     @ Main ./In[4]:108
#   [2] (::Zygote.ZBack{var"#uncat#6"{Vector{Int64}, Int64}})(dy::Matrix{ForwardDiff.Dual{Nothing, Float64, 12}})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:206
#   [3] Pullback
#     @ ./In[4]:69 [inlined]
#   [4] (::typeof(∂(compute_total_obj)))(Δ::ForwardDiff.Dual{Nothing, Float64, 12})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
#   [5] Pullback
#     @ ./In[26]:71 [inlined]
#   [6] (::typeof(∂(λ)))(Δ::ForwardDiff.Dual{Nothing, Float64, 12})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
#   [7] (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::ForwardDiff.Dual{Nothing, Float64, 12})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:45
#   [8] gradient(f::Function, args::Vector{ForwardDiff.Dual{Nothing, Float64, 12}})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:97
#   [9] (::Zygote.var"#102#103"{var"#187#198"{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Int64}, typeof(her), Vector{Float64}, Matrix{Float64}, Int64, Int64, Matrix{ComplexF64}}})(x::Vector{ForwardDiff.Dual{Nothing, Float64, 12}})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/lib/grad.jl:64
#  [10] forward_jacobian(f::Zygote.var"#102#103"{var"#187#198"{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Int64}, typeof(her), Vector{Float64}, Matrix{Float64}, Int64, Int64, Matrix{ComplexF64}}}, x::Vector{Float64}, #unused#::Val{12})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/lib/forward.jl:29
#  [11] forward_jacobian(f::Function, x::Vector{Float64}; chunk_threshold::Int64)
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/lib/forward.jl:44
# ...
#     @ In[27]:1
#  [18] eval
#     @ ./boot.jl:368 [inlined]
#  [19] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
#     @ Base ./loading.jl:1428

It’s good now.

I think it should be possible to simplify this to avoid the explicit rrule, but haven’t had the time to test yet. The gist is to skip using kvals and construct a matrix of indices from P using smf and num_eis directly. Wherever isinf(smf[i]), that column in the mask can be filled with a single index instead of j:j+num_eis where j is the last index processed. That index matrix can then be passed to NNlib.gather to pull the relevant parameters from P in one fell swoop. As long as the index generation function is marked @non_differentiable as well, I believe it should be nested diff friendly too. Oh, and (assuming it works) this pattern could help speed up PyTorch and JAX implementations as well.

2 Likes

Thanks at @ToucheSir for your suggestion. I implemented it and it works. Here is what I did still making use of kvals. First within the main function I form the gather_indices matrix

gather_indices = transpose(reduce(hcat, [
    let istart=kvals[i], istop=kvals[i + 1] - 1, ps = istart:istop
        istop - istart == 0 ? repeat(ps, num_eis) : ps
    end
    for i = 1:num_params]))

Then I within the objective function instead of using weirdcat, I use NNLib.gather together with the indices.:

function compute_total_obj(P::AbstractVector, F::AbstractArray, Z::AbstractArray, 
    Zerr_Re::AbstractArray, Zerr_Im::AbstractArray, 
    LB::AbstractVector, UB::AbstractVector, smf::AbstractVector, func::Function, 
    num_params::Integer, num_eis::Integer, gather_indices::AbstractMatrix, d2m::AbstractMatrix)

    P_log = NNlib.gather(P, gather_indices)

    up = (10 .^ P_log)

    P_norm = (NNlib.gather(LB, gather_indices) .+ up) ./ (1 .+ up ./ NNlib.gather(UB, gather_indices))
    
    smf_1 = ifelse.(isinf.(smf), 0.0, smf)
    chi_smf = sum(sum(abs2, d2m * transpose(P_log), dims = 1) .* smf_1)

    wrss_tot = compute_wrss(P_norm, F, Z, Zerr_Re, Zerr_Im, func)
    return (sum(wrss_tot) + chi_smf)
end

Since I dont generate the indices within the objective function, is there any point using the @non_differentiable macro? if yes, how do I use it.

1 Like

No, hoisting out a variable like gather_indices has much the same effect. It seems to be that smf_1 could also be calculated outside of the objective function, although I’m not sure how much of a speedup you’d get from that.

okay cool. The speed is pretty okay I guess. I will also try to form smf_1 outside.

1 Like

Now that this no longer needs to be differentiable, you can also consider using mutation again:

gather_indices = similar(P, num_params, num_eis)
for i in 1:num_params
  istart, istop = kvals[i], kvals[i + 1] - 1
  if istop == istart
    @views gather_indices[i, :] .= istart
  else
    @views gather_indices[i, :] .= istart:istop
  end
end
1 Like