Train neural ODE with mini-batch in different initial conditions

Hi, I am trying to train a neural ODE with mini-batch in different initial conditions. I found some tutorials regarding mini-batch but they are mostly focusing on batching different time points with the a single initial condition (e.g. an example post is this: https://diffeqflux.sciml.ai/stable/examples/minibatch/). I am wondering if it is possible to do mini-batching for different initial conditions.

Specifically, here is my code that uses a single initial condition:

tmp_in = rand(25)   # a single initial condition, which is a 25-dim vector
target_y = rand()     # a single output

# define neural ODE and time range
d_in, dim_hidden = 25, 16
tspan = (0., 1.)
t = range(tspan[1], tspan[2], length=100)
dudt = Chain(Dense(d_in, dim_hidden, tanh), Dense(dim_hidden, dim_hidden, tanh), Dense(dim_hidden, d_in,tanh))
n_ode = NeuralODE(dudt, tspan, Tsit5(), saveat=t, reltol=1e-7, abstol=1e-9);

# fitting
θ = n_ode.p
opt = ADAM(1e-3)
maxiters = 200
function predict(θ)
    result = n_ode(tmp_in, θ)
    ŷ = result.u[length(result.u)][[1], :]
#     print(size(ŷ))
    return ŷ
end
function loss(θ)
    result = Flux.mse(predict(θ), target_y)
    print(result)
    return result
end
# loss(θ) = sum(predict(θ))
res = DiffEqFlux.sciml_train(loss, θ, opt, maxiters=10)

This code works perfectly fine, but if I use a batch of initial conditions like this:

tmp_in = cat(rand(115, 24), rand(115, 1), dims=2)';
target_y = rand(1, 115)
# other parts are the same

I got following error:

BoundsError: attempt to access 3988-element Vector{Float64} at index [1:25, 1:115]

Does anyone know what this error means and how I should fix it?

Any suggestions would be much appreciated, thanks in advance!

Just make your input 25xN

1 Like

I usually wrote the training loop myself and feed the initial condition iteratively. For example https://github.com/DENG-MIT/CRNN/blob/5994dc218e55749d4f54654447408bf3e96a6d89/case1/case1.jl#L193.

Not sure if this is a good solution. It works great for me.

Thanks for your comment! Would it be slower if you run it iteratively rather than run all in parallel in a single batch?

Good question. I usually use batch size of one. So that only one initial condition is feed into the optimization. If you need a larger batch size, parallelling would be more efficient. As Chris mentioned, simply use pull all initial conditions into a matrix.

1 Like

That very much depends on the ODE solver. It can be more efficient to batch as a matrix because that can use some pretty well-optimized (GPU) kernels. However, as mentioned in one of our preprints, some ODE solvers (i.e. the ones for stiff equations) exhibit worse than linear scaling, in while case piling in more information via batches can grow compute costs by O(n^3), in which case you should split.

3 Likes