Gradient of a loss function : struggling to avoid arrays mutation

The more natural way to do this would be to use reduce:

function loss_2(p_p)
    diff_pred_obs = map( p_u0 -> ode_p_test_2(p_p, p_u0), eachrow(tab_u0) ) - tab_uT_2
    return sum(reduce(vcat,diff_pred_obs) .^2)
end

That will avoid the problems associated with splatting, which is something you rarely want to do on a big array. Additionally, sum(abs2,reduce(vcat,diff_pred_obs)) is a slightly nicer style. Additionally, instead of using a map you can use Parallel ensembles to multithread the solves. This is demonstrated in the SDE parameter estimation tutorial:

https://diffeqflux.sciml.ai/dev/examples/optimization_sde/

So in total, all you need to do to make your code work is reduce instead of splat (@dhairyagandhi96 it might be good to try and figure out what’s up with that adjoint anyways), but using the ensembles will have some advantages and will make sum(abs2,sol-data) directly work. Cheers!

1 Like