ODE parameter estimation, parameter not updating

Hi, I have this code to estimate the parameters of a ODE system. I do not know why it is not updating the parameters inside the optimization. I print them and they do not change. This is a test sample of the code:

import Pkg

using DifferentialEquations,  DataFrames,  CSV, Plots, LinearAlgebra, ODE, DataInterpolations,
 DiffEqParamEstim, Optimization,  Statistics, Dates,ForwardDiff, OptimizationOptimJL, OptimizationBBO, OrdinaryDiffEq,
 OptimizationPolyalgorithms, SciMLSensitivity, Zygote

# Constants
const N = 5
const m_c = 0.001 # probability of mosquito in a car

# Create fake data
const end_ind = 5

# Create fake mobility matrix
const eta =  [0.0       7744.15   5240.14     8.57707     5.21319;
        6.9175    0.0       7.23771    45.2212     16.8025;
        61731.9   7.426     0.0        17.7236      54998.8;
        8.40241  44.21     20.1597      0.0      5309.46;
        5830.9   17.5283    5.36841  5226.36        0.0]

# Choose number of patches and IC
pop_init = zeros(end_ind)
pop_init[1] = 1
pop_init[3] = 1

# Non autnomous model ----------------------------------------------------------------------
function fun_na!(du, u, p, t)
  for i in 1:N
      du[i]= sum(j -> p[1] * m_c * eta[i,j] * u[j], 1:N) * (1 - u[i]) -
      p[2]*u[i]
  end
  # println("du:$du[1:5]")
end

# Set initial parameters ------------------------------------------------------------------
t_obs = [ 592, 957,1323,1688,2053,2418,2784]
t0= t_obs[1] - 10
tf= t_obs[end] + 10
tspan = (t0, tf)
t_vect=1:tf
u0 = pop_init
p = [0.001,0.00000001]

# Create the model
prob = DifferentialEquations.ODEProblem(fun_na!, u0, tspan, p)

# Create the cost function -----------------------------------------------------------------
vec_year = [2007,2008,2009,2010,2011,2012,2013] # Vector of summer times per year

# Function to extract times related to summer per year
function summer_times_by_year(R_M)
    summer_times = Dict{Int, Vector{Int}}()
    for y in vec_year
        summer_times[y] = filter(row -> year(row.date) == y && month(row.date) in 6:8, R_M).time
    end
    return summer_times
end

# Run through RM
summer_t_obs_by_year = Dict(
  2013 => collect(3074:3165),
  2012 => collect(2709:2800),
  2011 => collect(2343:2434),
  2010 => collect(1978:2069),
  2009 => collect(1613:1704),
  2008 => collect(1248:1339),
  2007 => collect(882:973)
)


# Function to compute average probability of occupancy in summer per year and comarca
function average_summer_solution_by_year(sol)
    sol_mat = zeros(Float64,N,length(vec_year))
    j = 1
    for years in vec_year
        times = summer_t_obs_by_year[years]
        vec_sol = Vector{Float64}(undef,N)
        for i in 1:N
            vec_sol[i] = mean(sol(times, idxs = i))
        end  
        sol_mat[:,j] .= vec_sol
        j = j+1 
    end
    return sol_mat
end

# Update p in the ODE
function hanski_prediction(p)
	_prob = remake(prob, p = p)
	sol = solve(_prob, alg_hints=[:stiff]; abstol=1e-8, reltol=1e-8)
  return(sol)
end

# Create matrix observations -------------------------------------------------------------
obs = hanski_prediction(Float64[0.004,0.0000001])
matrix_obs = hcat(obs(t_obs)...)
plot(obs)

# Loss function
function summer_loss_by_year(p)
  println("p:", ForwardDiff.value.(p))
  sol = hanski_prediction(ForwardDiff.value.(p))
  loss = 0.0
  if any((!SciMLBase.successful_retcode(s.retcode) for s in sol)) # Test for not succesfull integration
      print("Loss function INF\n")
      flush(stdout)
      loss = Inf
  else
    summer_avg_by_year = average_summer_solution_by_year(sol)
    loss = sum((summer_avg_by_year .- matrix_obs).^2) 
    println("loss:", ForwardDiff.value.(loss))
    flush(stdout)
  end
  # println("loss:", ForwardDiff.value.(loss))
  # flush(stdout)
  loss
end

# Parameters bounds
low_bound = [0.0,0.0]
up_bound = [2,2]

# Optimizatin -------------------------------------------------------------------------------
adtype = Optimization.AutoForwardDiff() #AutoFiniteDiff()
optf = Optimization.OptimizationFunction((p,x) -> summer_loss_by_year(p), adtype) 

# Loop through different initial parameters parallely --------------------------------------------
init_param = Float64[0.004,0.0000001]
init_param[1] = init_param[1] + (rand(1)/100)[1]
init_param[2] = init_param[2] + (rand(1)/100)[1]

println("init_param:", init_param)

# Create the optimization problem with the new parameters
optprob = Optimization.OptimizationProblem(optf, init_param, lb = low_bound, ub = up_bound) # Create the optimization problem

# Do a try catch to not break when integration problem
try
  optsol = Optimization.solve(optprob,OptimizationOptimJL.LBFGS(), maxiters = 1000)
  print("Finish optimization\n")
catch e
  println("ERROR: $e")
end

This is the output:

p:[0.005178516823542502, 0.0015114340264134889]
loss:1.640687614977965
p:[0.005178516823542502, 0.0015114340264134889]
loss:1.640687614977965
p:[0.005178516823542502, 0.0015114340264134889]
loss:1.640687614977965
retcode: Success
u: 2-element Vector{Float64}:
0.005178516823542502
0.0015114340264134889

The parameter p is always the same. But I do not understand why this is happening.

Don’t manually remove gradient information and then wonder why you get zero gradients. Using the undocumented internals of the AD libraries should only be done with caution because they aren’t supposed to be used by users.

sol = hanski_prediction(ForwardDiff.value.(p))

That’s a very bad idea. If you get rid of derivatives… then the derivative is zero. So the optimization exits.

function summer_loss_by_year(p)
  println("p:", ForwardDiff.value.(p))
  sol = hanski_prediction(p)
  if any((!SciMLBase.successful_retcode(s.retcode) for s in sol)) # Test for not succesfull integration
      print("Loss function INF\n")
      flush(stdout)
      return Inf
  else
    summer_avg_by_year = average_summer_solution_by_year(sol)
    loss = sum((summer_avg_by_year .- matrix_obs).^2) 
    println("loss:", ForwardDiff.value.(loss))
    flush(stdout)
    return loss
  end
end

Then you shouldn’t just assume Float64.

function average_summer_solution_by_year(sol)    
    stack([[mean(sol(summer_t_obs_by_year[years], idxs = i)) for i in 1:N] for years in vec_year])
end

Those two changes then it’s fine.

1 Like

Now it works perfectly. I did not know I was removing gradient information. I did it after the problem with float64 since I thought this was the problem. Thanks a lot for the solution I really appreciate it :slight_smile: