NeuralODE with covariates

Hi. Am an quite new with Julia and I am stuck with something. Maybe you can help me (sorry if I am trying something stupid).

I am working with Neural ODEs using DiffEqFlux. I want to model the differential equation with a neural network defined in Lux and then fit the Neural ODE to some time series data. I have managed to do so without much trouble using some of the tutorials found online.
However, I am now interested in including covariates to my ODE. Something like this:

Imagine that I want to model the ODE dxdt = f(x|a), where a is a numerical covariate (suppose, for example that I am modeling some time series of a person where a is their weight)

I have a dataset with these columns: t, x, a. (i.e. an independent time series for different values of a)
What I have been doing to use DiffEqFlux is a little hack and define a system of ODE like this:

function F!(du,u,p,t)
    x, a = u
    du[1] = dx = model(u, p)[1]
    du[2] = da = 0.0
end

Where model is a simple FF neural network. By setting dadt = 0 I force the covariate a to be constant through all the time serie.

This way of dealing with the problem is working when I train with only 1 value of a. However, I need to train the model with different values for a (and thus different time series data) I am having some problems with Zygote.
What I have tried is to define my loss function like this:

function loss_function_multi_patients(p, a_list)
    errors = Zygote.Buffer(zeros(length(a_list)))
    for (ix, a) in enumerate(a_list)
        Zygote.ignore() do
            u0, t, x = get_data(a) # This function returns the timeseries (x vs t for some value a)
            prob = ODEProblem(F!, u0, (minimum(t), maximum(t)), p)
        end

        pred = predict(p, t)'[:,1] # This function solves the ODE (forward pass)
        errors[ix] = mse(pred,x) # For each value of a, I calculate the MSE error
    end
    
    return sum(errors)/length(patient_ids) # I return the average MSE for all the values of a
end

I know that taking the average of MSEs is not the best idea, but for this specific case I need it this way.
For training I am using Optimization and Zygote packages as recommended in the DiffEqFlux documentation.
Now, if I train the model using this loss function for only one value of a (setting list_a with only one value) it works perfectly fine. The problem is when I try to train with several values for a. If I do so, I get an Zygote error (AssertionError: x === y) and I have not been able to find out the reason.

I think that my problem might be in this line, as errors is mutable and apparently Zygote doesn’t like mutable objects. I have already tried using Zygote.Buffer without any luck.

errors[ix] = mse(pred,x)

Any ideas to help me? Thanks!

Just handle the covariate function before the ODE and let autodiff take care of it?