GridInterpolations.jl is not differentiable

I’m attempting to train two neural odes simultaneously that interact through the definition of a loss function. My training data is a 2d array. The goal here is to have one ODE describe the “path” through the 2D data array (from “left” to “right”) and have the other ODE describe the evolution of the data along that path.

My problem is that my loss function compares interpolated data (2d gridded interpolation parameterized by the first ODE) with the output of the second ODE - this gives a familiar error pointing back to the interpolation during the evaluation of the loss function:

Mutating arrays is not supported

My attempts at a workaround included using Zygote’s Buffer(), though this does not seem to work for me either.

My MWE is as follows (dummy data included):

using DiffEqFlux, DifferentialEquations, Optim, Flux, GridInterpolations
using Zygote: Buffer

## Let's get the input data:
sol = rand(101,101)

xSim = Array(range(0.0,10.0, length=101))
tSim = Array(range(0.0,10.0, length=101))

grid = RectangleGrid(xSim,tSim)
gridData = [(sol...)...]

## Define ODEs, NN, and set up solutions
tSpan = (0.0,5.0)
tSteps = Array(range(tSpan[1], tSpan[2], length=6))

L = FastChain(FastDense(1,3,sigmoid),FastDense(3,1))
p = [initial_params(L);0.0]

# ODE describing the evolution of the path
function dxdt(u,p,t)
	z = L(u,p[1:end-1])
	return [z[1]]
end

# ODE describing the (linear) dynamics along the path
function dudt(u,p,t)
	return p[end]*u
end

## Define function to interpolate data according to path:
function getTrainingTrajectory(s)
	buf = Buffer(s)
	for j in 1:length(s)
		buf[j] = interpolate(grid,gridData,[s[j],tSteps[j]])
	end
	return copy(buf)
end

## Define loss function
function lossFun(p)

    # Generic IC index:
    j = 40

	# Get the path:
	evolvePath = ODEProblem(dxdt, [xSim[j]], tSpan, p)
	s = Array(solve(evolvePath, Tsit5(), saveat=tSteps))

	# Evolve data IC
	evolveU = ODEProblem(dudt, [sol[j,1]], tSpan, p)
	u = Array(solve(evolveU, Tsit5(), saveat=tSteps))

	# Get the data along the new path:
	trainingTrajectory = getTrainingTrajectory(s)

	loss = sum(abs2, u.-trainingTrajectory)
	

	return loss, u, s
end

callback = function (p, loss, u, s; doplot = false)
  display(loss)
  return false
end

## Train:
result_neuralode = DiffEqFlux.sciml_train(lossFun,
                                           p,
                                           ADAM(0.05),
                                           cb = callback,
                                           maxiters = 100)

Besides my naive attempt at constructing the interpolated path with a buffered array, is there another way that I should be thinking about how to formulate this system/loss function?

I renamed the topic to be more clear. The issue is just that GridInterpolations.jl is not differentiable. An issue should be opened on that repo and the package should fix its primitives.

1 Like