Taking gradients of large differential equations

Hi all,

I have a rather large reaction network I want to fit to data. I imported it from an .xml file so it is in the form of ModelingToolkit.ODESystem.

It has 228 states and 470 parameters.

It’s not particularly stiff (at least at the current parameter values), and solves easily:

@time sol = solve(prob, Tsit5());
0.199835 seconds (68.93 k allocations: 120.288 MiB)

I want to fit it to data by minimising the L2 norm betwen the output trajectory and datapoints. I know it’s a massively undetermined system. This is my loss function

function loss(in)
    u0_, p_ = outmerge(in)
    pprob = remake(prob, p=p_, u0 = u0_)
    sol = solve(pprob, Tsit5(), saveat = ts, sensealg=BacksolveAdjoint())
    return sum(abs2, Array(sol)[idxs,:] .- data[idxs,:])

I try to run
grad = Zygote.gradient(loss, x0)
and my computer just hangs while precompiling functions (I’ve tested up to 40 mins). My computer isn’t terrible: it’s a 2017 macbook pro.

I’ve tried a few different sensealgs for the differentiation. Does anybody have tips on why such a hang could occur and how to avoid it, or a choice of sensealg() that might avoid this hang?


Share an MWE?

Sure. I don’t know what an MWE would be in this case, given that the problem is particular to the odesystem, and I can’t debug any errors as the command just hangs. But the code itself isn’t very long, and I can send over the .xml file if you like.
The main script is at the bottom.

using SbmlInterface, OrdinaryDiffEq, ModelingToolkit, Zygote, BSON

### auxiliary functions

# fetch data from bson. some rows are NaN.
function make_data()
    data = BSON.load("concentration.bson")
    return Float64.(data[:data])

# import xml file and return ODESystem and ODEProblem
function fetch_model()
    model = getmodel("ImmunoMetabolV5.xml")
    p = getparameters(model)
    u0 = getinitialconditions(model)
    rxs = getreactions(model)
    rs  = ReactionSystem(rxs, t, [item.first for item in u0], [item.first for item in p])
    odesys = convert(ODESystem, rs)
    prob = ODEProblem(odesys,u0,(0.0,30.0),p; jac=true)
    return odesys, prob, p, u0

# find numeric rows of data (ie filter out NaNs)
function get_numeric_idxs(arr)
    setdiff(collect(1:size(arr,1)), findall(isnan.(data[:,1])))

# for setting u0 as the first column of data, for the non-NaN data rows.
function remake_u0(u0, new_vals, idxs)
    u0_vals = last.(u0)
    u0_vals[idxs] .= new_vals[idxs]
    u0 = first.(u0) .=> u0_vals
    return u0

idxs = get_numeric_idxs(data)
nidxs = setdiff(1:length(u0), idxs)

# turn concatenation of parameters and unspecified initial conditions into their component elements
function outmerge(x)
    u = last.(u0)
    u[nidxs] = x[1:length(nidxs)]
    p = x[length(nidxs)+1:end]
    return u, p

function loss(in)
    u0_, p_ = outmerge(in)
    pprob = remake(prob, p=p_, u0 = u0_)
    sol = solve(pprob, Tsit5(), saveat = [0., 24.], sensealg=BacksolveAdjoint())
    return sum(abs2, Array(sol)[idxs,:] .- data[idxs,:])

### main script
odesys, prob, p, u0 = fetch_model()
data = make_data()

u0 = remake_u0(u0, data[:,1], idxs)
prob = remake(prob, u0=last.(u0))
ts = [0., 24.]
p0 = last.(p)
inmerge(u,p) = vcat(u[nidxs], p)
x0 = inmerge(last.(u0), p0)

### this hangs for 40mins+ until I force quit Julia
grad = Zygote.gradient(loss, x0)

Without the files I can’t do much. BacksolveAdjoint is probably a good idea. Also, are you sure the adjoint pass isn’t stiff? I’d try AutoTsit5(TRBDF2()) with InterpolatingAdjoint and see what you get. Then, instrument the RHS of the ODE with some printing of t to see how it’s progressing.

Thanks for the tips. I stopped trying…I was doing it for a friend, but upon inspection the higher level problem to be solved wasn’t really well defined so we stopped.