Performance comparison - Flux.jl's Adam vs Jax's Adam

Well, here’s one which uses this weird repeat convention, plus the fusion above. About 2.5 times faster. I’d bet there’s another factor of 2 to be had by improving model.

One improvement not in this code block is to use withgradient to get the loss and the gradient at once, rather than running twice. Only 20% or so but worth having.

# Initial code, 10^4 epochs:   4.085793 seconds (28.90 M allocations: 4.149 GiB, 9.70% gc time)
# With this compute_total_obj: 1.915556 seconds (11.76 M allocations: 3.365 GiB, 14.05% gc time)
# withgradient too:            1.614427 seconds (10.87 M allocations: 2.682 GiB, 13.52% gc time)

function compute_total_obj(P::AbstractVector, F::AbstractArray, Z::AbstractArray, 
    Zerr_Re::AbstractArray, Zerr_Im::AbstractArray, 
    LB::AbstractVector, UB::AbstractVector, smf::AbstractVector, func::Function, 
    num_params::Integer, num_eis::Integer, kvals::AbstractVector, d2m::AbstractMatrix)

    P_log = weirdcat(P, kvals, num_params-1)
    up = (10 .^ P_log)
    P_norm = (reshape(LB, :, num_params) .+ up) ./ (1 .+ up ./ reshape(UB, :, num_params))
    
    smf_1 = ifelse.(isinf.(smf), 0.0, smf)
    chi_smf = sum(sum(abs2, d2m * P_log, dims = 1) .* smf_1)

    wrss_tot = compute_wrss(transpose(P_norm), F, Z, Zerr_Re, Zerr_Im, func)
    return (sum(wrss_tot) + chi_smf)
end

# "vmap"-like variant of compute_wrss -- going inside model would be better still:
function compute_wrss(p::AbstractMatrix, f::AbstractMatrix, z::AbstractMatrix, zerr_re::AbstractMatrix,zerr_im::AbstractMatrix, func::F) where F
    z_concat = vcat(real(z), imag(z))
    sigma = vcat(zerr_re, zerr_im)
    z_model = reduce(hcat, map(func, eachcol(p), eachcol(f)))
    vec(sum(abs2, (z_concat .- z_model) ./ sigma, dims=1))
end

# Function to turn a vector into a matrix by repeating entries as whole columns
function weirdcat(data::AbstractVector, indices::AbstractVector{<:Integer}, step::Integer)
  first(indices) == firstindex(data) || error("indices must start at 1")
  last(indices) == lastindex(data)+1 || error("indices must cover the input exactly")
  slices = map(zip(indices, Iterators.drop(indices, 1))) do (lo, next)
    if (next - lo) == step
      # @assert StepRangeLen(lo, 1, step) == lo:next-1  # but changed for type-stability
      view(data, StepRangeLen(lo, 1, step)) 
    elseif (next - lo) == 1
      view(data, StepRangeLen(lo, 0, step))
    else
      error("step must be 1 or given integer")
    end
  end
  reduce(hcat, slices)
end

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(weirdcat), data, indices, step)
  function uncat(dy)
    pieces = map(eachcol(unthunk(dy)), zip(indices, Iterators.drop(indices, 1))) do col, (lo, next)
        if (next - lo) == step
          col
        else
          sum(col, dims=1)  # not the same type as col, unfortunately
        end
    end
    (NoTangent(), reduce(vcat, pieces)::Vector{Float64}, NoTangent(), NoTangent())
  end
  return weirdcat(data, indices, step), uncat
end

using ChainRulesTestUtils  # here check_inferred needs ::Vector{Float64} annotation,
# could be done more elegantly / more generally. Not sure how much performance cares.
test_rrule(weirdcat, randn(7), [1, 4, 5, 8], 3; check_inferred=true)
3 Likes