Hi, thanks for all the great work!
I’ve been trying to run this example: Neural Ordinary Differential Equations · DiffEqFlux.jl (I attach it at the bottom for future reference)
So I’ve installed the DiffEqFlux 1.53.0 version, then the rest of the dependencies of the script, but still the script won’t run:
ComponentArray
is not part ofLux
namespace (sorry for python lingo) in the versions I have installed. Taking it fromComponentArrays.jl
seems to work- I need to add
using OptimizationOptimisers
for theADAM
optimizer to work.
These are simple fixes, but am I maybe misunderstanding how version pinning works in julia?
- Should this example have worked but versions were pinned in the wrong way, or
- are these typos in the docs, or
- is it common in Julia that code is moved around between different libraries, or are these just mutual imports and the real code lives somewhere else?
(@v1.9) pkg> st
Status `~/.julia/environments/v1.9/Project.toml`
⌃ [aae7a2af] DiffEqFlux v1.53.0
[0c46a032] DifferentialEquations v7.8.0
[7073ff75] IJulia v1.24.2
⌅ [b2108857] Lux v0.4.58
⌃ [7f7a1694] Optimization v3.14.0
⌃ [36348300] OptimizationOptimJL v0.1.8
[91a5bcdd] Plots v1.38.17
[9a3f8284] Random
using Lux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationOptimJL, Random, Plots
rng = Random.default_rng()
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
dudt2 = Lux.Chain(x -> x.^3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
function predict_neuralode(p)
Array(prob_neuralode(u0, p, st)[1])
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
# Do not plot by default for the documentation
# Users should change doplot=true to see the plots callbacks
callback = function (p, l, pred; doplot = false)
println(l)
# plot current prediction against data
if doplot
plt = scatter(tsteps, ode_data[1,:], label = "data")
scatter!(plt, tsteps, pred[1,:], label = "prediction")
display(plot(plt))
end
return false
end
pinit = Lux.ComponentArray(p)
callback(pinit, loss_neuralode(pinit)...; doplot=true)
# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
result_neuralode = Optimization.solve(optprob,
ADAM(0.05),
callback = callback,
maxiters = 300)
optprob2 = remake(optprob,u0 = result_neuralode.u)
result_neuralode2 = Optimization.solve(optprob2,
Optim.BFGS(initial_stepnorm=0.01),
callback=callback,
allow_f_increases = false)
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true)