Another issue, if I want to parallelize the looping over controls with @Threads.threads like
cd(@__DIR__)
using Pkg
Pkg.activate(".")
using Lux
using ComponentArrays
using Zygote
using SciMLSensitivity
using ForwardDiff
using DifferentialEquations
using Optimization
using OptimizationOptimJL
using OptimizationOptimisers
using Random
using CairoMakie
using SimpleChains
using Interpolations
using BenchmarkTools
# Neural ODE which accepts control as an argument
struct ControlledODE{M <: Lux.AbstractExplicitLayer, So, Se, T, Sa} <:
Lux.AbstractExplicitContainerLayer{(:model,)}
model::M
solver::So
sensealg::Se
tspan::T
saveat::Sa
end
function ControlledODE(model::Lux.AbstractExplicitLayer;
solver=Tsit5(),
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
tspan=(0.0f0, 1.0f0),
saveat=[])
return ControlledODE(model, solver, sensealg, tspan, saveat)
end
function (n::ControlledODE)(u0, control, ps, st; tspan=n.tspan, saveat=n.saveat)
function n_ode(u, p, t)
c = Float32.([ctrl(t) for ctrl in control])
du, _ = n.model([u; c], p, st)
return du
end
prob = ODEProblem(ODEFunction(n_ode), u0, tspan, ps)
return solve(prob, n.solver; sensealg=n.sensealg, saveat=saveat)
end
tspan = (0.0f0, 10.0f0)
Nt = 100
saveat = LinRange{Float32}(tspan[1], tspan[2], Nt)
# number of controlled trajectories
Nc = 1
rng = MersenneTwister(1111)
disc_controls = randn(rng, Nc, Nt)
controls = [linear_interpolation(saveat, disc_controls[i, :]) for i = 1:Nc]
# Define Neural ODE with controls applied in a for-loop
hidden_nodes = 10
weight_init_mag = 0.1
f = Lux.Chain(Lux.Dense(3, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, hidden_nodes, selu; init_weight=Lux.glorot_uniform(gain=weight_init_mag)),
Lux.Dense(hidden_nodes, 2; init_weight=Lux.glorot_uniform(gain=weight_init_mag)))
rng = MersenneTwister(1111)
ps, st = Lux.setup(rng, f)
ps = ComponentArray(ps)
# Sensitivity algorithm for AD
#sensealg = ForwardDiffSensitivity()
#sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP())
#sensealg = BacksolveAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
#sensealg = InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP(true))
model = ControlledODE(f; solver=Tsit5(), sensealg=sensealg, tspan=tspan, saveat=saveat)
function cost(ps, u0)
loss = 0.0f0
@Threads.threads for i = 1:Nc
pred = model(u0, [controls[i]], ps, st)
loss += sum(abs2, pred) / Nt
end
return loss
end
u0 = zeros(2)
@benchmark begin
l, back = pullback(p -> cost(p, u0), ps)
gs = back(one(l))[1]
end
I get the following error:
ERROR: Compiling Tuple{typeof(Base._wait), Task}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:101 [inlined]
[2] _pullback(ctx::Zygote.Context{false}, f::typeof(Base._wait), args::Task)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:101
[3] _pullback
@ ./threadingconstructs.jl:115 [inlined]
[4] _pullback(::Zygote.Context{false}, ::typeof(Base.Threads.threading_run), ::var"#185#threadsfor_fun#37"{var"#185#threadsfor_fun#36#38"{ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:40, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3), NamedTuple())), bias = ViewAxis(31:40, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(41:150, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(151:260, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_4 = ViewAxis(261:282, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10), NamedTuple())), bias = ViewAxis(21:22, ShapedAxis((2, 1), NamedTuple())))))}}}, Vector{Float64}, UnitRange{Int64}}}, ::Bool)
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[5] macro expansion
@ ./threadingconstructs.jl:168 [inlined]
[6] _pullback
@ ~/Research/constitutive_history/optimization/batch_control_mre.jl:88 [inlined]
[7] _pullback(::Zygote.Context{false}, ::typeof(cost), ::ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:40, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3), NamedTuple())), bias = ViewAxis(31:40, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(41:150, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(151:260, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_4 = ViewAxis(261:282, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10), NamedTuple())), bias = ViewAxis(21:22, ShapedAxis((2, 1), NamedTuple())))))}}}, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[8] _pullback
@ ~/Research/constitutive_history/optimization/batch_control_mre.jl:97 [inlined]
[9] _pullback(ctx::Zygote.Context{false}, f::var"#39#40", args::ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:40, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3), NamedTuple())), bias = ViewAxis(31:40, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(41:150, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(151:260, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_4 = ViewAxis(261:282, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10), NamedTuple())), bias = ViewAxis(21:22, ShapedAxis((2, 1), NamedTuple())))))}}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface2.jl:0
[10] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:40, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3), NamedTuple())), bias = ViewAxis(31:40, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(41:150, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(151:260, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_4 = ViewAxis(261:282, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10), NamedTuple())), bias = ViewAxis(21:22, ShapedAxis((2, 1), NamedTuple())))))}}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:44
[11] pullback(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:40, Axis(weight = ViewAxis(1:30, ShapedAxis((10, 3), NamedTuple())), bias = ViewAxis(31:40, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis(41:150, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_3 = ViewAxis(151:260, Axis(weight = ViewAxis(1:100, ShapedAxis((10, 10), NamedTuple())), bias = ViewAxis(101:110, ShapedAxis((10, 1), NamedTuple())))), layer_4 = ViewAxis(261:282, Axis(weight = ViewAxis(1:20, ShapedAxis((2, 10), NamedTuple())), bias = ViewAxis(21:22, ShapedAxis((2, 1), NamedTuple())))))}}})
@ Zygote ~/.julia/packages/Zygote/JeHtr/src/compiler/interface.jl:42
[12] macro expansion
@ ~/Research/constitutive_history/optimization/batch_control_mre.jl:97 [inlined]
[13] var"##core#1694"()
@ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:489
[14] var"##sample#1695"(::Tuple{}, __params::BenchmarkTools.Parameters)
@ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:495
[15] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:99
[16] #invokelatest#2
@ ./essentials.jl:818 [inlined]
[17] invokelatest
@ ./essentials.jl:813 [inlined]
[18] #run_result#45
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
[19] run_result
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
[20] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117
[21] run (repeats 2 times)
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117 [inlined]
[22] #warmup#54
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:169 [inlined]
[23] warmup(item::BenchmarkTools.Benchmark)
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:168
[24] top-level scope
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:393