[ANN] MinimallyDisruptiveCurves.jl : regional (in)sensitivity analysis by finding hidden relationships between model parameters

Hi all!

What does it do?

  1. You’ve already fitted some model, by minimising a cost function on the model behaviour. Any type of model, as long as the cost function is differentiable. (I’ve focused on ODE models of biological processes in the user guide).

  2. This tool solves a differential equation to generate ‘trajectories’ in parameter space over which the cost function varies minimally.

  3. You look at these trajectories, and learn something about the model. “oh, model behaviour is (approx/exactly) invariant when parameter 1 (p_1) increases, as long as the product $p_1p_{23} is preserved!”.

  • In statistical terminology, this tool finds (structural /practical) unidentifiabilities in parameter space.

  • In mathematical terminology, it finds (approximate/exact) invariants of model behaviour in parameter space.

… at least, that’s the hope! Would love feedback on (success stories/failure stories/bugs/suggested improvements to the code/user guide/examples).

Concretely, in the examples (user guide): it finds `unnecessary’ model interactions, hidden, fast-timescale subsystems, parameters that can go to zero/infinity without really changing model behaviour. So hopefully useful if you’re interested in finding that type of thing

Thanks, and hope it’s useful!

16 Likes
  1. the mean prey population over time

  2. max predator population over time

  • global sensitivity analysis tools in the above link (roughly) assign an ‘importance’ to each parameter, in terms of how much they vary the considered model features. They do this by sampling lots of different parameter combinations.

  • Our method explicitly finds relationships between parameters over which the model features vary minimally, without blind sampling. The MD curve we generated took < 5 seconds on my laptop.

  • The MD curve shows that we can concurrently increase the predator growth and death rates by over 500%, while preserving model features almost perfectly. We just need to apply particular compensations to the prey growth/predation rates.

The code:


using DiffEqSensitivity, OrdinaryDiffEq, ForwardDiff, MinimallyDisruptiveCurves, Statistics, Plots, LinearAlgebra

function f(du,u,p,t)
  du[1] = p[1]*u[1] - p[2]*u[1]*u[2] #prey
  du[2] = -p[3]*u[2] + p[4]*u[1]*u[2] #predator
end

u0 = [1.0;1.0]
tspan = (0.0,10.0)
t = collect(range(0, stop=10., length=200))
p = [1.5,1.0,3.0,1.0]
nom_prob = ODEProblem(f,u0,tspan,p)
nom_sol = solve(nom_prob, Tsit5())

## These are the features of the model output that we want to quantify the sensitivity of
function features(p)
  prob = remake(nom_prob; p=p)
  sol = solve(prob, Tsit5(); saveat = t)
  return [mean(sol[1,:]), maximum(sol[2,:])]
end
nom_features = features(p)

## loss function, we can take as l2 difference of features vs nominal features
function loss(p)
prob = remake(nom_prob; p=p)
p_features = features(p)
loss = sum(abs2, p_features - nom_features)
return loss
end

## gradient of loss function
function lossgrad(p,g)
g[:] = ForwardDiff.gradient(p) do p
loss(p)
end
return loss(p)
end

## package the loss and gradient into a DiffCost structure
cost = DiffCost(loss, lossgrad)

"""
We evaluate the hessian once only, at p.
Why? to find locally insensitive directions of parameter perturbation
The small eigenvalues of the Hessian are one way of defining these directions
"""
hess0 = ForwardDiff.hessian(loss,p)
ev(i) = eigen(hess0).vectors[:,i]

## Now we set up a minimally disruptive curve, with nominal parameters p and initial direction ev(1)
init_dir = ev(1); momentum = 1.; span = (-15.,15.)
curve_prob = curveProblem(cost, p, init_dir, momentum, span)
@time mdc = evolve(curve_prob, Tsit5)

function sol_at_p(p)
   prob = remake(nom_prob; p=p)
   sol = solve(prob, Tsit5())
end
cost_vec = [mdc.cost(el) for el in eachcol(trajectory(mdc))]

p1 = plot(mdc; pnames = ["prey_growth_rate", "prey_predation_rate", "predator_death_rate", "predator_growth_rate"])

p2 = plot(distances(mdc), log.(cost_vec), ylabel = "log(cost)", xlabel = "distance", title = "cost over MD curve");

mdc_plot = plot(p1,p2, layout=(2,1), size = (800,800))
nominal_trajectory = plot(nom_sol)
example_trajectory = plot(sol_at_p(mdc(-15.)[:states]))
traj_comparison = plot(nominal_trajectory, example_trajectory, layout = (2,1))

mdc_plot:

Top pane: at distance x along the curve, the four lines depict how much the parameters have changed, relative to their nominal values

Bottom pane: at distance x along the curve, shows the cost (squared change in model features) of the perturbed parameters.

traj_comparison:
new_lotka_traj

The original simulation trajectory (p =[1.5,1.0,3.0,1.0] ), compared to with a new parameter set on the MD curve, at distance = -15: (p = [0.71, 2.26, 17.01, 5.83]).

Model features are almost exactly the same in each case ([2.98, 4.57] vs [2.94, 4.57])

1 Like

GIFS of the first two Minimally disruptive curves for the Lotka Volterra model…illustrates the point better

I asked MinimallyDisruptiveCurves.jl to find trajectories in parameter space preserving mean(prey) and max(predator). It solves a differential equation on the parameters to generate these trajectories:
MDC 1
mdc_1_fps15
MDC2
mdc_2_fps15

3 Likes

I’m enjoying reading the documentation. I wish more packages had thoughtful motivation/explanation/features/caveats sections like this!

Awesome docs yeah.

thanks both! if there’s anything about them you find hard to follow i’d love to get critical feedback to improve it