How to take the gradient of an ODE system with respect to many data points?

Hello!

I have an ode system (dx = f(x, ξ, k)), and I would like to find optimal values of ξ so that:

\mathrm{argmin}_\xi \left( f(x, \xi, k) - dx \right)^2

where \xi \in R^{n\xi} are the parameters I want to tune, x \in R^{v \times d} are data points for the system state where v is the number of state variables and d is the number of observations, x \in R^{v \times d} are observations for the derivative of the system, and k \in R^{nk \times d} are values of parameters that don’t need to be tuned.
I’ve implemented the equation above thusly:

using Statistics
using Zygote
using ProfileCanvas
using BenchmarkTools

loss(x, ξ, k, dx) = mean( @inbounds @. @views (-dx[3,:] + k[1,:]*ξ[31] + k[2,:]*ξ[32] + k[3,:]*ξ[33] + (k[1,:]*ξ[10] + k[2,:]*ξ[11] + k[3,:]*ξ[12])*x[1,:] + (k[1,:]*ξ[25] + k[2,:]*ξ[26] + k[3,:]*ξ[27])*x[2,:] + (-k[1,:]*ξ[34] - k[2,:]*ξ[35] - k[3,:]*ξ[36])*x[3,:] + (-k[1,:]*ξ[37] - k[2,:]*ξ[38] - k[3,:]*ξ[39])*x[3,:] + (-k[1,:]*ξ[40] - k[2,:]*ξ[41] - k[3,:]*ξ[42])*x[3,:] + (-k[1,:]*ξ[43] - k[2,:]*ξ[44] - k[3,:]*ξ[45])*x[3,:] + (k[1,:]*ξ[58] + k[2,:]*ξ[59] + k[3,:]*ξ[60])*x[4,:])^2.0 + (-dx[1,:] + k[1,:]*ξ[1] + k[2,:]*ξ[2] + k[3,:]*ξ[3] + (-k[1,:]*ξ[10] - k[2,:]*ξ[11] - k[3,:]*ξ[12])*x[1,:] + (-k[1,:]*ξ[13] - k[2,:]*ξ[14] - k[3,:]*ξ[15])*x[1,:] + (k[1,:]*ξ[22] + k[2,:]*ξ[23] + k[3,:]*ξ[24])*x[2,:] + (k[1,:]*ξ[37] + k[2,:]*ξ[38] + k[3,:]*ξ[39])*x[3,:] + (-k[1,:]*ξ[4] - k[2,:]*ξ[5] - k[3,:]*ξ[6])*x[1,:] + (k[1,:]*ξ[52] + k[2,:]*ξ[53] + k[3,:]*ξ[54])*x[4,:] + (-k[1,:]*ξ[7] - k[2,:]*ξ[8] - k[3,:]*ξ[9])*x[1,:])^2.0 + (-dx[4,:] + k[1,:]*ξ[46] + k[2,:]*ξ[47] + k[3,:]*ξ[48] + (k[1,:]*ξ[13] + k[2,:]*ξ[14] + k[3,:]*ξ[15])*x[1,:] + (k[1,:]*ξ[28] + k[2,:]*ξ[29] + k[3,:]*ξ[30])*x[2,:] + (k[1,:]*ξ[43] + k[2,:]*ξ[44] + k[3,:]*ξ[45])*x[3,:] + (-k[1,:]*ξ[49] - k[2,:]*ξ[50] - k[3,:]*ξ[51])*x[4,:] + (-k[1,:]*ξ[52] - k[2,:]*ξ[53] - k[3,:]*ξ[54])*x[4,:] + (-k[1,:]*ξ[55] - k[2,:]*ξ[56] - k[3,:]*ξ[57])*x[4,:] + (-k[1,:]*ξ[58] - k[2,:]*ξ[59] - k[3,:]*ξ[60])*x[4,:])^2.0 + (-dx[2,:] + k[1,:]*ξ[16] + k[2,:]*ξ[17] + k[3,:]*ξ[18] + (-k[1,:]*ξ[19] - k[2,:]*ξ[20] - k[3,:]*ξ[21])*x[2,:] + (-k[1,:]*ξ[22] - k[2,:]*ξ[23] - k[3,:]*ξ[24])*x[2,:] + (-k[1,:]*ξ[25] - k[2,:]*ξ[26] - k[3,:]*ξ[27])*x[2,:] + (-k[1,:]*ξ[28] - k[2,:]*ξ[29] - k[3,:]*ξ[30])*x[2,:] + (k[1,:]*ξ[40] + k[2,:]*ξ[41] + k[3,:]*ξ[42])*x[3,:] + (k[1,:]*ξ[55] + k[2,:]*ξ[56] + k[3,:]*ξ[57])*x[4,:] + (k[1,:]*ξ[7] + k[2,:]*ξ[8] + k[3,:]*ξ[9])*x[1,:])^2.0)


It looks like a mess but basically it’s just (f(x_1, ξ, k) .- dx_1).^2 .+ (f(x_2, ξ, k) .- dx_2)^.2 + ...., with a lot of slice and broadcast operations. (Basically it’s a version of SINDy where the equation term library coefficients are reused in a specific pattern which gives specific theoretical guarantees.)

If I run the function forward it seems fine:

v = 4
d = 10
nk = 3
nξ = 60
x = rand(v, d)
dx = rand(v, d)
ξ = rand(nξ)
k = rand(nk, d)

@btime loss(x, ξ, k, dx)
# 2.894 μs (2 allocations: 160 bytes)


However, when I try to take the gradient with respect to \xi it’s not really fine:

f(ξ) = loss(x, ξ, k, dx)
@btime f'(ξ)
# 42.575 ms (1227548 allocations: 51.24 MiB)


For a larger version of the same system, running the gradient function gets exponentially slower and/or gives a stack overflow error, depending on the specific implementation (I’ve tried a lot of different implementations).

So I basically have two questions, and an answer to either one of them would be great:

1. For the function above specifically, is there a way to reformat it to make it go faster with Zygote, or an alternative AD system? (ForwardDiff works fine for small problems but when the size of \xi gets large it is too slow.)
2. More generally, is there a best practice for taking an ODE system and calculating the average of its gradient with respect to parameters across many different values of the system state vector? (I am aware of DataDrivenDiffEq.jl but I don’t think it would work in this case owing to the reuse of parameters in multiple equation terms.)

Thanks for any suggestions you can offer!

Can you reformat the function so that we can see it without scrolling, and understand its structure a little bit?
For a quick way to compare several autodiff backends, check out DifferentiationInterface.jl

Tape-compiled ReverseDiff is very likely to be your best bet here, which I think can be accessed through DifferentiationInterface.jl

1 Like

Here is a formatted version of the function:

loss(x, ξ, k, dx) = mean( @inbounds @. @views abs(-dx[3,:] + k[1,:]*ξ[31] + k[2,:]*ξ[32] + k[3,:]*ξ[33] +
(k[1,:]*ξ[10] + k[2,:]*ξ[11] + k[3,:]*ξ[12])*x[1,:] +
(k[1,:]*ξ[25] + k[2,:]*ξ[26] + k[3,:]*ξ[27])*x[2,:] +
(-k[1,:]*ξ[34] - k[2,:]*ξ[35] - k[3,:]*ξ[36])*x[3,:] +
(-k[1,:]*ξ[37] - k[2,:]*ξ[38] - k[3,:]*ξ[39])*x[3,:] +
(-k[1,:]*ξ[40] - k[2,:]*ξ[41] - k[3,:]*ξ[42])*x[3,:] +
(-k[1,:]*ξ[43] - k[2,:]*ξ[44] - k[3,:]*ξ[45])*x[3,:] +
(k[1,:]*ξ[58] + k[2,:]*ξ[59] + k[3,:]*ξ[60])*x[4,:]) +
abs(-dx[1,:] + k[1,:]*ξ[1] + k[2,:]*ξ[2] + k[3,:]*ξ[3] +
(-k[1,:]*ξ[10] - k[2,:]*ξ[11] - k[3,:]*ξ[12])*x[1,:] +
(-k[1,:]*ξ[13] - k[2,:]*ξ[14] - k[3,:]*ξ[15])*x[1,:] +
(k[1,:]*ξ[22] + k[2,:]*ξ[23] + k[3,:]*ξ[24])*x[2,:] +
(k[1,:]*ξ[37] + k[2,:]*ξ[38] + k[3,:]*ξ[39])*x[3,:] +
(-k[1,:]*ξ[4] - k[2,:]*ξ[5] - k[3,:]*ξ[6])*x[1,:] +
(k[1,:]*ξ[52] + k[2,:]*ξ[53] + k[3,:]*ξ[54])*x[4,:] +
(-k[1,:]*ξ[7] - k[2,:]*ξ[8] - k[3,:]*ξ[9])*x[1,:]) +
abs(-dx[4,:] + k[1,:]*ξ[46] + k[2,:]*ξ[47] + k[3,:]*ξ[48] +
(k[1,:]*ξ[13] + k[2,:]*ξ[14] + k[3,:]*ξ[15])*x[1,:] +
(k[1,:]*ξ[28] + k[2,:]*ξ[29] + k[3,:]*ξ[30])*x[2,:] +
(k[1,:]*ξ[43] + k[2,:]*ξ[44] + k[3,:]*ξ[45])*x[3,:] +
(-k[1,:]*ξ[49] - k[2,:]*ξ[50] - k[3,:]*ξ[51])*x[4,:] +
(-k[1,:]*ξ[52] - k[2,:]*ξ[53] - k[3,:]*ξ[54])*x[4,:] +
(-k[1,:]*ξ[55] - k[2,:]*ξ[56] - k[3,:]*ξ[57])*x[4,:] +
(-k[1,:]*ξ[58] - k[2,:]*ξ[59] - k[3,:]*ξ[60])*x[4,:]) +
abs(-dx[2,:] + k[1,:]*ξ[16] + k[2,:]*ξ[17] + k[3,:]*ξ[18] +
(-k[1,:]*ξ[19] - k[2,:]*ξ[20] - k[3,:]*ξ[21])*x[2,:] +
(-k[1,:]*ξ[22] - k[2,:]*ξ[23] - k[3,:]*ξ[24])*x[2,:] +
(-k[1,:]*ξ[25] - k[2,:]*ξ[26] - k[3,:]*ξ[27])*x[2,:] +
(-k[1,:]*ξ[28] - k[2,:]*ξ[29] - k[3,:]*ξ[30])*x[2,:] +
(k[1,:]*ξ[40] + k[2,:]*ξ[41] + k[3,:]*ξ[42])*x[3,:] +
(k[1,:]*ξ[55] + k[2,:]*ξ[56] + k[3,:]*ξ[57])*x[4,:] +
(k[1,:]*ξ[7] + k[2,:]*ξ[8] + k[3,:]*ξ[9])*x[1,:]))


I also modified this function to use abs instead of ^2, because ^2 was causing a domain error in ReverseDiff (“DomainError with -1.8919415434611375, in call to log, in call to to ^”).

Here’s some code to test the performance of all of the backends:

begin
using Statistics
using Zygote
using ProfileCanvas
using BenchmarkTools
using DifferentiationInterface
using ReverseDiff
using ChainRulesCore
using Diffractor
using Enzyme
using FiniteDiff
using FiniteDifferences
using ForwardDiff
using PolyesterForwardDiff
using SparseDiffTools
using Tracker
using FastDifferentiation
using Symbolics
using Tapir
using Plots
end

begin
v = 4
d = 10
nk = 3
nξ = 60
x = rand(v, d)
dx = rand(v, d)
ξ = rand(nξ)
k = rand(nk, d)
end

f(ξ) = loss(x, ξ, k, dx)

diffbackends = [
#AutoChainRules()
#AutoDiffractor()
#AutoEnzyme(; mode=Enzyme.Forward)
#AutoEnzyme(; mode=Enzyme.Reverse)
AutoFiniteDiff()
#AutoFiniteDifferences()
AutoForwardDiff()
#AutoPolyesterForwardDiff()
AutoReverseDiff() # DomainError with -1.8919415434611375, in call to log, in call to to ^
AutoSparseForwardDiff()
#AutoSparseFiniteDiff()
#AutoTracker() # Took > 10 min to compile/run the first time
AutoZygote()
AutoFastDifferentiation()
AutoSparseFastDifferentiation()
AutoSymbolics()
AutoSparseSymbolics()
AutoTapir()
]

nms = []
tms = []
for i in 1:length(diffbackends)
push!(nms, string(Base.typename(typeof(diffbackends[i])).wrapper))
@info nms[end]
push!(tms, @belapsed value_and_gradient(f, diffbackends[$i], ξ)) end bar(1:length(nms), log10.(tms), ylim=(0,length(nms)+1), yticks=(1:length(nms), nms), orientation=:h, xlabel="log10(time) (s)", xlims=(-5, 2.5), label=:none)  I’ve commented out the ones that errored or took >10 minutes to compile/run the first time. Here is the resulting plot: I tried calculating the analytical gradient using symbolics.jl and then rewriting it to handle vectors, with the result looking like this: loss_grad(g, x, ξ, k, dx) = begin @views begin begin g[1] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :]))) g[2] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :])) g[3] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :])) g[4] = mean(#= none:1 =# @__dot__((-2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :]) * x[1, :])) g[5] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[1, :])) g[6] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[1, :])) g[7] = mean(#= none:1 =# @__dot__((-2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :]) * x[1, :] + (2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :]) * x[1, :])) g[8] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[1, :] + (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[1, :])) g[9] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[1, :] + (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[1, :])) g[10] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :]) * x[1, :] - (2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :]) * x[1, :])) g[11] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :])) * k[2, :] * x[1, :] - (2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[1, :])) g[12] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :])) * k[3, :] * x[1, :] - (2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[1, :])) g[13] = mean(#= none:1 =# @__dot__((-2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :]) * x[1, :] + (2.0 * k[1, :]) * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :]) * x[1, :])) g[14] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[1, :] + (2.0 * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :])) * k[2, :] * x[1, :])) g[15] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[1, :] + (2.0 * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :])) * k[3, :] * x[1, :])) g[16] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :]))) g[17] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :])) g[18] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :])) g[19] = mean(#= none:1 =# @__dot__((-2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :]) * x[2, :])) g[20] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[2, :])) g[21] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[2, :])) g[22] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :]) * x[2, :] - (2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :]) * x[2, :])) g[23] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[2, :])) g[24] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[2, :])) g[25] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :]) * x[2, :] - (2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :]) * x[2, :])) g[26] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :])) * k[2, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[2, :])) g[27] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :])) * k[3, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[2, :])) g[28] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :]) * x[2, :] - (2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :]) * x[2, :])) g[29] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :])) * k[2, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[2, :])) g[30] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :])) * k[3, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[2, :])) # ....Truncated owing to message size limits... end g end end g = zeros(nξ) tms2 = [tms..., @belapsed loss_grad($g, x, $ξ, k, dx)] nms2 = [nms..., "Analytical rewrite"] bar(1:length(nms2), log10.(tms2), ylim=(0,length(nms2)+1), yticks=(1:length(nms2), nms2), orientation=:h, xlabel="log10(time) (s)", xlims=(-5, 2.5), label=:none)  As you can see in the plot below, this version (labeled “Analytical rewrite”) is similar to the fastest other options: However, the compile time is slow, and seems to grow exponentially as x and g get larger… Part of the explanation for your results is that many autodiff packages need “preparation” to be performant, and you don’t want to include that preparation step in your benchmark. Examples include the config in ForwardDiff, the tape in ReverseDiff, the symbolic compilation in FastDifferentiation and so on. To make your life easier, DifferentiationInterfaceTest.jl includes benchmarking utilities that prepare the operator before benchmarking. Take a look at the tutorial: Tutorial · DifferentiationInterfaceTest.jl Also note that sparse backends won’t change anything here: they are only useful for jacobians and hessians, not for gradients. Finally, your function seems extremely inefficient as written: when you profile it, most of the time is spent copying stuff / managing eltypes. 2 Likes I wouldn’t expect symbolics to give a better result than forward-mode AD for gradients. It’s a very similar calculation for that, but has expression growth for a compile time disadvantage. Only sparse Jacobians would likely have a real use case where symbolics would outperform AD. This reminds me I should probably make benchmarking robust to errors in DifferentiationInterfaceTest.jl, with something like this try result = Chairmarks.@be f() catch e result = # Inf everywhere end  I’ll investigate why some of the backends fail Benchmarking code begin using Statistics # using ProfileCanvas # using BenchmarkTools using DataFrames using Plots using DifferentiationInterface using DifferentiationInterfaceTest # using ChainRulesCore # using Diffractor using Enzyme using FastDifferentiation using FiniteDiff # using FiniteDifferences using ForwardDiff using PolyesterForwardDiff using ReverseDiff using SparseDiffTools using Symbolics using Tapir using Tracker using Zygote end function loss(x, ξ, k, dx) k1, k2, k3 = view(k, 1, :), view(k, 2, :), view(k, 3, :) x1, x2, x3, x4 = view(x, 1, :), view(x, 2, :), view(x, 3, :), view(x, 4, :) dx1, dx2, dx3, dx4 = view(dx, 1, :), view(dx, 2, :), view(dx, 3, :), view(dx, 4, :) y = @. begin abs( -dx3 + k1 * ξ[31] + k2 * ξ[32] + k3 * ξ[33] + (k1 * ξ[10] + k2 * ξ[11] + k3 * ξ[12]) * x1 + (k1 * ξ[25] + k2 * ξ[26] + k3 * ξ[27]) * x2 + (-k1 * ξ[34] - k2 * ξ[35] - k3 * ξ[36]) * x3 + (-k1 * ξ[37] - k2 * ξ[38] - k3 * ξ[39]) * x3 + (-k1 * ξ[40] - k2 * ξ[41] - k3 * ξ[42]) * x3 + (-k1 * ξ[43] - k2 * ξ[44] - k3 * ξ[45]) * x3 + (k1 * ξ[58] + k2 * ξ[59] + k3 * ξ[60]) * x4, ) + abs( -dx1 + k1 * ξ[1] + k2 * ξ[2] + k3 * ξ[3] + (-k1 * ξ[10] - k2 * ξ[11] - k3 * ξ[12]) * x1 + (-k1 * ξ[13] - k2 * ξ[14] - k3 * ξ[15]) * x1 + (k1 * ξ[22] + k2 * ξ[23] + k3 * ξ[24]) * x2 + (k1 * ξ[37] + k2 * ξ[38] + k3 * ξ[39]) * x3 + (-k1 * ξ[4] - k2 * ξ[5] - k3 * ξ[6]) * x1 + (k1 * ξ[52] + k2 * ξ[53] + k3 * ξ[54]) * x4 + (-k1 * ξ[7] - k2 * ξ[8] - k3 * ξ[9]) * x1, ) + abs( -dx4 + k1 * ξ[46] + k2 * ξ[47] + k3 * ξ[48] + (k1 * ξ[13] + k2 * ξ[14] + k3 * ξ[15]) * x1 + (k1 * ξ[28] + k2 * ξ[29] + k3 * ξ[30]) * x2 + (k1 * ξ[43] + k2 * ξ[44] + k3 * ξ[45]) * x3 + (-k1 * ξ[49] - k2 * ξ[50] - k3 * ξ[51]) * x4 + (-k1 * ξ[52] - k2 * ξ[53] - k3 * ξ[54]) * x4 + (-k1 * ξ[55] - k2 * ξ[56] - k3 * ξ[57]) * x4 + (-k1 * ξ[58] - k2 * ξ[59] - k3 * ξ[60]) * x4, ) + abs( -dx2 + k1 * ξ[16] + k2 * ξ[17] + k3 * ξ[18] + (-k1 * ξ[19] - k2 * ξ[20] - k3 * ξ[21]) * x2 + (-k1 * ξ[22] - k2 * ξ[23] - k3 * ξ[24]) * x2 + (-k1 * ξ[25] - k2 * ξ[26] - k3 * ξ[27]) * x2 + (-k1 * ξ[28] - k2 * ξ[29] - k3 * ξ[30]) * x2 + (k1 * ξ[40] + k2 * ξ[41] + k3 * ξ[42]) * x3 + (k1 * ξ[55] + k2 * ξ[56] + k3 * ξ[57]) * x4 + (k1 * ξ[7] + k2 * ξ[8] + k3 * ξ[9]) * x1, ) end return mean(y) end begin v = 4 d = 10 nk = 3 nξ = 60 x = rand(v, d) dx = rand(v, d) ξ = rand(nξ) k = rand(nk, d) end f(ξ) = loss(x, ξ, k, dx) f(ξ) scenarios = [GradientScenario(f; x=ξ, operator=:inplace)] backends = [ # AutoDiffractor(), # AutoEnzyme(; mode=Enzyme.Forward), # AutoEnzyme(; mode=Enzyme.Reverse), AutoFastDifferentiation(), AutoFiniteDiff(), # AutoFiniteDifferences(), AutoForwardDiff(), # AutoPolyesterForwardDiff(; chunksize=8), # AutoTracker(), # AutoReverseDiff(), # AutoSymbolics(), # AutoTapir(), AutoZygote(), ] result = benchmark_differentiation(backends, scenarios; logging=true) df = DataFrame(result) df_filtered = df[df[!, :operator] .== :gradient!, :] plt = bar( df_filtered[!, :backend], df_filtered[!, :time], label=nothing, xlabel="backend", ylabel="runtime [log]", xrotation=10, yscale=:log10, margin=15Plots.mm ) savefig(plt, "benchmark.png")  2 Likes FastDifferentiation.jl specializes a lot on + and * (assuming associativity and commutativity) so it’s cool to see that come into play here. 1 Like @brianguenter you’ll like this Beware that forward and reverse mode scale quite differently with the number of parameters, so you are likely to get very different results for e.g. 1000 parameters. Thanks everyone! I’ve eventually figured out a way to turn the above operation into something closer to matrix multiplications, which I guess is probably the best solution for this kind of thing. 1 Like I’d be curious if you could share the updated code, I’d love to benchmark it again. I think the reason why some backends lagged forever is the complexity of the function Here is the code: umask = [ 1.0 0.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0 0.0; 1.0 0.0 0.0 0.0 0.0; 0.0 0.0 1.0 0.0 0.0; 0.0 0.0 1.0 0.0 0.0; 0.0 0.0 1.0 0.0 0.0; 0.0 0.0 1.0 0.0 0.0; 1.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 1.0 0.0; 1.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 1.0; 0.0 0.0 0.0 0.0 1.0; 0.0 0.0 0.0 0.0 1.0; 0.0 0.0 0.0 0.0 1.0; ] stoich = [ 1.0 -1.0 -1.0 -1.0 -1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0; 0.0 0.0 1.0 0.0 0.0 1.0 -1.0 -1.0 -1.0 -1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0; 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 -1.0 -1.0 -1.0 -1.0 0.0 0.0 0.0 0.0 1.0; 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 1.0 -1.0 -1.0 -1.0 -1.0; ] function dudt_mat(u, ξ, k) oneplusu = [1; u] ξmat = transpose(reshape(ξ, nrate, :)) ratelaws = (ξmat * k) .* (umask * oneplusu) stoich * ratelaws end function loss_mat(u, ξ, k, dudt_target) mean((dudt(u, ξ, k) .- dudt_target).^2) end let f(ξ) = loss_mat(x, ξ, k, dx) @btime value_and_gradient(f, AutoZygote(),$ξ) # 41.555 ms (1227564 allocations: 51.30 MiB)
end


I think it might scale even better if you make umask and stoich
• sparse matrices with SparseArrays.sparse
• constant with const