For loop having too many allocation

You can simplify the loss function if you reshape your input data in the following way

X1 = zeros(13, n_pts * n_cases)
X2 = zeros(5, n_pts * n_cases)
i = 0
for ics in 1:n_cases, ipt in 1:n_pts
    i += 1
    X1[:, i] = vcat(xyz[ipt, ics, :], uparams[ics, :], wa[ics])
    X2[:, i] = pp[ipt, ics, :]
end

loss_new(X1, X2, NN2) = sum(abs2, NN2(X1) - X2)
julia> @time loss_new(X, Y, NN2);
  0.348413 seconds (15 allocations: 296.159 MiB, 4.08% gc time)
1 Like