Hi,
I would like to solve several ODE’s with contribution from a neural network concurrently.
I have the following (not working) example.
using StableRNGs
using ComponentArrays
using Zygote, Lux, OrdinaryDiffEq, SciMLSensitivity
using LinearAlgebra, Statistics
rng = StableRNG(1111);
rbf(x) = exp.(-(x .^ 2))
const U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
Lux.Dense(5, 2))
p, st = Lux.setup(rng, U)
const _st = st;
ps = ComponentVector{Float64}(p);
size_t = 10
# Initial Conditions
ICs = ComponentArray(x = 10.0*abs.(randn(size_t)), y = 4.0*abs.(randn(size_t)));
# parameters
param = ComponentArray(alpha = abs.(randn(size_t)),
beta = abs.(randn(size_t)),
gamma = abs.(randn(size_t)),
delta = abs.(randn(size_t)));
function ude_dynamics!(du, u, p, t, p_true)
û = U(reshape(u[:],size_t, 2)', p, _st)[1]
@. du.x = p_true.alpha * u.x + û[1, :]
@. du.y = -p_true.delta * u.y + û[2, :]
end
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, param);
tspan = (0.0, 5.0);
ts = LinRange(0.0, 5.0, 200);
prob_nn = ODEProblem(nn_dynamics!, ICs, tspan, ps);
function predict(θ, IC, T)
_prob = remake(prob_nn, u0 = IC, tspan = (T[1], T[end]))
Array(solve(_prob, Vern7(), saveat = T,
abstol = 1e-6, reltol = 1e-6,
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end
function loss(θ, IC, T)
X̂ = predict(θ, IC, T)
mean(abs2, 0.999*X̂ .- X̂) # Placeholder function. Just something to make it work
end
l, back = pullback(loss, ps, ICs, ts);
grads = back(l) # Problems..
When the line grads = back(l)
executes, I get the error: ERROR: type Array has no field x
which fallback to the line @. du.x = p_true.alpha * u.x + û[1, :]
How can I fix this?