Hi all,
I have a rather large reaction network I want to fit to data. I imported it from an .xml file so it is in the form of ModelingToolkit.ODESystem.
It has 228 states and 470 parameters.
It’s not particularly stiff (at least at the current parameter values), and solves easily:
@time sol = solve(prob, Tsit5());
0.199835 seconds (68.93 k allocations: 120.288 MiB)
I want to fit it to data by minimising the L2 norm betwen the output trajectory and datapoints. I know it’s a massively undetermined system. This is my loss function
function loss(in)
u0_, p_ = outmerge(in)
pprob = remake(prob, p=p_, u0 = u0_)
sol = solve(pprob, Tsit5(), saveat = ts, sensealg=BacksolveAdjoint())
return sum(abs2, Array(sol)[idxs,:] .- data[idxs,:])
end
I try to run
grad = Zygote.gradient(loss, x0)
and my computer just hangs while precompiling functions (I’ve tested up to 40 mins). My computer isn’t terrible: it’s a 2017 macbook pro.
I’ve tried a few different sensealgs for the differentiation. Does anybody have tips on why such a hang could occur and how to avoid it, or a choice of sensealg() that might avoid this hang?
Thanks!