Neural ODE works for small networks, but throws error for larger networks (getindex() method)?

Hello all:

I’m running into strange behavior with neural ODEs defined with composite functions and FastChain definitions. In short, the below code works w/o issue for small neural networks, but fails for larger networks because of a getindex() issue. See this toy minimal (non)working example:

using Random, DiffEqFlux, DifferentialEquations, Flux, Optim
Random.seed!(0)

# Width <= 15 works fine:
width = 15
# Width >= 16 fails because of getindex method failure:
#width = 16

NN = FastChain(FastDense(1,width,swish), FastDense(width,1))
pNN = initial_params(NN)

p = [pNN;1.0]

function neural_ode(u, p, t)
    pNN = p[1:end-1]
    m = p[end]

    dudt = NN(u,pNN)[] - m
    return dudt
end

u0 = rand(1)[1]
tspan = (0.0,10.0)
t = Array(range(0,10,100))
prob_neuralode = ODEProblem(neural_ode, u0, tspan, p)

function loss_neuralode(p)
    trial = Array(solve(prob_neuralode,AutoTsit5(Rosenbrock23()),u0=u0,p=p,saveat=t,abstol = 1e-6,reltol = 1e-6))
    loss = sum(abs2, trial)
    return loss, trial
end

callback = function (p, l, pred; doplot = true)
    display(l)
    return false
end

result_neuralode = DiffEqFlux.sciml_train(loss_neuralode,
                                            p,
                                            ADAM(0.1),
                                            cb = callback,
                                            maxiters = 10)

And the beginning of the stacktrace:

ERROR: MethodError: no method matching getindex(::Float64, ::UnitRange{Int64})
Closest candidates are:
getindex(::Number) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/number.jl:95
getindex(::Union{AbstractChar, Number}, ::CartesianIndex{0}) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/multidimensional.jl:831
getindex(::Number, ::Integer) at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/number.jl:96

I’m running 1.7 for this example.

Has anyone experienced behavior? Any idea what I might be doing wrong?

Did you get a warning thrown during the solve that the solver diverged? I have a guess that is the case, and of course then you’d have a getindex error because it didn’t solve all of the way.

No solver divergence warning. To simplify the MWE even further, I’ve specified the RHS of the neural ODE as something trivial that is stable by construction (function dudt() always returns zero, regardless of the network architecture):

using Random, DiffEqFlux, DifferentialEquations, Flux, Optim
Random.seed!(0)

# Width <= 15 works fine:
width = 15
# Width >= 16 fails because of getindex method failure:
width = 16

NN = FastChain(FastDense(1,width,swish), FastDense(width,1))
pNN = initial_params(NN)

p = [zeros(length(pNN));1.0]

function neural_ode(u, p, t)
    pNN = p[1:end-1]
    m = p[end]

    dudt = zeros(length(NN(u,pNN)))[]
    return dudt
end

u0 = rand(1)[1]
tspan = (0.0,10.0)
t = Array(range(0,.10,100))
prob_neuralode = ODEProblem(neural_ode, u0, tspan, p)

function loss_neuralode(p)
    trial = Array(solve(prob_neuralode,AutoTsit5(Rosenbrock23()),u0=u0,p=p,saveat=t,abstol = 1e-6,reltol = 1e-6))
    loss = sum(abs2, trial.-zeros(length(t)))
    return loss, trial
end

callback = function (p, l, pred; doplot = true)
    display(l)
    return false
end

result_neuralode = DiffEqFlux.sciml_train(loss_neuralode,
                                            p,
                                            ADAM(0.1),
                                            cb = callback,
                                            maxiters = 10)

This example has the same behavior as that of my original post.

ForwardDiffSensitivity wasn’t compatible with u0 as a scalar (we always used arrays!). Fixed that here:

Other adjoints will throw a nice error, but instead of fixing the error message here, this dispatch was easy to just solve.

Ah! Thanks, Chris. Code is working now!