I had to test this locally, but I think I understand now. JAX lets you do the equivalent of array[:] = [x], so len(P[self.kvals[i]:self.kvals[i + 1]]) may also be 1.
Your Julia implementation looks pretty reasonable. Here is my attempt at cleaning it up a bit:
P_log = reduce(hcat, [
let istart=kvals[i], istop=kvals[i + 1] - 1, ps=@view P[istart:istop]
istop - istart == 1 ? repeat(ps, num_eis) : ps
end
for i = 1:num_params
])
# Alternative way of doing this using map
# Not sure which is faster, feel free to try both
P_norm = reduce(hcat, map(1:num_params) do i
istart, istop = kvals[i], kvals[i + 1] - 1
# @view and @views reduces memory allocations when you're not differentiating this function with Zygote
ps = @view P[istart:istop]
ps_norm = @views (LB[istart:istop] .+ (10 .^ ps)) ./ (1 .+ (10 .^ ps) ./ UB[istart:istop])
ub - lb == 1 ? repeat(ps, num_eis) : ps
end)
smf_1 = ifelse.(isinf.(smf), 0.0, smf)
chi_smf = sum(sum((d2m * transpose(P_log)).^2, dims = 1) .* smf_1)
wrss_tot = compute_wrss.(eachcol(P_norm), eachcol(F), eachcol(Z), eachcol(Zerr_Re), eachcol(Zerr_Im), func)
return sum(wrss_tot) + chi_smf