-
the mean prey population over time
-
max predator population over time
-
global sensitivity analysis tools in the above link (roughly) assign an ‘importance’ to each parameter, in terms of how much they vary the considered model features. They do this by sampling lots of different parameter combinations.
-
Our method explicitly finds relationships between parameters over which the model features vary minimally, without blind sampling. The MD curve we generated took < 5 seconds on my laptop.
-
The MD curve shows that we can concurrently increase the predator growth and death rates by over 500%, while preserving model features almost perfectly. We just need to apply particular compensations to the prey growth/predation rates.
The code:
using DiffEqSensitivity, OrdinaryDiffEq, ForwardDiff, MinimallyDisruptiveCurves, Statistics, Plots, LinearAlgebra
function f(du,u,p,t)
du[1] = p[1]*u[1] - p[2]*u[1]*u[2] #prey
du[2] = -p[3]*u[2] + p[4]*u[1]*u[2] #predator
end
u0 = [1.0;1.0]
tspan = (0.0,10.0)
t = collect(range(0, stop=10., length=200))
p = [1.5,1.0,3.0,1.0]
nom_prob = ODEProblem(f,u0,tspan,p)
nom_sol = solve(nom_prob, Tsit5())
## These are the features of the model output that we want to quantify the sensitivity of
function features(p)
prob = remake(nom_prob; p=p)
sol = solve(prob, Tsit5(); saveat = t)
return [mean(sol[1,:]), maximum(sol[2,:])]
end
nom_features = features(p)
## loss function, we can take as l2 difference of features vs nominal features
function loss(p)
prob = remake(nom_prob; p=p)
p_features = features(p)
loss = sum(abs2, p_features - nom_features)
return loss
end
## gradient of loss function
function lossgrad(p,g)
g[:] = ForwardDiff.gradient(p) do p
loss(p)
end
return loss(p)
end
## package the loss and gradient into a DiffCost structure
cost = DiffCost(loss, lossgrad)
"""
We evaluate the hessian once only, at p.
Why? to find locally insensitive directions of parameter perturbation
The small eigenvalues of the Hessian are one way of defining these directions
"""
hess0 = ForwardDiff.hessian(loss,p)
ev(i) = eigen(hess0).vectors[:,i]
## Now we set up a minimally disruptive curve, with nominal parameters p and initial direction ev(1)
init_dir = ev(1); momentum = 1.; span = (-15.,15.)
curve_prob = curveProblem(cost, p, init_dir, momentum, span)
@time mdc = evolve(curve_prob, Tsit5)
function sol_at_p(p)
prob = remake(nom_prob; p=p)
sol = solve(prob, Tsit5())
end
cost_vec = [mdc.cost(el) for el in eachcol(trajectory(mdc))]
p1 = plot(mdc; pnames = ["prey_growth_rate", "prey_predation_rate", "predator_death_rate", "predator_growth_rate"])
p2 = plot(distances(mdc), log.(cost_vec), ylabel = "log(cost)", xlabel = "distance", title = "cost over MD curve");
mdc_plot = plot(p1,p2, layout=(2,1), size = (800,800))
nominal_trajectory = plot(nom_sol)
example_trajectory = plot(sol_at_p(mdc(-15.)[:states]))
traj_comparison = plot(nominal_trajectory, example_trajectory, layout = (2,1))
mdc_plot:
Top pane: at distance x along the curve, the four lines depict how much the parameters have changed, relative to their nominal values
Bottom pane: at distance x along the curve, shows the cost (squared change in model features) of the perturbed parameters.
traj_comparison:
The original simulation trajectory (p =[1.5,1.0,3.0,1.0]
), compared to with a new parameter set on the MD curve, at distance = -15: (p = [0.71, 2.26, 17.01, 5.83]
).
Model features are almost exactly the same in each case ([2.98, 4.57] vs [2.94, 4.57])