Not all parameters being estimated in DiffEqFlux

I’m trying to fit a simple Von Bertalanffy curve to some data using DiffEqFlux but one of the parameters (init_p[3]) does not change with each iteration. I suspect it might have something to do with the step function inside prod_f(), which I am using to force a time delay before the curve begins rising. There is probably something I am not understanding about how the parameters are being accessed during the training process. Any help is appreciated!

Here is the code:

using DifferentialEquations, Flux, DiffEqFlux, Plots

function prod_f(du,u,p,t)
  L = u[1]
    #Von Bertalanffy curve
    k,Lmax,delay = p
   
    if t >= delay
        d = 1
    else
        d = 0
    end

    du[1] = k*(Lmax - L)*d
end


u0 = [0.0]
tspan = (1.0,500)
true_p = [0.01,25.0,100.0]
prob1 = ODEProblem(prod_f,u0,tspan,true_p,saveat = 1.0)
sol1 = solve(prob1,Tsit5())
plot(sol1,label = "")

x =   [14,14,15,17,17,18,18,19,19,20,20,21,22,23,24,25,26,27,28,29,30,
        32,34,35,36,38,39,40,41,42,45,46,47,48,49,50,51,52,53,55,56,58,59,62].*7

data = VectorOfArray([(sol1(x[i]) + .01randn(1)) for i in 1:length(x)])
data = convert(Array,data)
plot!(x,data',seriestype = :scatter)


init_p = [0.02,28.0,90.0]

function predict_adjoint(p) 
    Array(concrete_solve(prob1,Tsit5(),u0,p,saveat=x))
end

function loss_adjoint(p)
  prediction = predict_adjoint(p)
  loss = sum(abs2,data .- prediction)
  loss,prediction
end

cb = function (p,l,pred) #callback function to observe training
  # println(p)
  # display(l)
  # using `remake` to re-create our `prob` with current parameters `p`
  # display(plot(solve(remake(prob1,p=p),Tsit5(),saveat = x)))
  return false # Tell it to not halt the optimization. If return true, then optimization stops
end

 # Display the ODE with the initial parameter values.
cb(init_p,loss_adjoint(init_p)...)

res = DiffEqFlux.sciml_train(loss_adjoint, init_p, ADAM(0.05), cb = cb,
                            maxiters = 200)

plot!(solve(remake(prob1,p=res.minimizer),Tsit5(),saveat=0.0:0.1:500.0),
      label = "",linewidth = 3)
plot!(x,data',seriestype = :scatter)

The problem is that you’re using delay in a way that does not propagate gradient information, delay only appears in the conditional expression if t >= delay. You would need to use a different method to derive the gradient w.r.t. this parameter.

1 Like

Indeed try solving it as two ODEs (1.0,delay) and (delay, 500) and I think that might work?

1 Like

Thanks, this sounds like a good suggestion, but I’m not sure how exactly to implement this. Wouldn’t it require estimating the length of tspan? e.g., tspan = (1,p[3]) for one ode and tspan = (p[3],500) for the other. I tried the following but the delay (p[3]) still wasn’t estimated. Perhaps you had a different approach in mind?

function prod_f1(du,u,p,t)
  L = u[1]
  
    du[1] = 0.0

end

function prod_f2(du,u,p,t)
  L = u[1]
    #Von Bertalanffy curve
    k = p[1]
    Lmax = p[2]
    du[1] = k*(Lmax - L)
end


function loss_adjoint(p)
  tspan1 = (1.0,p[3])
  tspan2 = (p[3],500)
  probf1 = ODEProblem(prod_f1,u0,tspan1,p,saveat = 1.0)
  probf2 = ODEProblem(prod_f2,u0,tspan2,p,saveat = 1.0)
  aaa = Array(concrete_solve(probf1,Tsit5(),u0,p,saveat=1))
  bbb = Array(concrete_solve(probf2,Tsit5(),u0,p,saveat=1))
  prediction = [aaa bbb]
  loss = sum(abs2,data' .- prediction[x])
  loss,prediction
end

Even with forward mode that’s not estimated? That would surprise me.

1 Like

Sorry for the slow reply. using ForwardDiff indeed solved the problem! I’ve also found that tuning ADAM() is critical. Thanks for your help, it’s much appreciated!