Thanks to both of you for your help! Two comments regarding the discussion:

- When trying to use Zygote, I get the following error on the MWE:

```
f = LLVM.Function("julia__mapreducedim__5966")
(gty, inst, v) = (LLVM.IntegerType[LLVM.IntegerType(i64)], LLVM.PHIInst(%129 = phi double addrspace(13)* [ poison, %L89.L110.loopexit_crit_edge.us125.unr-lcssa.us.1.L110.us129.us.1_crit_edge ], [ %113, %L89.L110.loopexit_crit_edge.us125.unr-lcssa.us.1.thread ], [ %119, %L93.us121.epil.us.1 ]), LLVM.PoisonValue(0x000000006a63fb20))
f = LLVM.Function("julia__mapreducedim__6969")
(gty, inst, v) = (LLVM.IntegerType[LLVM.IntegerType(i64)], LLVM.PHIInst(%129 = phi double addrspace(13)* [ poison, %L89.L110.loopexit_crit_edge.us125.unr-lcssa.us.1.L110.us129.us.1_crit_edge ], [ %113, %L89.L110.loopexit_crit_edge.us125.unr-lcssa.us.1.thread ], [ %119, %L93.us121.epil.us.1 ]), LLVM.PoisonValue(0x00000000086ed390))
ā Warning: EnzymeVJP tried and failed in the automated AD choice algorithm with the following error. (To turn off this printing, add `verbose = false` to the `solve` call)
ā @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/Rm4xX/src/concrete_solve.jl:23
AssertionError: false
ERROR: UndefRefError: access to undefined reference
```

follow by a long stacktrace.

- Unfortunately my model is not just linear, but the linear part is the one that takes most of the space in memory as it has thousands of parameters.

A small neural network makes predictions over subsets of the data. Maybe a more reasonable MWE of what I am trying to do would be:

```
using Zygote, DifferentialEquations, SciMLSensitivity, Lux, Random, ComponentArrays
rng = Random.default_rng()
trainingData = rand(100,4)
p0 = rand(100,100)
chain = Lux.Chain(Lux.Dense(4,5),Lux.Dense(5,4))
ltup = Lux.setup(rng, chain)
ps = ltup[1]
st = ltup[2]
p = ComponentVector(model_params = ps, connectivityMatrix = p0)
function nn!(du,u,p,t)
nns = reduce(vcat,[first(chain(u[((i-1)*4+1):((i-1)*4+4)],p.model_params,st)) for i in 1:25])
diffs = [u[j]-u[i] for i in 1:100, j in 1:100]
du = nns .+ sum(p.connectivityMatrix*diffs,dims=2)
end
function predict(p)
prob = ODEProblem(nn!,trainingData[:,1],(1.,4.),p)
Array(solve(prob,saveat=1.))
end
function loss(p)
pred = predict(p)
sum(abs2,pred .- trainingData)
end
@time Zygote.gradient(loss,p);
```