I’m attempting to train two neural odes simultaneously that interact through the definition of a loss function. My training data is a 2d array. The goal here is to have one ODE describe the “path” through the 2D data array (from “left” to “right”) and have the other ODE describe the evolution of the data along that path.
My problem is that my loss function compares interpolated data (2d gridded interpolation parameterized by the first ODE) with the output of the second ODE - this gives a familiar error pointing back to the interpolation during the evaluation of the loss function:
Mutating arrays is not supported
My attempts at a workaround included using Zygote’s Buffer(), though this does not seem to work for me either.
My MWE is as follows (dummy data included):
using DiffEqFlux, DifferentialEquations, Optim, Flux, GridInterpolations
using Zygote: Buffer
## Let's get the input data:
sol = rand(101,101)
xSim = Array(range(0.0,10.0, length=101))
tSim = Array(range(0.0,10.0, length=101))
grid = RectangleGrid(xSim,tSim)
gridData = [(sol...)...]
## Define ODEs, NN, and set up solutions
tSpan = (0.0,5.0)
tSteps = Array(range(tSpan[1], tSpan[2], length=6))
L = FastChain(FastDense(1,3,sigmoid),FastDense(3,1))
p = [initial_params(L);0.0]
# ODE describing the evolution of the path
function dxdt(u,p,t)
z = L(u,p[1:end-1])
return [z[1]]
end
# ODE describing the (linear) dynamics along the path
function dudt(u,p,t)
return p[end]*u
end
## Define function to interpolate data according to path:
function getTrainingTrajectory(s)
buf = Buffer(s)
for j in 1:length(s)
buf[j] = interpolate(grid,gridData,[s[j],tSteps[j]])
end
return copy(buf)
end
## Define loss function
function lossFun(p)
# Generic IC index:
j = 40
# Get the path:
evolvePath = ODEProblem(dxdt, [xSim[j]], tSpan, p)
s = Array(solve(evolvePath, Tsit5(), saveat=tSteps))
# Evolve data IC
evolveU = ODEProblem(dudt, [sol[j,1]], tSpan, p)
u = Array(solve(evolveU, Tsit5(), saveat=tSteps))
# Get the data along the new path:
trainingTrajectory = getTrainingTrajectory(s)
loss = sum(abs2, u.-trainingTrajectory)
return loss, u, s
end
callback = function (p, loss, u, s; doplot = false)
display(loss)
return false
end
## Train:
result_neuralode = DiffEqFlux.sciml_train(lossFun,
p,
ADAM(0.05),
cb = callback,
maxiters = 100)
Besides my naive attempt at constructing the interpolated path with a buffered array, is there another way that I should be thinking about how to formulate this system/loss function?