Custom NeuralODE layer trains innefficiently

I slightly changed a NeuralODE layer and added a few new weights parameters, but it
can’t fit the simple data and show the warning

┌ Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable.
└ @ SciMLBase /home/solar/.julia/packages/SciMLBase/HbD6U/src/integrator_interface.jl:345
┌ Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable.
└ @ SciMLBase /home/solar/.julia/packages/SciMLBase/HbD6U/src/integrator_interface.jl:345

Code too long though:

using DiffEqFlux
using OrdinaryDiffEq, Flux, Optim, Plots
using Flux, OrdinaryDiffEq
using Zygote 
using DiffEqSensitivity # for ZygteVJP?


abstract type NeuralDELayer <: Function end
basic_tgrad(u,p,t) = zero(u)

struct LTC{M,P,RE,T,TA,AB,A,K} <: NeuralDELayer
    model::M
    p::P # weights
    p_len::Int # for assignment 
    re::RE
    tspan::T
    τ::TA # weights
    τ_len::Int #
    A::AB # weights
    A_len::Int
    args::A
    kwargs::K

    function LTC(model,tspan, τ, A, args...;p = nothing,kwargs...)
        _p,re = Flux.destructure(model) # is it like [p;τ;A] already? 
        if p === nothing
            p = _p
        end
        new{typeof(model),typeof(p),typeof(re),
            typeof(tspan), typeof(τ), typeof(A),
            typeof(args),typeof(kwargs)}(
            model,p, length(p), re,tspan,τ,length(τ),A,length(A),args,kwargs)
    end
end

function (n::LTC)(x)
    function dudt_(u, p, t)
       p_ = @view p[1:n.p_len]
       τ_ = @view p[n.p_len+1:n.p_len+n.τ_len]
       τ_ = Flux.softplus.(τ_) # to ensure τ>=0
       A_ = @view p[n.p_len+n.τ_len+1:end]
       h = -(1 ./τ_+ n.re(p_)(u)) .* u +  n.re(p_)(u) .* A_
    end
    ff = ODEFunction{false}(dudt_,tgrad=basic_tgrad) 
    prob = ODEProblem{false}(ff,x,getfield(n,:tspan), [n.p; n.τ; n.A]) # inital conditions and tspan, etc
    sense = InterpolatingAdjoint(autojacvec=ZygoteVJP()) 
    solve(prob,n.args...;sense=sense,n.kwargs...)
end

Flux.trainable(m::LTC) = (m.p, m.τ, m.A)


# Example 
u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 3.0f0)

function trueODEfunc(du,u,p,t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),saveat=t))
τ = rand(2) # ~size(out) 
A = zeros(2) # ~size(out)

dudt = Chain(x -> x.^3,
             Dense(2,10,tanh),
             Dense(10,2))
ps = Flux.params(n_ode)

pred = n_ode(u0) # Get the prediction using the correct initial condition
scatter(t,ode_data[1,:],label="data")
scatter!(t,pred[1,:],label="prediction")

function predict_n_ode()
  n_ode(u0)
end
loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())


data = Iterators.repeated((), 1000)
opt = ADAM(0.01)
cb = function () #callback function to observe training
    nothing
end
Flux.train!(loss_n_ode, ps, data, opt, cb = cb)

Maybe @lungd can give some advice?

Was there an issue on this one?

Nope, I created the issue