Well, here’s one which uses this weird repeat convention, plus the fusion above. About 2.5 times faster. I’d bet there’s another factor of 2 to be had by improving model
.
One improvement not in this code block is to use withgradient
to get the loss and the gradient at once, rather than running twice. Only 20% or so but worth having.
# Initial code, 10^4 epochs: 4.085793 seconds (28.90 M allocations: 4.149 GiB, 9.70% gc time)
# With this compute_total_obj: 1.915556 seconds (11.76 M allocations: 3.365 GiB, 14.05% gc time)
# withgradient too: 1.614427 seconds (10.87 M allocations: 2.682 GiB, 13.52% gc time)
function compute_total_obj(P::AbstractVector, F::AbstractArray, Z::AbstractArray,
Zerr_Re::AbstractArray, Zerr_Im::AbstractArray,
LB::AbstractVector, UB::AbstractVector, smf::AbstractVector, func::Function,
num_params::Integer, num_eis::Integer, kvals::AbstractVector, d2m::AbstractMatrix)
P_log = weirdcat(P, kvals, num_params-1)
up = (10 .^ P_log)
P_norm = (reshape(LB, :, num_params) .+ up) ./ (1 .+ up ./ reshape(UB, :, num_params))
smf_1 = ifelse.(isinf.(smf), 0.0, smf)
chi_smf = sum(sum(abs2, d2m * P_log, dims = 1) .* smf_1)
wrss_tot = compute_wrss(transpose(P_norm), F, Z, Zerr_Re, Zerr_Im, func)
return (sum(wrss_tot) + chi_smf)
end
# "vmap"-like variant of compute_wrss -- going inside model would be better still:
function compute_wrss(p::AbstractMatrix, f::AbstractMatrix, z::AbstractMatrix, zerr_re::AbstractMatrix,zerr_im::AbstractMatrix, func::F) where F
z_concat = vcat(real(z), imag(z))
sigma = vcat(zerr_re, zerr_im)
z_model = reduce(hcat, map(func, eachcol(p), eachcol(f)))
vec(sum(abs2, (z_concat .- z_model) ./ sigma, dims=1))
end
# Function to turn a vector into a matrix by repeating entries as whole columns
function weirdcat(data::AbstractVector, indices::AbstractVector{<:Integer}, step::Integer)
first(indices) == firstindex(data) || error("indices must start at 1")
last(indices) == lastindex(data)+1 || error("indices must cover the input exactly")
slices = map(zip(indices, Iterators.drop(indices, 1))) do (lo, next)
if (next - lo) == step
# @assert StepRangeLen(lo, 1, step) == lo:next-1 # but changed for type-stability
view(data, StepRangeLen(lo, 1, step))
elseif (next - lo) == 1
view(data, StepRangeLen(lo, 0, step))
else
error("step must be 1 or given integer")
end
end
reduce(hcat, slices)
end
using ChainRulesCore
function ChainRulesCore.rrule(::typeof(weirdcat), data, indices, step)
function uncat(dy)
pieces = map(eachcol(unthunk(dy)), zip(indices, Iterators.drop(indices, 1))) do col, (lo, next)
if (next - lo) == step
col
else
sum(col, dims=1) # not the same type as col, unfortunately
end
end
(NoTangent(), reduce(vcat, pieces)::Vector{Float64}, NoTangent(), NoTangent())
end
return weirdcat(data, indices, step), uncat
end
using ChainRulesTestUtils # here check_inferred needs ::Vector{Float64} annotation,
# could be done more elegantly / more generally. Not sure how much performance cares.
test_rrule(weirdcat, randn(7), [1, 4, 5, 8], 3; check_inferred=true)