I slightly changed a NeuralODE layer and added a few new weights parameters, but it
can’t fit the simple data and show the warning
┌ Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable.
└ @ SciMLBase /home/solar/.julia/packages/SciMLBase/HbD6U/src/integrator_interface.jl:345
┌ Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable.
└ @ SciMLBase /home/solar/.julia/packages/SciMLBase/HbD6U/src/integrator_interface.jl:345
Code too long though:
using DiffEqFlux
using OrdinaryDiffEq, Flux, Optim, Plots
using Flux, OrdinaryDiffEq
using Zygote
using DiffEqSensitivity # for ZygteVJP?
abstract type NeuralDELayer <: Function end
basic_tgrad(u,p,t) = zero(u)
struct LTC{M,P,RE,T,TA,AB,A,K} <: NeuralDELayer
model::M
p::P # weights
p_len::Int # for assignment
re::RE
tspan::T
τ::TA # weights
τ_len::Int #
A::AB # weights
A_len::Int
args::A
kwargs::K
function LTC(model,tspan, τ, A, args...;p = nothing,kwargs...)
_p,re = Flux.destructure(model) # is it like [p;τ;A] already?
if p === nothing
p = _p
end
new{typeof(model),typeof(p),typeof(re),
typeof(tspan), typeof(τ), typeof(A),
typeof(args),typeof(kwargs)}(
model,p, length(p), re,tspan,τ,length(τ),A,length(A),args,kwargs)
end
end
function (n::LTC)(x)
function dudt_(u, p, t)
p_ = @view p[1:n.p_len]
τ_ = @view p[n.p_len+1:n.p_len+n.τ_len]
τ_ = Flux.softplus.(τ_) # to ensure τ>=0
A_ = @view p[n.p_len+n.τ_len+1:end]
h = -(1 ./τ_+ n.re(p_)(u)) .* u + n.re(p_)(u) .* A_
end
ff = ODEFunction{false}(dudt_,tgrad=basic_tgrad)
prob = ODEProblem{false}(ff,x,getfield(n,:tspan), [n.p; n.τ; n.A]) # inital conditions and tspan, etc
sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
solve(prob,n.args...;sense=sense,n.kwargs...)
end
Flux.trainable(m::LTC) = (m.p, m.τ, m.A)
# Example
u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 3.0f0)
function trueODEfunc(du,u,p,t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),saveat=t))
τ = rand(2) # ~size(out)
A = zeros(2) # ~size(out)
dudt = Chain(x -> x.^3,
Dense(2,10,tanh),
Dense(10,2))
ps = Flux.params(n_ode)
pred = n_ode(u0) # Get the prediction using the correct initial condition
scatter(t,ode_data[1,:],label="data")
scatter!(t,pred[1,:],label="prediction")
function predict_n_ode()
n_ode(u0)
end
loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())
data = Iterators.repeated((), 1000)
opt = ADAM(0.01)
cb = function () #callback function to observe training
nothing
end
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)
Maybe @lungd can give some advice?