ADAM solver returns parameter values resulting in poor fit for my SDE model

Hi! I am a plant breeding master student that recently got into SDEs. I am investigating the seasonal change in biomass of some wheat genotypes from the day they are sowed until the day of harvest. I want to fit a SDE model and obtain the values of certain parameters of interest based on a logistic growth model, where r (growth rate), is modified by a daily irradiance incidence value RADN (represented as an interpolation of irradiance data at time t). Mmax is the biomass limitation of the system, and c is a conversion factor. The parameters that I want to estimate for each genotype are thus: r, c, Mmax and Ļƒ.

The model goes as follows:

using DifferentialEquations
#Drift term
function expgrowth!(du, u, p, t) 
    r,Mmax,c,Ļƒ = p;
    du[1] = (r + RADN(t) * c ) * u[1] * (1 - (u[1] / Mmax))   
end

for the drift term and:

#Diffusion term
function Ļƒ_expgrowth!(du, u, p, t)
    r,Mmax,c,Ļƒ = p;
    du[1] = 0.1 * (( Ļƒ * ((Mmax - u[1])/Mmax)) * (1 - ((Mmax - u[1])/Mmax)))	
end

for the diffusion term. An example of biomass data and time in days after sowing (das) could look like:

Biomass = [0.0046, 0.0057, 0.007, 0.00969, 0.01219, 0.0157, 0.0197, 0.0236, 0.0293, 0.0337, 0.0399, 0.0455, 0.0511, 0.0598, 0.0634, 0.0746, 0.0875, 0.1004, 0.1148, 0.1211, 0.1419, 0.1629, 0.18309, 0.196199, 0.222, 0.2507, 0.2899, 0.3371, 0.3876, 0.4442, 0.5074, 0.573299, 0.64579, 0.7352, 0.8101, 0.8966, 1.0083, 1.12829, 1.2568, 1.3932, 1.53720, 1.634, 1.7255, 1.81809, 1.8932, 2.0165, 2.101, 2.2249, 2.3578, 2.4878, 2.5888, 2.688, 2.8294, 2.9598, 3.0624, 3.15769, 3.2377, 3.3089, 3.372, 3.4297, 3.4974, 3.5553, 3.6081, 3.655, 3.6976, 3.746, 3.8569, 3.9536, 4.0417, 4.114, 4.1832, 4.234, 4.2788, 4.3196, 4.35619, 4.3918, 4.4245, 4.45339, 4.4773, 4.5013, 4.597, 4.6778, 4.7396, 4.7935, 4.8405, 4.8798, 4.9131, 4.9421, 4.9865, 5.0676, 5.117, 5.15339, 5.1806, 5.2021, 5.2166, 5.2306, 5.2424, 5.2542, 5.2651, 5.2733, 5.276, 5.2803, 5.2837, 5.2862, 5.2882, 5.2897, 5.2911, 5.2923, 5.29339, 5.2945, 5.2954, 5.2963, 5.2973, 5.2974, 5.2974, 5.2974]
das = [7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0]

The irradiance interpolation is obtained using the package DataInterpolations.jl:

using DataInterpolations

irradiance = [18.0, 8.0, 9.0, 17.0, 13.0, 16.0, 16.0, 13.0, 17.0, 11.0, 14.0, 11.0, 10.0, 14.0, 5.0, 15.0, 15.0, 13.0, 13.0, 5.0, 16.0, 14.0, 12.0, 7.0, 13.0, 13.0, 16.0, 17.0, 16.0, 16.0, 16.0, 15.0, 15.0, 17.0, 13.0, 14.0, 17.0, 17.0, 17.0, 17.0, 17.0, 11.0, 10.0, 10.0, 8.0, 13.0, 9.0, 13.0, 17.0, 15.0, 17.0, 16.0, 17.0, 17.0, 18.0, 18.0, 16.0, 18.0, 17.0, 17.0, 19.0, 19.0, 19.0, 18.0, 19.0, 7.0, 16.0, 14.0, 19.0, 19.0, 20.0, 20.0, 20.0, 19.0, 18.0, 16.0, 19.0, 6.0, 5.0, 5.0, 20.0, 21.0, 20.0, 21.0, 20.0, 20.0, 20.0, 18.0, 12.0, 22.0, 23.0, 22.0, 20.0, 20.0, 22.0, 23.0, 20.0, 22.0, 20.0, 21.0, 10.0, 23.0, 24.0, 24.0, 24.0, 24.0, 24.0, 24.0, 24.0, 24.0, 24.0, 24.0, 25.0, 25.0, 25.0, 26.0]

#Bspline of irradiance data on das
RADN = BSplineApprox( irradiance, das, 3, 5, :Uniform, :Uniform, extrapolate = true)

With this, I create a SDE problem and try to find the parameters r, c, Mmax and Ļƒ:

#tspan for sde problem
tspan = (das[begin], das[end])

#inital guess for parameters r, c, Mmax and Ļƒ
p_init = [0.15, 1.01*Biomass[end], -0.002, 2.0]

p = [0.15, 1.01*Biomass[end], -0.002, 2.0]
r, Mmax, c, Ļƒ = p

#initial condition
u0 = [Biomass[begin]]

#create the SDE problem
sde_prob = SDEProblem(expgrowth!, Ļƒ_expgrowth!, u0, tspan, p)

Then I define the loss function:

#loss for ADAM
function loss(theta)
    tmp_prob = remake(sde_prob; p = theta)
    ensembleprob_l = EnsembleProblem(tmp_prob)
    tmp_sol = solve(ensembleprob_l, SOSRI(), saveat = das, trajectories = 1_000)
    arrsol = Array(tmp_sol)
    sum(abs2, collect(transpose(Biomass)) - mean(arrsol, dims=3)),
    arrsol
end

And finally the optimization problem:

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p,x) -> loss(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)

I will use the ADAM optimizer, from Optimization.jl, as it is used in the majority of examples for SDE fitting in Julia

res = Optimization.solve(optprob, OptimizationOptimisers.ADAM(0.05), maxiters = 50)

#check results
res.objective
res.stats

What makes ADAM find values for the parameters that are not a good fit at all? I would also like to know what I can improve from my code. Because I am new to SDEs, I do not know what I am doing wrong. Is it a problem of the model?

Thanks in advance and regards!

Hi, and welcome to the Julia community!

Iā€™m afraid Iā€™m not able to run your code. Could you provide it in full, including in particular the imports?
(I tried
using DifferentialEquations, DataInterpolations, Optimization, OptimizationOptimisers, Zygote, SciMLSensitivity, Statistics
but got MethodErrors.)

I have no experience with SDEs, but I see you are using a learning rate much higher than Iā€™m used to, which is more like 1e-3 or 1e-4. What happens if you use a lower learning rate? Did you try a different optimiser, like SGD?

Yes, I will share the block of code. My apologies if I missed a package:

using Pkg

Pkg.add(["DifferentialEquations", "Plots", "DiffEqParamEstim", "RecursiveArrayTools", "OptimizationOptimJL",
"SciMLSensitivity", "OptimizationOptimisers", "StochasticDiffEq", "DataFrames", "CSV", "Zygote", "Optimization",
"Statistics","Glob","DataInterpolations","ComponentArrays"])

using DifferentialEquations, Plots, DiffEqParamEstim, RecursiveArrayTools, OptimizationOptimJL, SciMLSensitivity,
 OptimizationOptimisers, StochasticDiffEq, DataFrames, CSV, Zygote, Optimization, Statistics, Glob,
  DataInterpolations, ComponentArrays

And the remaining of the code just in case:

#Drift term
function expgrowth!(du, u, p, t) 
    r,Mmax,c,Ļƒ = p;
    du[1] = (r + RADN(t) * c ) * u[1] * (1 - (u[1] / Mmax))
end

#Diffusion term
function Ļƒ_expgrowth!(du, u, p, t)
    r,Mmax,c,Ļƒ = p;
    du[1] = 0.1 * (( Ļƒ * ((Mmax - u[1])/Mmax)) * (1 - ((Mmax - u[1])/Mmax)))	

#Biomass data
Biomass = [0.0046, 0.0057, 0.007, 0.00969, 0.01219, 0.0157, 0.0197, 0.0236, 0.0293, 0.0337, 0.0399, 0.0455, 0.0511, 0.0598, 0.0634, 0.0746, 0.0875, 0.1004, 0.1148, 0.1211, 0.1419, 0.1629, 0.18309, 0.196199, 0.222, 0.2507, 0.2899, 0.3371, 0.3876, 0.4442, 0.5074, 0.573299, 0.64579, 0.7352, 0.8101, 0.8966, 1.0083, 1.12829, 1.2568, 1.3932, 1.53720, 1.634, 1.7255, 1.81809, 1.8932, 2.0165, 2.101, 2.2249, 2.3578, 2.4878, 2.5888, 2.688, 2.8294, 2.9598, 3.0624, 3.15769, 3.2377, 3.3089, 3.372, 3.4297, 3.4974, 3.5553, 3.6081, 3.655, 3.6976, 3.746, 3.8569, 3.9536, 4.0417, 4.114, 4.1832, 4.234, 4.2788, 4.3196, 4.35619, 4.3918, 4.4245, 4.45339, 4.4773, 4.5013, 4.597, 4.6778, 4.7396, 4.7935, 4.8405, 4.8798, 4.9131, 4.9421, 4.9865, 5.0676, 5.117, 5.15339, 5.1806, 5.2021, 5.2166, 5.2306, 5.2424, 5.2542, 5.2651, 5.2733, 5.276, 5.2803, 5.2837, 5.2862, 5.2882, 5.2897, 5.2911, 5.2923, 5.29339, 5.2945, 5.2954, 5.2963, 5.2973, 5.2974, 5.2974, 5.2974]

#Days after pollination (t)
das = [7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0]

#irradiance data
irradiance = [18.0, 8.0, 9.0, 17.0, 13.0, 16.0, 16.0, 13.0, 17.0, 11.0, 14.0, 11.0, 10.0, 14.0, 5.0, 15.0, 15.0, 13.0, 13.0, 5.0, 16.0, 14.0, 12.0, 7.0, 13.0, 13.0, 16.0, 17.0, 16.0, 16.0, 16.0, 15.0, 15.0, 17.0, 13.0, 14.0, 17.0, 17.0, 17.0, 17.0, 17.0, 11.0, 10.0, 10.0, 8.0, 13.0, 9.0, 13.0, 17.0, 15.0, 17.0, 16.0, 17.0, 17.0, 18.0, 18.0, 16.0, 18.0, 17.0, 17.0, 19.0, 19.0, 19.0, 18.0, 19.0, 7.0, 16.0, 14.0, 19.0, 19.0, 20.0, 20.0, 20.0, 19.0, 18.0, 16.0, 19.0, 6.0, 5.0, 5.0, 20.0, 21.0, 20.0, 21.0, 20.0, 20.0, 20.0, 18.0, 12.0, 22.0, 23.0, 22.0, 20.0, 20.0, 22.0, 23.0, 20.0, 22.0, 20.0, 21.0, 10.0, 23.0, 24.0, 24.0, 24.0, 24.0, 24.0, 24.0, 24.0, 24.0, 24.0, 24.0, 25.0, 25.0, 25.0, 26.0]

#Bspline of irradiance data on das
RADN = BSplineApprox( irradiance, das, 3, 5, :Uniform, :Uniform, extrapolate = true)

#tspan for sde problem
tspan = (das[begin], das[end])

#inital guess for parameters r, c, Mmax and Ļƒ
p_init = [0.15, 1.01*Biomass[end], -0.002, 2.0]

p = [0.15, 1.01*Biomass[end], -0.002, 2.0]
r, Mmax, c, Ļƒ = p

#initial condition for biomass
u0 = [Biomass[begin]]

#create the SDE problem
sde_prob = SDEProblem(expgrowth!, Ļƒ_expgrowth!, u0, tspan, p)

#loss for ADAM
function loss(theta)
    tmp_prob = remake(sde_prob; p = theta)
    ensembleprob_l = EnsembleProblem(tmp_prob)
    tmp_sol = solve(ensembleprob_l, SOSRI(), saveat = das, trajectories = 1_000)
    arrsol = Array(tmp_sol)
    sum(abs2, collect(transpose(Biomass)[:,:]) - mean(arrsol, dims=3)),
    arrsol
end

#Optimization problem
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)

res = Optimization.solve(optprob, OptimizationOptimisers.ADAM(0.05), maxiters = 50)

end

Thanks for mentioning to check on the learning rate. I observed that for a more simple version of the model (just logistic growth without the c*RADN(t)), reducing the learning rate to 0.01 improves the fit. I will try reducing the learning rate for this model now.

1 Like

Ok, so after trying to decrease the learning rate, I am getting much smaller loss objectives. You were right. Thank you very much!

1 Like

FYI: I donā€™t believe Adam is the best optimizer, I believe by now Adan, not to be confused with it (or AdamW) is the best one.

But itā€™s not yet available in Julia (I might have an implementation soon):

Just considering the ā€œnā€ is Adan, for Nestov, maybe this other Optimisers.NAdam: Nesterov variant of the Adam optimizer is currently the best available in Julia/Flux.jl where you get your optimizers (indirectly) from. Lion is also (more/most? of those available there) recent. RAdam is also recent, but less so.

I believe all the Adam variants are a drop-in replacement for the original Adam, including AdamW that I see probably most used alternative, for LLMs at least. Iā€™m most familiar with LLMs, and not 100% sure rules-of-thumb from there translate to SDEs, though I believe all optimizers are a drop-in replacement for each other, even non-Adam variants, and independently of application area. I would like to know if Iā€™m wrong, at least you can try others, and see if the converge faster and/or find better optimum, but donā€™t take my word for them for sure be a replacement.

Note, Iā€™m no expert on this.

Abstract: This work aims to improve upon the recently proposed and rapidly popularized optimization algorithm Adam (Kingma & Ba, 2014). Adam has two main componentsā€”a momentum component and an adaptive learning rate component. However, regular momentum can be shown conceptually and empirically to be inferior to a similar algorithm known as Nesterovā€™s accelerated gradient (NAG). We show how to modify Adamā€™s momentum component to take advantage of insights from NAG, and then we present preliminary evidence suggesting that making this substitution improves the speed of convergence and the quality of the learned models.

There are many more available to consider, e.g.: OAdam: Optimistic Adam optimizer that I wasnā€™t aware of.

1 Like