# Performance comparison - Flux.jl's Adam vs Jax's Adam

Well, here’s one which uses this weird repeat convention, plus the fusion above. About 2.5 times faster. I’d bet there’s another factor of 2 to be had by improving `model`.

One improvement not in this code block is to use `withgradient` to get the loss and the gradient at once, rather than running twice. Only 20% or so but worth having.

``````# Initial code, 10^4 epochs:   4.085793 seconds (28.90 M allocations: 4.149 GiB, 9.70% gc time)
# With this compute_total_obj: 1.915556 seconds (11.76 M allocations: 3.365 GiB, 14.05% gc time)
# withgradient too:            1.614427 seconds (10.87 M allocations: 2.682 GiB, 13.52% gc time)

function compute_total_obj(P::AbstractVector, F::AbstractArray, Z::AbstractArray,
Zerr_Re::AbstractArray, Zerr_Im::AbstractArray,
LB::AbstractVector, UB::AbstractVector, smf::AbstractVector, func::Function,
num_params::Integer, num_eis::Integer, kvals::AbstractVector, d2m::AbstractMatrix)

P_log = weirdcat(P, kvals, num_params-1)
up = (10 .^ P_log)
P_norm = (reshape(LB, :, num_params) .+ up) ./ (1 .+ up ./ reshape(UB, :, num_params))

smf_1 = ifelse.(isinf.(smf), 0.0, smf)
chi_smf = sum(sum(abs2, d2m * P_log, dims = 1) .* smf_1)

wrss_tot = compute_wrss(transpose(P_norm), F, Z, Zerr_Re, Zerr_Im, func)
return (sum(wrss_tot) + chi_smf)
end

# "vmap"-like variant of compute_wrss -- going inside model would be better still:
function compute_wrss(p::AbstractMatrix, f::AbstractMatrix, z::AbstractMatrix, zerr_re::AbstractMatrix,zerr_im::AbstractMatrix, func::F) where F
z_concat = vcat(real(z), imag(z))
sigma = vcat(zerr_re, zerr_im)
z_model = reduce(hcat, map(func, eachcol(p), eachcol(f)))
vec(sum(abs2, (z_concat .- z_model) ./ sigma, dims=1))
end

# Function to turn a vector into a matrix by repeating entries as whole columns
function weirdcat(data::AbstractVector, indices::AbstractVector{<:Integer}, step::Integer)
first(indices) == firstindex(data) || error("indices must start at 1")
last(indices) == lastindex(data)+1 || error("indices must cover the input exactly")
slices = map(zip(indices, Iterators.drop(indices, 1))) do (lo, next)
if (next - lo) == step
# @assert StepRangeLen(lo, 1, step) == lo:next-1  # but changed for type-stability
view(data, StepRangeLen(lo, 1, step))
elseif (next - lo) == 1
view(data, StepRangeLen(lo, 0, step))
else
error("step must be 1 or given integer")
end
end
reduce(hcat, slices)
end

using ChainRulesCore
function ChainRulesCore.rrule(::typeof(weirdcat), data, indices, step)
function uncat(dy)
pieces = map(eachcol(unthunk(dy)), zip(indices, Iterators.drop(indices, 1))) do col, (lo, next)
if (next - lo) == step
col
else
sum(col, dims=1)  # not the same type as col, unfortunately
end
end
(NoTangent(), reduce(vcat, pieces)::Vector{Float64}, NoTangent(), NoTangent())
end
return weirdcat(data, indices, step), uncat
end

using ChainRulesTestUtils  # here check_inferred needs ::Vector{Float64} annotation,
# could be done more elegantly / more generally. Not sure how much performance cares.
test_rrule(weirdcat, randn(7), [1, 4, 5, 8], 3; check_inferred=true)
``````
3 Likes

Thanks @mcabbott .This is incredible. I could never have thought about this. The run time was practically cut in half and takes less time than the Jax equivalent. I changed the line `P_norm = (reshape(LB, :, num_params) .+ up) ./ (1 .+ up ./ reshape(UB, :, num_params))` to `P_norm = (weirdcat(LB, kvals, num_eis) .+ up) ./ (1 .+ up ./ weirdcat(UB, kvals, num_eis))` which now works for all cases.

Could you please show me how you used `withgradient` to get the loss and the gradient at once, rather than running twice?

Now I’ve gotta love Julia.

1 Like

You can replace these two lines

``````gs = gradient(p -> compute_total_obj(p, F, Z, ...)[1]
training_loss = compute_total_obj(p, F, Z, ...)  # 2nd forward pass
``````

with one

``````training_loss, (gs,) = withgradient(p -> compute_total_obj(p, F, Z, ...)
``````

or perhaps

``````res = withgradient(p -> compute_total_obj(p, F, Z, ...)
training_loss = res.val
``````

It’s a bit unfortunate that so much tinkering is necessary. And that what works well for Zygote-able code is not the same style as fast ordinary Julia code.

1 Like

I can see that indeed

One last question please. Could you try taking the hessian of the function. I seem to be unable. The gradient works well but the hessian throws an error `TypeError: in typeassert, expected Vector{Float64}, got a value of type Vector{ForwardDiff.Dual{Nothing, Float64, 12}} `
This is my hessian function

``````    function get_hess(p)
H = Flux.hessian(p -> compute_total_obj(vec(p), F, Z, Zerr_Re, Zerr_Im, lb_vec, ub_vec, smf, func, num_params, num_eis, kvals, d2m), p)
return H
end
``````

I tried using vec(p) in place of p but it didnt work.

What exactly is throwing this? The error comes about because you’re trying to put dual numbers in a vector of floats. Can you post the full stacktrace?

The error is because my `rrule(::typeof(weirdcat)` asserts a type `::Vector{Float64}` for stability. That could be removed. Not sure whether it matters for speed.

1 Like
``````Stacktrace:
#   [1] (::var"#uncat#6"{Vector{Int64}, Int64})(dy::Matrix{ForwardDiff.Dual{Nothing, Float64, 12}})
#     @ Main ./In[4]:108
#   [2] (::Zygote.ZBack{var"#uncat#6"{Vector{Int64}, Int64}})(dy::Matrix{ForwardDiff.Dual{Nothing, Float64, 12}})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:206
#   [3] Pullback
#     @ ./In[4]:69 [inlined]
#   [4] (::typeof(∂(compute_total_obj)))(Δ::ForwardDiff.Dual{Nothing, Float64, 12})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
#   [5] Pullback
#     @ ./In[26]:71 [inlined]
#   [6] (::typeof(∂(λ)))(Δ::ForwardDiff.Dual{Nothing, Float64, 12})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
#   [7] (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::ForwardDiff.Dual{Nothing, Float64, 12})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:45
#   [8] gradient(f::Function, args::Vector{ForwardDiff.Dual{Nothing, Float64, 12}})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:97
#   [9] (::Zygote.var"#102#103"{var"#187#198"{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Int64}, typeof(her), Vector{Float64}, Matrix{Float64}, Int64, Int64, Matrix{ComplexF64}}})(x::Vector{ForwardDiff.Dual{Nothing, Float64, 12}})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/lib/grad.jl:64
#  [10] forward_jacobian(f::Zygote.var"#102#103"{var"#187#198"{Vector{Float64}, Vector{Float64}, Matrix{Float64}, Vector{Int64}, typeof(her), Vector{Float64}, Matrix{Float64}, Int64, Int64, Matrix{ComplexF64}}}, x::Vector{Float64}, #unused#::Val{12})
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/lib/forward.jl:29
#  [11] forward_jacobian(f::Function, x::Vector{Float64}; chunk_threshold::Int64)
#     @ Zygote ~/.julia/packages/Zygote/dABKa/src/lib/forward.jl:44
# ...
#     @ In[27]:1
#  [18] eval
#     @ ./boot.jl:368 [inlined]
#  [19] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
``````

It’s good now.

I think it should be possible to simplify this to avoid the explicit `rrule`, but haven’t had the time to test yet. The gist is to skip using `kvals` and construct a matrix of indices from `P` using `smf` and `num_eis` directly. Wherever `isinf(smf[i])`, that column in the mask can be filled with a single index instead of `j:j+num_eis` where `j` is the last index processed. That index matrix can then be passed to `NNlib.gather` to pull the relevant parameters from `P` in one fell swoop. As long as the index generation function is marked `@non_differentiable` as well, I believe it should be nested diff friendly too. Oh, and (assuming it works) this pattern could help speed up PyTorch and JAX implementations as well.

2 Likes

Thanks at @ToucheSir for your suggestion. I implemented it and it works. Here is what I did still making use of kvals. First within the main function I form the gather_indices matrix

``````gather_indices = transpose(reduce(hcat, [
let istart=kvals[i], istop=kvals[i + 1] - 1, ps = istart:istop
istop - istart == 0 ? repeat(ps, num_eis) : ps
end
for i = 1:num_params]))
``````

Then I within the objective function instead of using weirdcat, I use NNLib.gather together with the indices.:

``````function compute_total_obj(P::AbstractVector, F::AbstractArray, Z::AbstractArray,
Zerr_Re::AbstractArray, Zerr_Im::AbstractArray,
LB::AbstractVector, UB::AbstractVector, smf::AbstractVector, func::Function,
num_params::Integer, num_eis::Integer, gather_indices::AbstractMatrix, d2m::AbstractMatrix)

P_log = NNlib.gather(P, gather_indices)

up = (10 .^ P_log)

P_norm = (NNlib.gather(LB, gather_indices) .+ up) ./ (1 .+ up ./ NNlib.gather(UB, gather_indices))

smf_1 = ifelse.(isinf.(smf), 0.0, smf)
chi_smf = sum(sum(abs2, d2m * transpose(P_log), dims = 1) .* smf_1)

wrss_tot = compute_wrss(P_norm, F, Z, Zerr_Re, Zerr_Im, func)
return (sum(wrss_tot) + chi_smf)
end
``````

Since I dont generate the indices within the objective function, is there any point using the `@non_differentiable` macro? if yes, how do I use it.

1 Like

No, hoisting out a variable like `gather_indices` has much the same effect. It seems to be that `smf_1` could also be calculated outside of the objective function, although I’m not sure how much of a speedup you’d get from that.

okay cool. The speed is pretty okay I guess. I will also try to form smf_1 outside.

1 Like

Now that this no longer needs to be differentiable, you can also consider using mutation again:

``````gather_indices = similar(P, num_params, num_eis)
for i in 1:num_params
istart, istop = kvals[i], kvals[i + 1] - 1
if istop == istart
@views gather_indices[i, :] .= istart
else
@views gather_indices[i, :] .= istart:istop
end
end
``````
1 Like