Diffeq ContinuousCallback "not callable" errors

My use of ContinuousCallback from DifferentialEquations.jl is throwing an error and I don’t know why. Any help would be appreciated. The short version of the error is:

LoadError: MethodError: objects of type SciMLBase.ContinuousCallback{typeof(Main.TestCB.always_condition), typeof(Main.TestCB.print_affect2!), typeof(Main.TestCB.print_affect2!), typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Float64, Int64, Rational{Int64}, Nothing, Int64} are not callable

I might have the syntax wrong in my code, but I am trying to follow examples in the DifferentialEquations.jl online documents and also in Discourse. While I place my code within a module and this might require some syntax changes to get ContinuousCallback to work, note that the same error occurs if I don’t put my code in a module (and move the code in run() into the main script).

I provide a working example below. This is a simple example of using a neural net to control a parameter of an ODE problem. Here, my ODE parameters are [position, velocity], of a car, for example. The loss function is meaningless (try to maintain a squared velocity of 10). I want to use multiple callbacks, one to print results and one to terminate the ODE problem once a condition is met. Later I will use others. My understanding is that I should do this using two ContinuousCallback functions, which would be joined via a CallbackSet. I cannot get the one ContinuousCallback for printing to work, much less the CallbackSet.

Following online examples, I can successfully use a simple function for a callback, with parameters (p_hat, loss, and pred), which is my function print_affect1(). I can’t find anywhere in the documentation where such a callback and its signature is defined. I only see documentation for ContinuousCallback and other related ones that pass integrator as a parameter. Is this simple type of callback defined somewhere in the docs?

The example code is below.

module TestCB

import DiffEqFlux
import DifferentialEquations
import Lux
import Optimization
import Random
import Zygote

using Formatting
using Infiltrator
using Revise

Difeq = DifferentialEquations 

# --------------------------------------------------------------------------------------------------
function dudt(du, u, p, t)
	# function for dudt, e.g., u=[position, velocity], and p = delta velocity
	s, v = u
    dv = p
    du[1] = v
    du[2] = dv
end

# --------------------------------------------------------------------------------------------------
# system dynamics derivative with the controller included
function dudt_controlled(du, u, p, t)
	dv = NNet([u[1],u[2]], p, st)[1]
	
	# plug force into system dynamics
	dudt(du, u, dv[1], t)
end

# --------------------------------------------------------------------------------------------------
# predict trajectory given some estimated parameters
function predict(p_hat)
    ans = Difeq.solve(
        prob, 
        Difeq.Tsit5(), 
        p = p_hat, 
        saveat=tsteps,
        )
    return Array(ans)  
end

# --------------------------------------------------------------------------------------------------
# loss function to miminize
function loss(p_hat)
    pred = predict(p_hat)
    s, v = pred
    loss = sum( v.^2 .- 10)    
    return loss, pred
end

# --------------------------------------------------------------------------------------------------
# first test callback function, works when used without  ContinuousCallback
function print_affect1(p_hat, loss, pred)
    # print every iteration
    printfmtln("loss= {}", loss)
    return false
end

# --------------------------------------------------------------------------------------------------
# second test callback function, for use with ContinuousCallback
function print_affect2!(integrator)
    # get loss from integrator?
    #printfmtln("loss= {}", loss)
    println("working")
    
end

# --------------------------------------------------------------------------------------------------
function always_condition(u, t, integrator)
    true
end

# --------------------------------------------------------------------------------------------------
function terminate_condition(u, t, integrator)
    # stop solution at t>= 10
    10 - t
end

# --------------------------------------------------------------------------------------------------
# run problem
function run()
    
    #setup NN and ODE problem
    global NNet = Lux.Chain(
        Lux.Dense(2, 8, tanh), 
        Lux.Dense(8, 1),
        (x) -> Lux.softplus.(x)
        )
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    x = randn(rng, Float32, 2, 8)
    global ps, st = Lux.setup(rng, NNet)
    ps = Lux.ComponentArray(ps)


    u0 = [0.0, 0.0]
    N = 50
    global tspan = (0.0, 10.0)
    global tsteps = range(tspan[1], length = N, tspan[2])

    # set up ODE problem
    global prob = Difeq.ODEProblem((du, u, p, t) -> dudt_controlled(du, u, p, t), u0, tspan, ps)
    

   # setup ContinuousCallbacks
    terminate_affect!(integrator) = Difeq.terminate!(integrator)
    cb_terminate = Difeq.ContinuousCallback(terminate_condition, terminate_affect!)

    # print callback
    cb_print = Difeq.ContinuousCallback(always_condition, print_affect2!)
    
    # cb set
    cb_set = Difeq.CallbackSet(cb_print, cb_terminate)
        
    # do optimization
    adtype = Optimization.AutoZygote()
    optf = Optimization.OptimizationFunction((x, ps) -> loss(x), adtype)

    optprob = Optimization.OptimizationProblem(optf, ps)
    
    # this works
    #res1_uode = Optimization.solve(optprob, DiffEqFlux.ADAM(0.01), callback=print_affect1, maxiters = 100)
    
    # this fails
    res1_uode = Optimization.solve(optprob, DiffEqFlux.ADAM(0.01), callback=cb_print, maxiters = 100)
    
    # this fails
    #res1_uode = Optimization.solve(optprob, DiffEqFlux.ADAM(0.01), callback=cb_set, maxiters = 100)

    # this fails
    #res1_uode = Optimization.solve(optprob, DiffEqFlux.ADAM(0.01), callback=cb_terminate, maxiters = 100)

end  # -- end run()


# outside of module functions, initialize some variables
NNet = Nothing
ps = Nothing
st = Nothing
tsteps = Nothing
prob = Nothing


end  # end module -----------------------------------------------


# call run() to run the problem
TestCB.run() 

The full error when callback=cb_print is used is below. Similar errors occur for other callbacks.

ERROR: LoadError: MethodError: objects of type SciMLBase.ContinuousCallback{typeof(Main.TestCB.always_condition), typeof(Main.TestCB.print_affect2!), typeof(Main.TestCB.print_affect2!), typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Float64, Int64, Rational{Int64}, Nothing, Int64} are not callable
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/OptimizationFlux/cpWyO/src/OptimizationFlux.jl:34 [inlined]
  [2] macro expansion
    @ ~/.julia/packages/Optimization/NmJMd/src/utils.jl:37 [inlined]
  [3] __solve(prob::SciMLBase.OptimizationProblem{true, SciMLBase.OptimizationFunction{true, Optimization.AutoZygote, Main.TestCB.var"#3#7", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2), NamedTuple())), bias = ViewAxis(17:24, ShapedAxis((8, 1), NamedTuple())))), layer_2 = ViewAxis(25:33, Axis(weight = ViewAxis(1:8, ShapedAxis((1, 8), NamedTuple())), bias = ViewAxis(9:9, ShapedAxis((1, 1), NamedTuple())))), layer_3 = 34:33)}}}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::Flux.Optimise.Adam, data::Base.Iterators.Cycle{Tuple{Optimization.NullData}}; maxiters::Int64, callback::SciMLBase.ContinuousCallback{typeof(Main.TestCB.always_condition), typeof(Main.TestCB.print_affect2!), typeof(Main.TestCB.print_affect2!), typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Float64, Int64, Rational{Int64}, Nothing, Int64}, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OptimizationFlux ~/.julia/packages/OptimizationFlux/cpWyO/src/OptimizationFlux.jl:30
  [4] #solve#489
    @ ~/.julia/packages/SciMLBase/xWByK/src/solve.jl:71 [inlined]
  [5] run()
    @ Main.TestCB ~/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb2.jl:132
  [6] top-level scope
    @ ~/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb2.jl:155
  [7] include(fname::String)
    @ Base.MainInclude ./client.jl:451
  [8] top-level scope
    @ REPL[3]:1
  [9] macro expansion
    @ ~/.julia/packages/Infiltrator/qHXTS/src/Infiltrator.jl:680 [inlined]
 [10] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
in expression starting at /home/jboik/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb2.jl:155

It looks like you stuck a Difeq.ContinuousCallback differential equation callback function into an optimization solve? The differential equation callback should go into the differential equation solve in the predict function.

Thanks Chris. Ugg, of course. With this fix my ContinuousCallback function is now called properly. But there is another problem. It seems that the outer DiffEqFlux.ADAM optimization fails if the inner ODE solve is terminated using a callback.

The brief error is:
LoadError: AssertionError: haskey(kwargs, :callback) && length(sol.t) < 3

Any ideas on how to fix that? I added save_positions=(true,true) in the ContinuousCallback function, but this did not fix the error. Also, the ODE problem is successfully solved if it is called on its own, not within the outer DiffEqFlux.ADAM optimization.

The revised code is as follows. It will print out the line specified in the terminate_condition function numerous times (until u[1] - 2 crosses zero), and then the fail occurs.

module TestCB

import DiffEqFlux
import DifferentialEquations
import Lux
import Optimization
import Random
import Zygote

using Formatting
using Infiltrator
using Revise

Difeq = DifferentialEquations 

# --------------------------------------------------------------------------------------------------
function dudt(du, u, p, t)
	# function for dudt, e.g., u=[position, velocity], and p = delta velocity
	s, v = u
    dv = p
    du[1] = v
    du[2] = dv
end

# --------------------------------------------------------------------------------------------------
# system dynamics derivative with the controller included
function dudt_controlled(du, u, p, t)
	dv = NNet([u[1],u[2]], p, st)[1]
	
	# plug force into system dynamics
	dudt(du, u, dv[1], t)
end

# --------------------------------------------------------------------------------------------------
function terminate_condition(u, t, integrator)
    printfmtln("terminate_condition, t= {}, u[1]= {}", t, u[1])  # many results printed prior to fail
    u[1] - 2.0
end

terminate_affect!(integrator) = Difeq.terminate!(integrator)
cb_terminate = Difeq.ContinuousCallback(
    terminate_condition, 
    terminate_affect!,
    save_positions=(true,true))

# --------------------------------------------------------------------------------------------------
# predict trajectory given some estimated parameters
function predict(p_hat)
    ans = Difeq.solve(
        prob, 
        Difeq.Tsit5(), 
        p = p_hat, 
        saveat=tsteps,
        callback=cb_terminate
        )
    return Array(ans)  
end

# --------------------------------------------------------------------------------------------------
# loss function to miminize
function loss(p_hat)
    pred = predict(p_hat)
    s, v = pred
    loss = sum( v.^2 .- 10)    
    return loss, pred
end

# --------------------------------------------------------------------------------------------------
# first test callback function, works when used without  ContinuousCallback
function cb_optimization(p_hat, loss, pred)
    # print every iteration
    printfmtln("loss= {}", loss)
    return false
end

# --------------------------------------------------------------------------------------------------
# run problem
function run()
    
    #setup NN and ODE problem
    global NNet = Lux.Chain(
        Lux.Dense(2, 8, tanh), 
        Lux.Dense(8, 1),
        (x) -> Lux.softplus.(x)
        )
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    x = randn(rng, Float32, 2, 8)
    global ps, st = Lux.setup(rng, NNet)
    ps = Lux.ComponentArray(ps)


    u0 = [0.0, 0.0]
    N = 50
    global tspan = (0.0, 10.0)
    global tsteps = range(tspan[1], length = N, tspan[2])

    # set up ODE problem
    global prob = Difeq.ODEProblem((du, u, p, t) -> dudt_controlled(du, u, p, t), u0, tspan, ps)
    
   # do optimization
    adtype = Optimization.AutoZygote()
    optf = Optimization.OptimizationFunction((x, ps) -> loss(x), adtype)

    optprob = Optimization.OptimizationProblem(optf, ps)
    
    
    res1_uode = Optimization.solve(optprob, DiffEqFlux.ADAM(0.01), callback=cb_optimization, maxiters = 100)

end  # -- end run()

# outside of module functions, initialize some variables
NNet = Nothing
ps = Nothing
st = Nothing
tsteps = Nothing
prob = Nothing

end  # end module -----------------------------------------------

# call run() to run the problem
TestCB.run()

The full error is:

ERROR: LoadError: AssertionError: haskey(kwargs, :callback) && length(sol.t) < 3
Stacktrace:
  [1] _concrete_solve_adjoint(::SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2), NamedTuple())), bias = ViewAxis(17:24, ShapedAxis((8, 1), NamedTuple())))), layer_2 = ViewAxis(25:33, Axis(weight = ViewAxis(1:8, ShapedAxis((1, 8), NamedTuple())), bias = ViewAxis(9:9, ShapedAxis((1, 1), NamedTuple())))), layer_3 = 34:33)}}}, SciMLBase.ODEFunction{true, SciMLBase.AutoSpecialize, Main.TestCB.var"#2#5", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::SciMLSensitivity.ForwardDiffSensitivity{0, nothing}, ::Vector{Float64}, ::ComponentVector{Float32}, ::SciMLBase.ChainRulesOriginator; saveat::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, kwargs::Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol}, NamedTuple{(:verbose, :callback), Tuple{Bool, SciMLBase.ContinuousCallback{typeof(Main.TestCB.terminate_condition), typeof(Main.TestCB.terminate_affect!), typeof(Main.TestCB.terminate_affect!), typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Float64, Int64, Rational{Int64}, Nothing, Int64}}}})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/G9VHq/src/concrete_solve.jl:634
  [2] _concrete_solve_adjoint(::SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2), NamedTuple())), bias = ViewAxis(17:24, ShapedAxis((8, 1), NamedTuple())))), layer_2 = ViewAxis(25:33, Axis(weight = ViewAxis(1:8, ShapedAxis((1, 8), NamedTuple())), bias = ViewAxis(9:9, ShapedAxis((1, 1), NamedTuple())))), layer_3 = 34:33)}}}, SciMLBase.ODEFunction{true, SciMLBase.AutoSpecialize, Main.TestCB.var"#2#5", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Nothing, ::Vector{Float64}, ::ComponentVector{Float32}, ::SciMLBase.ChainRulesOriginator; verbose::Bool, kwargs::Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol}, NamedTuple{(:saveat, :callback), Tuple{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, SciMLBase.ContinuousCallback{typeof(Main.TestCB.terminate_condition), typeof(Main.TestCB.terminate_affect!), typeof(Main.TestCB.terminate_affect!), typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Float64, Int64, Rational{Int64}, Nothing, Int64}}}})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/G9VHq/src/concrete_solve.jl:163
  [3] #_solve_adjoint#55
    @ ~/.julia/packages/DiffEqBase/5rKYk/src/solve.jl:1237 [inlined]
  [4] #rrule#53
    @ ~/.julia/packages/DiffEqBase/5rKYk/src/solve.jl:1190 [inlined]
  [5] chain_rrule_kw
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:230 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0 [inlined]
  [7] _pullback
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:9 [inlined]
  [8] _apply
    @ ./boot.jl:814 [inlined]
  [9] adjoint
    @ ~/.julia/packages/Zygote/dABKa/src/lib/lib.jl:203 [inlined]
 [10] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [11] _pullback
    @ ~/.julia/packages/DiffEqBase/5rKYk/src/solve.jl:801 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::DiffEqBase.var"##solve#31", ::Nothing, ::Nothing, ::ComponentVector{Float32}, ::Val{true}, ::Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol}, NamedTuple{(:saveat, :callback), Tuple{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, SciMLBase.ContinuousCallback{typeof(Main.TestCB.terminate_condition), typeof(Main.TestCB.terminate_affect!), typeof(Main.TestCB.terminate_affect!), typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Float64, Int64, Rational{Int64}, Nothing, Int64}}}}, ::typeof(CommonSolve.solve), ::SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2), NamedTuple())), bias = ViewAxis(17:24, ShapedAxis((8, 1), NamedTuple())))), layer_2 = ViewAxis(25:33, Axis(weight = ViewAxis(1:8, ShapedAxis((1, 8), NamedTuple())), bias = ViewAxis(9:9, ShapedAxis((1, 1), NamedTuple())))), layer_3 = 34:33)}}}, SciMLBase.ODEFunction{true, SciMLBase.AutoSpecialize, Main.TestCB.var"#2#5", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [13] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [14] adjoint
    @ ~/.julia/packages/Zygote/dABKa/src/lib/lib.jl:203 [inlined]
 [15] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [16] _pullback
    @ ~/.julia/packages/DiffEqBase/5rKYk/src/solve.jl:793 [inlined]
 [17] _pullback(::Zygote.Context{false}, ::CommonSolve.var"#solve##kw", ::NamedTuple{(:p, :saveat, :callback), Tuple{ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2), NamedTuple())), bias = ViewAxis(17:24, ShapedAxis((8, 1), NamedTuple())))), layer_2 = ViewAxis(25:33, Axis(weight = ViewAxis(1:8, ShapedAxis((1, 8), NamedTuple())), bias = ViewAxis(9:9, ShapedAxis((1, 1), NamedTuple())))), layer_3 = 34:33)}}}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, SciMLBase.ContinuousCallback{typeof(Main.TestCB.terminate_condition), typeof(Main.TestCB.terminate_affect!), typeof(Main.TestCB.terminate_affect!), typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Float64, Int64, Rational{Int64}, Nothing, Int64}}}, ::typeof(CommonSolve.solve), ::SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2), NamedTuple())), bias = ViewAxis(17:24, ShapedAxis((8, 1), NamedTuple())))), layer_2 = ViewAxis(25:33, Axis(weight = ViewAxis(1:8, ShapedAxis((1, 8), NamedTuple())), bias = ViewAxis(9:9, ShapedAxis((1, 1), NamedTuple())))), layer_3 = 34:33)}}}, SciMLBase.ODEFunction{true, SciMLBase.AutoSpecialize, Main.TestCB.var"#2#5", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [18] _pullback
    @ ~/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb.jl:50 [inlined]
 [19] _pullback(ctx::Zygote.Context{false}, f::typeof(Main.TestCB.predict), args::ComponentVector{Float32})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [20] _pullback
    @ ~/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb.jl:63 [inlined]
 [21] _pullback(ctx::Zygote.Context{false}, f::typeof(Main.TestCB.loss), args::ComponentVector{Float32})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [22] _pullback
    @ ~/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb.jl:105 [inlined]
 [23] _pullback(::Zygote.Context{false}, ::Main.TestCB.var"#3#6", ::ComponentVector{Float32}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [24] _apply
    @ ./boot.jl:814 [inlined]
 [25] adjoint
    @ ~/.julia/packages/Zygote/dABKa/src/lib/lib.jl:203 [inlined]
 [26] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [27] _pullback
    @ ~/.julia/packages/SciMLBase/xWByK/src/scimlfunctions.jl:3289 [inlined]
 [28] _pullback(::Zygote.Context{false}, ::SciMLBase.OptimizationFunction{true, Optimization.AutoZygote, Main.TestCB.var"#3#6", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::ComponentVector{Float32}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [29] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [30] adjoint
    @ ~/.julia/packages/Zygote/dABKa/src/lib/lib.jl:203 [inlined]
 [31] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [32] _pullback
    @ ~/.julia/packages/Optimization/NmJMd/src/function/zygote.jl:30 [inlined]
 [33] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#146#158"{SciMLBase.OptimizationFunction{true, Optimization.AutoZygote, Main.TestCB.var"#3#6", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}, args::ComponentVector{Float32})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [34] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [35] adjoint
    @ ~/.julia/packages/Zygote/dABKa/src/lib/lib.jl:203 [inlined]
 [36] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [37] _pullback
    @ ~/.julia/packages/Optimization/NmJMd/src/function/zygote.jl:37 [inlined]
 [38] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#149#161"{Tuple{}, Optimization.var"#146#158"{SciMLBase.OptimizationFunction{true, Optimization.AutoZygote, Main.TestCB.var"#3#6", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}}, args::ComponentVector{Float32})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [39] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float32})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:44
 [40] pullback
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:42 [inlined]
 [41] gradient(f::Function, args::ComponentVector{Float32})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:96
 [42] (::Optimization.var"#147#159"{Optimization.var"#146#158"{SciMLBase.OptimizationFunction{true, Optimization.AutoZygote, Main.TestCB.var"#3#6", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}})(::ComponentVector{Float32}, ::ComponentVector{Float32})
    @ Optimization ~/.julia/packages/Optimization/NmJMd/src/function/zygote.jl:32
 [43] macro expansion
    @ ~/.julia/packages/OptimizationFlux/cpWyO/src/OptimizationFlux.jl:32 [inlined]
 [44] macro expansion
    @ ~/.julia/packages/Optimization/NmJMd/src/utils.jl:37 [inlined]
 [45] __solve(prob::SciMLBase.OptimizationProblem{true, SciMLBase.OptimizationFunction{true, Optimization.AutoZygote, Main.TestCB.var"#3#6", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2), NamedTuple())), bias = ViewAxis(17:24, ShapedAxis((8, 1), NamedTuple())))), layer_2 = ViewAxis(25:33, Axis(weight = ViewAxis(1:8, ShapedAxis((1, 8), NamedTuple())), bias = ViewAxis(9:9, ShapedAxis((1, 1), NamedTuple())))), layer_3 = 34:33)}}}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::Flux.Optimise.Adam, data::Base.Iterators.Cycle{Tuple{Optimization.NullData}}; maxiters::Int64, callback::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OptimizationFlux ~/.julia/packages/OptimizationFlux/cpWyO/src/OptimizationFlux.jl:30
 [46] #solve#489
    @ ~/.julia/packages/SciMLBase/xWByK/src/solve.jl:71 [inlined]
 [47] run()
    @ Main.TestCB ~/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb.jl:110
 [48] top-level scope
    @ ~/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb.jl:124
 [49] include(fname::String)
    @ Base.MainInclude ./client.jl:451
 [50] top-level scope
    @ REPL[1]:1
 [51] macro expansion
    @ ~/.julia/packages/Infiltrator/qHXTS/src/Infiltrator.jl:680 [inlined]
in expression starting at /home/jboik/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb.jl:124

Hmm I’m not sure about this. I think @frankschae was looking at fixing some issues with terminate! support in adjoints here:

It would be good to add your case to that. retcode == :terminated has some special support but I’m not entirely sure why it would require no time series before that: that seems only for steady state handling. I’ll need to chat with Frank about that.

I added this assertion in Handle terminating callbacks by frankschae · Pull Request #690 · SciML/SciMLSensitivity.jl · GitHub
The reason was: if one saves at other times, there can be more or less time points in the primal solution
SciMLSensitivity.jl/concrete_solve.jl at 0cc379240098d4cc5d964b9425d6811e9ec3df49 · SciML/SciMLSensitivity.jl · GitHub
than in the solution using dual numbers:
SciMLSensitivity.jl/concrete_solve.jl at 0cc379240098d4cc5d964b9425d6811e9ec3df49 · SciML/SciMLSensitivity.jl · GitHub
because the dual components also enter the error estimate in the adaptive solver, and thus, different time stepping might occur.

Therefore, if one saves several time steps using ForwardDiffSensitivity, there can be bounds errors in SciMLSensitivity.jl/concrete_solve.jl at master · SciML/SciMLSensitivity.jl · GitHub (or incorrect sensitivities).

I don’t think that’s quite the right solution though. It’s the right solution only if save_everystep=false. If time series are saved, then it needs to be filtered to match the t of the forward, which would be the fix for this case.

I agree it’s a very conservative solution. How can we make

filtering to match the t of the forward

work if the terminating callback with the dual numbers is triggered earlier?

You need to cut off any t after the termination point.

1 Like

Using the same code (for a MWE NN-ODE control problem) I see there is another issue. If I run the optimization for many iterations I get an out of memory error (on my Linux system) and Julia crashes. Streamlined code is below that shows the problem. Each few hundred iterations of the outer optimization problem uses up a GB or so of RAM until all RAM is used up.

Any ideas on how to address this memory issue? Calling gc() within the predict() function reduces the memory growth rate, but does not eliminate the problem and it slows down the code. Other things that do not make a difference include defining the ODE problem within the predict() function, eliminating the callback, explicitly setting iip to {true} for the optimization function, and moving the code in main() up into the module level and removing all global statements.

If I skip the outer optimization and just run the ODE probem predict(ps) several thousand times, the memory problem does not seem to occur. So the problem must be in the outer optimization.

Regarding the previous terminating callback issue, if it will take a while to fix on your end, is there a hack I can do on my end to get the terminating callback working in the inner ODE problem?

module TestCB

import DiffEqFlux
import DifferentialEquations
import Lux
import Optimization
import Random
import Zygote

using Formatting
using Infiltrator
using Revise

Difeq = DifferentialEquations 

# --------------------------------------------------------------------------------------------------
function dudt!(du, u, p, t)
	# function for dudt, e.g., u=[position, velocity], and p = delta velocity
	s, v = u
    dv = p
    du[1] = v
    du[2] = dv
end

# --------------------------------------------------------------------------------------------------
# system dynamics derivative with the controller included
function dudt_controlled!(du, u, p, t)
	dv = NNet([u[1],u[2]], p, st)[1]
	dudt!(du, u, dv[1], t)
end

# --------------------------------------------------------------------------------------------------
# predict trajectory given some estimated parameters
function predict(p_hat)
    ans = Difeq.solve(
        prob, 
        Difeq.Tsit5(), 
        p = p_hat, 
        saveat=tsteps,
        )
    return Array(ans)  
end

# --------------------------------------------------------------------------------------------------
# loss function to minimize
function loss(p_hat)
    pred = predict(p_hat)
    s, v = pred
    loss = sum( v.^2 .- 10)    
    return loss, pred
end

# --------------------------------------------------------------------------------------------------
function cb(p_hat, loss, pred)
    # print every iteration
    global ii
    ii += 1
    if ii % 100 == 0
        printfmtln("ii= {}, loss= {}", ii, loss)
    end
    return false
end

# --------------------------------------------------------------------------------------------------
# run problem
function run()
    
    #setup NN and ODE problem
    global NNet = Lux.Chain(
        Lux.Dense(2, 80, tanh), 
        Lux.Dense(80, 1, tanh),
        )
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    x = randn(rng, Float32, 2, 8)
    global ps, st = Lux.setup(rng, NNet)
    ps = Lux.ComponentArray(ps)


    u0 = [0.0, 0.0]
    N = 50
    global tspan = (0.0, 10.0)
    global tsteps = range(tspan[1], length = N, tspan[2])

    # set up ODE problem
    global prob = Difeq.ODEProblem((du, u, p, t) -> dudt_controlled!(du, u, p, t), u0, tspan, ps)
    
   # do optimization
    adtype = Optimization.AutoZygote()
    optf = Optimization.OptimizationFunction{true}((x, ps) -> loss(x), adtype)

    optprob = Optimization.OptimizationProblem(optf, ps)
    res1_uode = Optimization.solve(optprob, DiffEqFlux.ADAM(0.01), callback=cb, maxiters = 100000)


end  # -- end run()

# outside of module functions, initialize some variables
NNet = nothing
ps = nothing
st = nothing
tsteps = nothing
prob = nothing
ii = 0
end  # end module -----------------------------------------------

# call run() to run the problem
TestCB.run()

It should work with a different sensealg; see the tests in
https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/steady_state.jl#L500-L521.