[DiffEqFlux.jl] DimensionMismatch("array could not be broadcast to match destination")

I’m trying to train a NeuralODE model on the MNIST dataset. When I take the gradient of loss function I get the DimensionMismatch error.

using DiffEqFlux, OrdinaryDiffEq, Flux, NNlib, MLDataUtils, Printf
using Flux: logitcrossentropy
using Flux.Data: DataLoader
using MLDatasets
using CUDA
using Random: seed!
CUDA.allowscalar(false)

function loadmnist(batchsize = bs, train_split = 0.9)
    # Use MLDataUtils LabelEnc for natural onehot conversion
    onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw,
                                      LabelEnc.NativeLabels(collect(0:9)))
    # Load MNIST
    imgs, labels_raw = MNIST.traindata();
    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs,1), size(imgs,2), 1, size(imgs,3)))
    y_data = onehot(labels_raw)
    (x_train, y_train), (x_test, y_test) = stratifiedobs((x_data, y_data),
                                                         p = train_split)
    return (
        # Use Flux's DataLoader to automatically minibatch and shuffle the data
        DataLoader(gpu.(collect.((x_train, y_train))); batchsize = batchsize,
                   shuffle = true),
        # Don't shuffle the test data
        DataLoader(gpu.(collect.((x_test, y_test))); batchsize = batchsize,
                   shuffle = false)
    )
end

# Main
const bs = 128
const train_split = 0.9
train_dataloader, test_dataloader = loadmnist(bs, train_split);


function DiffEqArray_to_Array(x)
    xarr = gpu(x)
    return reshape(xarr, size(xarr)[1:end-1])
end

down = Chain(Conv((3,3), 1=>64, pad=(0,0), relu),
        BatchNorm(64),
        Conv((4,4), 64=>64, stride=2, pad=1, relu),
        BatchNorm(64),
        Conv((4,4), 64=>64, stride=2, pad=1, relu)) |> gpu



convode_base= Chain(Conv((3,3), 64=>64, stride=1, pad=1, relu),
                BatchNorm(64),
                Conv((3,3), 64=>64, stride=1, pad=1, relu),
                BatchNorm(64)) |> gpu

convode = NeuralODE(convode_base, (0.f0, 1.f0), Tsit5(),
           save_everystep = false,
           reltol = 1e-3, abstol = 1e-3,
           save_start = false) |> gpu;

fc = Chain(Dense(2304,10)) |> gpu;

model = Chain(down,
        convode,
        DiffEqArray_to_Array,
        BatchNorm(64),
        x -> reshape(x, :, size(x, 4)),
        fc,
        softmax) |> gpu;

loss(x, y) = logitcrossentropy(model(x), y)

img, lab = train_dataloader.data[1][:, :, :, 1:1], train_dataloader.data[2][:, 1:1]

When I run the following I get the DimensionMismatch error.

g = gradient(() -> loss(img, lab), params(model))

I couldn’t find what was causing this error. The forward pass model(img) works fine. The same error occurs when I try different training algorithms such as ADAM etc so I think the gradient is causing the error.

Full stack trace: ```DimensionMismatch("array could not be broadcast to match destination")S - Pastebin.com