Hi this is a follow up to an earlier post. I’m trying to use multiple_shooting and I am just trying to run the basic example script shown:
Can anyone confirm that the example code on that page is both syntactically and semantically correct? I’m getting an error on the second to last line. As I write in the comments, the error message is different depending on whether I run it on Julia 1.6.7 on x86 Linux or Julia 1.10.8 on Apple Silicon.
My code is right here:
using ComponentArrays
using Lux
using DiffEqFlux
using Optimization
using OptimizationPolyalgorithms
using OrdinaryDiffEq
using Plots
using DiffEqFlux: group_ranges
using Random
#THIS CODE IS COPIED VERBATIM FROM:
# docs.sciml.ai/DiffEqFlux/dev/examples/multiple_shooting
#
rng = Random.default_rng()
#rng = Xoshiro(0)
#Define initial conditions and time steps
datasize = 30
u0 = Float32[2.0,0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[1], tspan[2]; length = datasize) #will be Float32
#Get the data
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u .^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps))
#define the neural network
nn = Lux.Chain(x -> x .^3, Lux.Dense(2, 16, tanh), Lux.Dense(16, 2))
#p_init, st = Lux.setup(rng, nn)
p_init, st = Lux.setup(rng, nn)
#p_init = Float32.(p64) #broadcast convert... THIS TRIGGERS AN ERROR
ps = ComponentArray(p_init) #is this Float32?
pd, pax = getdata(ps), getaxes(ps)
neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps)
prob_node = ODEProblem((u, p, t) -> nn(u, p, st)[1], u0, tspan, ps)
#define parameters for multiple multiple shooting
group_size = 3
continuity_term = 200
function loss_function(data, pred)
return sum(abs2, data - pred)
end
l1, preds = multiple_shoot(
ps,
ode_data,
tsteps,
prob_node,
loss_function,
Tsit5(),
group_size;
continuity_term
)
function loss_multiple_shooting(p)
ps = ComponentArray(p, pax)
loss, currpred = multiple_shoot(
ps,
ode_data,
tsteps,
prob_node,
loss_function,
Tsit5(),
group_size;
continuity_term
)
global preds = currpred
return loss
end
function plot_multiple_shoot(plt, preds, group_size)
step = group_size - 1
ranges = group_ranges(datasize, group_size)
for (i, ig) in enumerate(range)
plot!(plt, tsteps[rg], preds[i][1,:]; markershape = :circle, label = "Group $(i)")
end
end
anim = Plots.Animation()
iter = 0
function calback(state, l; doplot = true, prob_node = prob_node)
display(l)
global iter
iter += 1
if doplot && iter % 1 == 0
#plot the original data
plt = scatter(tsteps, ode_data[1,:]; label = "Data")
#plot the different predictions for individual shoot
l1, preds = multiple_shoot(
ComponentArray(state.u, pax),
ode_data,
tsteps,
prob_node,
loss_function,
Tsit5(),
group_size;
continuity_term
)
plot_multiple_shoot(plt, preds, group_size)
frame(anim)
display(plot(plt))
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 300)
#ERROR OCCURS:
#On my MacBook running Apple Silicon with Julia 1.10.8, the error is
#Method Error: no matching iterate(::typeof(range))
#On my x86 running Ubuntu 24.04 with Julia 1.6.7, the error is:
#Type array has no field u. I assume that refers to the line
#mutiple_shoot(ComponentArray(state.u, pax), ode_data,....)
gif(anim, "multiple_shooting.gif"; fps = 15)
I suspect somehow I’m mixing Float32 and Float64 values though I’m trying to only use Float32
I’d appreciate any assistance