This gets just over a factor of 2 improvement:
# Initial code, 10^4 epochs: 4.085793 seconds (28.90 M allocations: 4.149 GiB, 9.70% gc time)
# With this compute_total_obj: 1.887204 seconds (11.76 M allocations: 3.348 GiB, 14.07% gc time)
function compute_total_obj(P::AbstractVector, F::AbstractArray, Z::AbstractArray,
Zerr_Re::AbstractArray, Zerr_Im::AbstractArray,
LB::AbstractVector, UB::AbstractVector, smf::AbstractVector, func::Function,
num_params::Integer, num_eis::Integer, kvals::AbstractVector, d2m::AbstractMatrix)
# num_params = 6
# summary(P) == "30-element Vector{Float64}"
# Exploiting kvals = [1, 6, 11, 16, 21, 26, 31] i.e. evenly spaced, none co-incide, to simplify
P_log = reshape(P, :, num_params)
# Indexing like ps = @view P[istart:istop] costs a whole copy(P) in every gradient, sadly,
# and @view only saves on forward pass. But eachcol is more efficient:
# tmp = map(eachcol(P_log), eachcol(reshape(LB, :, num_params)), eachcol(reshape(UB, :, num_params))) do ps, L, U
# up = (10 .^ ps) # do this once
# ps_norm = (L .+ up) ./ (1 .+ up ./ U)
# end
# P_norm = reduce(hcat, tmp) # With this change alone, 2.515051 seconds
# Even better, avoid the slices completely (worth about 0.2 sec):
up = (10 .^ P_log)
P_norm = (reshape(LB, :, num_params) .+ up) ./ (1 .+ up ./ reshape(UB, :, num_params))
smf_1 = ifelse.(isinf.(smf), 0.0, smf)
chi_smf = sum(sum(abs2, d2m * P_log, dims = 1) .* smf_1)
# This is perfect for vmap, as the function applied to slices here only involves broadcasting etc,
# which could more efficiently be done on whole matrices, not slices.
# wrss_tot = compute_wrss.(eachcol(transpose(P_norm)), eachcol(F), eachcol(Z), eachcol(Zerr_Re), eachcol(Zerr_Im), func)
# Here done incompletely, by hand, below:
wrss_tot = compute_wrss(transpose(P_norm), F, Z, Zerr_Re, Zerr_Im, func)
return (sum(wrss_tot) + chi_smf)
end
function compute_wrss(p::AbstractVector, f::AbstractVector, z::AbstractVector, zerr_re::AbstractVector,zerr_im::AbstractVector, func::F) where F
z_concat = vcat(real(z), imag(z))
sigma = vcat(zerr_re, zerr_im)
z_model = func(p, f)
sum(abs2, (z_concat .- z_model) ./ sigma) # norm^2, without sqrt-then-square
end
# "vmap"-like variant -- going inside model would be better still:
function compute_wrss(p::AbstractMatrix, f::AbstractMatrix, z::AbstractMatrix, zerr_re::AbstractMatrix,zerr_im::AbstractMatrix, func::F) where F
z_concat = vcat(real(z), imag(z))
sigma = vcat(zerr_re, zerr_im)
z_model = reduce(hcat, map(func, eachcol(p), eachcol(f)))
vec(sum(abs2, (z_concat .- z_model) ./ sigma, dims=1))
end
Here P_norm
relies on kvals
being evenly spaced, to use eachcol
to avoid indexing. It would not be hard to write some kind of split_at
function with an efficient gradient.
The functions being mapped over slices (for wrss_tot
) are all quite simple things like broadcasting which could in principle apply to the whole array. My attempt at a vmap
doesn’t handle this case at present. But performing the transformation by hand on compute_wrss
isn’t so hard.
Doing it for model
would probably help a lot, but would be more work. It might be possible to pass that tuples of length num_params
instead of arrays, but I didn’t try.