My use of ContinuousCallback from DifferentialEquations.jl is throwing an error and I don’t know why. Any help would be appreciated. The short version of the error is:
LoadError: MethodError: objects of type SciMLBase.ContinuousCallback{typeof(Main.TestCB.always_condition), typeof(Main.TestCB.print_affect2!), typeof(Main.TestCB.print_affect2!), typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Float64, Int64, Rational{Int64}, Nothing, Int64} are not callable
I might have the syntax wrong in my code, but I am trying to follow examples in the DifferentialEquations.jl online documents and also in Discourse. While I place my code within a module and this might require some syntax changes to get ContinuousCallback to work, note that the same error occurs if I don’t put my code in a module (and move the code in run()
into the main script).
I provide a working example below. This is a simple example of using a neural net to control a parameter of an ODE problem. Here, my ODE parameters are [position, velocity], of a car, for example. The loss function is meaningless (try to maintain a squared velocity of 10). I want to use multiple callbacks, one to print results and one to terminate the ODE problem once a condition is met. Later I will use others. My understanding is that I should do this using two ContinuousCallback functions, which would be joined via a CallbackSet. I cannot get the one ContinuousCallback for printing to work, much less the CallbackSet.
Following online examples, I can successfully use a simple function for a callback, with parameters (p_hat, loss, and pred), which is my function print_affect1()
. I can’t find anywhere in the documentation where such a callback and its signature is defined. I only see documentation for ContinuousCallback and other related ones that pass integrator
as a parameter. Is this simple type of callback defined somewhere in the docs?
The example code is below.
module TestCB
import DiffEqFlux
import DifferentialEquations
import Lux
import Optimization
import Random
import Zygote
using Formatting
using Infiltrator
using Revise
Difeq = DifferentialEquations
# --------------------------------------------------------------------------------------------------
function dudt(du, u, p, t)
# function for dudt, e.g., u=[position, velocity], and p = delta velocity
s, v = u
dv = p
du[1] = v
du[2] = dv
end
# --------------------------------------------------------------------------------------------------
# system dynamics derivative with the controller included
function dudt_controlled(du, u, p, t)
dv = NNet([u[1],u[2]], p, st)[1]
# plug force into system dynamics
dudt(du, u, dv[1], t)
end
# --------------------------------------------------------------------------------------------------
# predict trajectory given some estimated parameters
function predict(p_hat)
ans = Difeq.solve(
prob,
Difeq.Tsit5(),
p = p_hat,
saveat=tsteps,
)
return Array(ans)
end
# --------------------------------------------------------------------------------------------------
# loss function to miminize
function loss(p_hat)
pred = predict(p_hat)
s, v = pred
loss = sum( v.^2 .- 10)
return loss, pred
end
# --------------------------------------------------------------------------------------------------
# first test callback function, works when used without ContinuousCallback
function print_affect1(p_hat, loss, pred)
# print every iteration
printfmtln("loss= {}", loss)
return false
end
# --------------------------------------------------------------------------------------------------
# second test callback function, for use with ContinuousCallback
function print_affect2!(integrator)
# get loss from integrator?
#printfmtln("loss= {}", loss)
println("working")
end
# --------------------------------------------------------------------------------------------------
function always_condition(u, t, integrator)
true
end
# --------------------------------------------------------------------------------------------------
function terminate_condition(u, t, integrator)
# stop solution at t>= 10
10 - t
end
# --------------------------------------------------------------------------------------------------
# run problem
function run()
#setup NN and ODE problem
global NNet = Lux.Chain(
Lux.Dense(2, 8, tanh),
Lux.Dense(8, 1),
(x) -> Lux.softplus.(x)
)
rng = Random.default_rng()
Random.seed!(rng, 0)
x = randn(rng, Float32, 2, 8)
global ps, st = Lux.setup(rng, NNet)
ps = Lux.ComponentArray(ps)
u0 = [0.0, 0.0]
N = 50
global tspan = (0.0, 10.0)
global tsteps = range(tspan[1], length = N, tspan[2])
# set up ODE problem
global prob = Difeq.ODEProblem((du, u, p, t) -> dudt_controlled(du, u, p, t), u0, tspan, ps)
# setup ContinuousCallbacks
terminate_affect!(integrator) = Difeq.terminate!(integrator)
cb_terminate = Difeq.ContinuousCallback(terminate_condition, terminate_affect!)
# print callback
cb_print = Difeq.ContinuousCallback(always_condition, print_affect2!)
# cb set
cb_set = Difeq.CallbackSet(cb_print, cb_terminate)
# do optimization
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, ps) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)
# this works
#res1_uode = Optimization.solve(optprob, DiffEqFlux.ADAM(0.01), callback=print_affect1, maxiters = 100)
# this fails
res1_uode = Optimization.solve(optprob, DiffEqFlux.ADAM(0.01), callback=cb_print, maxiters = 100)
# this fails
#res1_uode = Optimization.solve(optprob, DiffEqFlux.ADAM(0.01), callback=cb_set, maxiters = 100)
# this fails
#res1_uode = Optimization.solve(optprob, DiffEqFlux.ADAM(0.01), callback=cb_terminate, maxiters = 100)
end # -- end run()
# outside of module functions, initialize some variables
NNet = Nothing
ps = Nothing
st = Nothing
tsteps = Nothing
prob = Nothing
end # end module -----------------------------------------------
# call run() to run the problem
TestCB.run()
The full error when callback=cb_print is used is below. Similar errors occur for other callbacks.
ERROR: LoadError: MethodError: objects of type SciMLBase.ContinuousCallback{typeof(Main.TestCB.always_condition), typeof(Main.TestCB.print_affect2!), typeof(Main.TestCB.print_affect2!), typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Float64, Int64, Rational{Int64}, Nothing, Int64} are not callable
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/OptimizationFlux/cpWyO/src/OptimizationFlux.jl:34 [inlined]
[2] macro expansion
@ ~/.julia/packages/Optimization/NmJMd/src/utils.jl:37 [inlined]
[3] __solve(prob::SciMLBase.OptimizationProblem{true, SciMLBase.OptimizationFunction{true, Optimization.AutoZygote, Main.TestCB.var"#3#7", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2), NamedTuple())), bias = ViewAxis(17:24, ShapedAxis((8, 1), NamedTuple())))), layer_2 = ViewAxis(25:33, Axis(weight = ViewAxis(1:8, ShapedAxis((1, 8), NamedTuple())), bias = ViewAxis(9:9, ShapedAxis((1, 1), NamedTuple())))), layer_3 = 34:33)}}}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::Flux.Optimise.Adam, data::Base.Iterators.Cycle{Tuple{Optimization.NullData}}; maxiters::Int64, callback::SciMLBase.ContinuousCallback{typeof(Main.TestCB.always_condition), typeof(Main.TestCB.print_affect2!), typeof(Main.TestCB.print_affect2!), typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Float64, Int64, Rational{Int64}, Nothing, Int64}, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ OptimizationFlux ~/.julia/packages/OptimizationFlux/cpWyO/src/OptimizationFlux.jl:30
[4] #solve#489
@ ~/.julia/packages/SciMLBase/xWByK/src/solve.jl:71 [inlined]
[5] run()
@ Main.TestCB ~/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb2.jl:132
[6] top-level scope
@ ~/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb2.jl:155
[7] include(fname::String)
@ Base.MainInclude ./client.jl:451
[8] top-level scope
@ REPL[3]:1
[9] macro expansion
@ ~/.julia/packages/Infiltrator/qHXTS/src/Infiltrator.jl:680 [inlined]
[10] top-level scope
@ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
in expression starting at /home/jboik/Devel/EVCO/evco/evco/EvcoJulia/optCntrl/test_cb2.jl:155