How to take the gradient of an ODE system with respect to many data points?

Hello!

I have an ode system (dx = f(x, ξ, k)), and I would like to find optimal values of ξ so that:

\mathrm{argmin}_\xi \left( f(x, \xi, k) - dx \right)^2

where \xi \in R^{n\xi} are the parameters I want to tune, x \in R^{v \times d} are data points for the system state where v is the number of state variables and d is the number of observations, x \in R^{v \times d} are observations for the derivative of the system, and k \in R^{nk \times d} are values of parameters that don’t need to be tuned.
I’ve implemented the equation above thusly:

using Statistics
using Zygote
using ProfileCanvas
using BenchmarkTools

loss(x, ξ, k, dx) = mean( @inbounds @. @views (-dx[3,:] + k[1,:]*ξ[31] + k[2,:]*ξ[32] + k[3,:]*ξ[33] + (k[1,:]*ξ[10] + k[2,:]*ξ[11] + k[3,:]*ξ[12])*x[1,:] + (k[1,:]*ξ[25] + k[2,:]*ξ[26] + k[3,:]*ξ[27])*x[2,:] + (-k[1,:]*ξ[34] - k[2,:]*ξ[35] - k[3,:]*ξ[36])*x[3,:] + (-k[1,:]*ξ[37] - k[2,:]*ξ[38] - k[3,:]*ξ[39])*x[3,:] + (-k[1,:]*ξ[40] - k[2,:]*ξ[41] - k[3,:]*ξ[42])*x[3,:] + (-k[1,:]*ξ[43] - k[2,:]*ξ[44] - k[3,:]*ξ[45])*x[3,:] + (k[1,:]*ξ[58] + k[2,:]*ξ[59] + k[3,:]*ξ[60])*x[4,:])^2.0 + (-dx[1,:] + k[1,:]*ξ[1] + k[2,:]*ξ[2] + k[3,:]*ξ[3] + (-k[1,:]*ξ[10] - k[2,:]*ξ[11] - k[3,:]*ξ[12])*x[1,:] + (-k[1,:]*ξ[13] - k[2,:]*ξ[14] - k[3,:]*ξ[15])*x[1,:] + (k[1,:]*ξ[22] + k[2,:]*ξ[23] + k[3,:]*ξ[24])*x[2,:] + (k[1,:]*ξ[37] + k[2,:]*ξ[38] + k[3,:]*ξ[39])*x[3,:] + (-k[1,:]*ξ[4] - k[2,:]*ξ[5] - k[3,:]*ξ[6])*x[1,:] + (k[1,:]*ξ[52] + k[2,:]*ξ[53] + k[3,:]*ξ[54])*x[4,:] + (-k[1,:]*ξ[7] - k[2,:]*ξ[8] - k[3,:]*ξ[9])*x[1,:])^2.0 + (-dx[4,:] + k[1,:]*ξ[46] + k[2,:]*ξ[47] + k[3,:]*ξ[48] + (k[1,:]*ξ[13] + k[2,:]*ξ[14] + k[3,:]*ξ[15])*x[1,:] + (k[1,:]*ξ[28] + k[2,:]*ξ[29] + k[3,:]*ξ[30])*x[2,:] + (k[1,:]*ξ[43] + k[2,:]*ξ[44] + k[3,:]*ξ[45])*x[3,:] + (-k[1,:]*ξ[49] - k[2,:]*ξ[50] - k[3,:]*ξ[51])*x[4,:] + (-k[1,:]*ξ[52] - k[2,:]*ξ[53] - k[3,:]*ξ[54])*x[4,:] + (-k[1,:]*ξ[55] - k[2,:]*ξ[56] - k[3,:]*ξ[57])*x[4,:] + (-k[1,:]*ξ[58] - k[2,:]*ξ[59] - k[3,:]*ξ[60])*x[4,:])^2.0 + (-dx[2,:] + k[1,:]*ξ[16] + k[2,:]*ξ[17] + k[3,:]*ξ[18] + (-k[1,:]*ξ[19] - k[2,:]*ξ[20] - k[3,:]*ξ[21])*x[2,:] + (-k[1,:]*ξ[22] - k[2,:]*ξ[23] - k[3,:]*ξ[24])*x[2,:] + (-k[1,:]*ξ[25] - k[2,:]*ξ[26] - k[3,:]*ξ[27])*x[2,:] + (-k[1,:]*ξ[28] - k[2,:]*ξ[29] - k[3,:]*ξ[30])*x[2,:] + (k[1,:]*ξ[40] + k[2,:]*ξ[41] + k[3,:]*ξ[42])*x[3,:] + (k[1,:]*ξ[55] + k[2,:]*ξ[56] + k[3,:]*ξ[57])*x[4,:] + (k[1,:]*ξ[7] + k[2,:]*ξ[8] + k[3,:]*ξ[9])*x[1,:])^2.0)

It looks like a mess but basically it’s just (f(x_1, ξ, k) .- dx_1).^2 .+ (f(x_2, ξ, k) .- dx_2)^.2 + ...., with a lot of slice and broadcast operations. (Basically it’s a version of SINDy where the equation term library coefficients are reused in a specific pattern which gives specific theoretical guarantees.)

If I run the function forward it seems fine:

v = 4
d = 10
nk = 3
nξ = 60
x = rand(v, d)
dx = rand(v, d)
ξ = rand(nξ)
k = rand(nk, d)

@btime loss(x, ξ, k, dx)
# 2.894 μs (2 allocations: 160 bytes)

However, when I try to take the gradient with respect to \xi it’s not really fine:

f(ξ) = loss(x, ξ, k, dx)
@btime f'(ξ)
# 42.575 ms (1227548 allocations: 51.24 MiB)

For a larger version of the same system, running the gradient function gets exponentially slower and/or gives a stack overflow error, depending on the specific implementation (I’ve tried a lot of different implementations).

So I basically have two questions, and an answer to either one of them would be great:

  1. For the function above specifically, is there a way to reformat it to make it go faster with Zygote, or an alternative AD system? (ForwardDiff works fine for small problems but when the size of \xi gets large it is too slow.)
  2. More generally, is there a best practice for taking an ODE system and calculating the average of its gradient with respect to parameters across many different values of the system state vector? (I am aware of DataDrivenDiffEq.jl but I don’t think it would work in this case owing to the reuse of parameters in multiple equation terms.)

Thanks for any suggestions you can offer!

Can you reformat the function so that we can see it without scrolling, and understand its structure a little bit?
For a quick way to compare several autodiff backends, check out DifferentiationInterface.jl

Tape-compiled ReverseDiff is very likely to be your best bet here, which I think can be accessed through DifferentiationInterface.jl

1 Like

Thanks for your responses!

Here is a formatted version of the function:

loss(x, ξ, k, dx) = mean( @inbounds @. @views abs(-dx[3,:] + k[1,:]*ξ[31] + k[2,:]*ξ[32] + k[3,:]*ξ[33] + 
		(k[1,:]*ξ[10] + k[2,:]*ξ[11] + k[3,:]*ξ[12])*x[1,:] + 
		(k[1,:]*ξ[25] + k[2,:]*ξ[26] + k[3,:]*ξ[27])*x[2,:] + 
		(-k[1,:]*ξ[34] - k[2,:]*ξ[35] - k[3,:]*ξ[36])*x[3,:] + 
		(-k[1,:]*ξ[37] - k[2,:]*ξ[38] - k[3,:]*ξ[39])*x[3,:] + 
		(-k[1,:]*ξ[40] - k[2,:]*ξ[41] - k[3,:]*ξ[42])*x[3,:] + 
		(-k[1,:]*ξ[43] - k[2,:]*ξ[44] - k[3,:]*ξ[45])*x[3,:] + 
		(k[1,:]*ξ[58] + k[2,:]*ξ[59] + k[3,:]*ξ[60])*x[4,:]) + 
	abs(-dx[1,:] + k[1,:]*ξ[1] + k[2,:]*ξ[2] + k[3,:]*ξ[3] + 
		(-k[1,:]*ξ[10] - k[2,:]*ξ[11] - k[3,:]*ξ[12])*x[1,:] + 
		(-k[1,:]*ξ[13] - k[2,:]*ξ[14] - k[3,:]*ξ[15])*x[1,:] + 
		(k[1,:]*ξ[22] + k[2,:]*ξ[23] + k[3,:]*ξ[24])*x[2,:] + 
		(k[1,:]*ξ[37] + k[2,:]*ξ[38] + k[3,:]*ξ[39])*x[3,:] + 
		(-k[1,:]*ξ[4] - k[2,:]*ξ[5] - k[3,:]*ξ[6])*x[1,:] + 
		(k[1,:]*ξ[52] + k[2,:]*ξ[53] + k[3,:]*ξ[54])*x[4,:] + 
		(-k[1,:]*ξ[7] - k[2,:]*ξ[8] - k[3,:]*ξ[9])*x[1,:]) + 
	abs(-dx[4,:] + k[1,:]*ξ[46] + k[2,:]*ξ[47] + k[3,:]*ξ[48] + 
		(k[1,:]*ξ[13] + k[2,:]*ξ[14] + k[3,:]*ξ[15])*x[1,:] + 
		(k[1,:]*ξ[28] + k[2,:]*ξ[29] + k[3,:]*ξ[30])*x[2,:] + 
		(k[1,:]*ξ[43] + k[2,:]*ξ[44] + k[3,:]*ξ[45])*x[3,:] + 
		(-k[1,:]*ξ[49] - k[2,:]*ξ[50] - k[3,:]*ξ[51])*x[4,:] + 
		(-k[1,:]*ξ[52] - k[2,:]*ξ[53] - k[3,:]*ξ[54])*x[4,:] + 
		(-k[1,:]*ξ[55] - k[2,:]*ξ[56] - k[3,:]*ξ[57])*x[4,:] + 
		(-k[1,:]*ξ[58] - k[2,:]*ξ[59] - k[3,:]*ξ[60])*x[4,:]) + 
	abs(-dx[2,:] + k[1,:]*ξ[16] + k[2,:]*ξ[17] + k[3,:]*ξ[18] + 
		(-k[1,:]*ξ[19] - k[2,:]*ξ[20] - k[3,:]*ξ[21])*x[2,:] + 
		(-k[1,:]*ξ[22] - k[2,:]*ξ[23] - k[3,:]*ξ[24])*x[2,:] + 
		(-k[1,:]*ξ[25] - k[2,:]*ξ[26] - k[3,:]*ξ[27])*x[2,:] + 
		(-k[1,:]*ξ[28] - k[2,:]*ξ[29] - k[3,:]*ξ[30])*x[2,:] + 
		(k[1,:]*ξ[40] + k[2,:]*ξ[41] + k[3,:]*ξ[42])*x[3,:] + 
		(k[1,:]*ξ[55] + k[2,:]*ξ[56] + k[3,:]*ξ[57])*x[4,:] + 
		(k[1,:]*ξ[7] + k[2,:]*ξ[8] + k[3,:]*ξ[9])*x[1,:]))

I also modified this function to use abs instead of ^2, because ^2 was causing a domain error in ReverseDiff (“DomainError with -1.8919415434611375, in call to log, in call to to ^”).

Here’s some code to test the performance of all of the backends:

begin
	using Statistics
	using Zygote
	using ProfileCanvas
	using BenchmarkTools
	using DifferentiationInterface
	using ReverseDiff
	using ChainRulesCore
	using Diffractor
	using Enzyme
	using FiniteDiff
	using FiniteDifferences
	using ForwardDiff
	using PolyesterForwardDiff
	using SparseDiffTools
	using Tracker
	using FastDifferentiation
	using Symbolics
	using Tapir
	using Plots
end

begin
	v = 4
	d = 10
	nk = 3
	nξ = 60
	x = rand(v, d)
	dx = rand(v, d)
	ξ = rand(nξ)
	k = rand(nk, d)
end

f(ξ) = loss(x, ξ, k, dx)

diffbackends = [
	#AutoChainRules()
	#AutoDiffractor()
	#AutoEnzyme(; mode=Enzyme.Forward)
	#AutoEnzyme(; mode=Enzyme.Reverse)
	AutoFiniteDiff()
	#AutoFiniteDifferences()
	AutoForwardDiff()
	#AutoPolyesterForwardDiff()
	AutoReverseDiff() # DomainError with -1.8919415434611375, in call to `log`, in call to to `^`
	AutoSparseForwardDiff()
	#AutoSparseFiniteDiff()
	#AutoTracker() # Took > 10 min to compile/run the first time
	AutoZygote()
	AutoFastDifferentiation()
	AutoSparseFastDifferentiation()
	AutoSymbolics()
	AutoSparseSymbolics()
	AutoTapir()
]

nms = []
tms = []
for i in 1:length(diffbackends)
	push!(nms, string(Base.typename(typeof(diffbackends[i])).wrapper))
	@info nms[end]
	push!(tms, @belapsed value_and_gradient(f, diffbackends[$i], ξ))
end

bar(1:length(nms), log10.(tms), ylim=(0,length(nms)+1), yticks=(1:length(nms), nms), 
	orientation=:h, xlabel="log10(time) (s)", xlims=(-5, 2.5), label=:none)

I’ve commented out the ones that errored or took >10 minutes to compile/run the first time. Here is the resulting plot:
image

I tried calculating the analytical gradient using symbolics.jl and then rewriting it to handle vectors, with the result looking like this:

loss_grad(g, x, ξ, k, dx) = begin
@views begin
		begin
			g[1] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])))
			g[2] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :]))
			g[3] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :]))
			g[4] = mean(#= none:1 =# @__dot__((-2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :]) * x[1, :]))
			g[5] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[1, :]))
			g[6] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[1, :]))
			g[7] = mean(#= none:1 =# @__dot__((-2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :]) * x[1, :] + (2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :]) * x[1, :]))
			g[8] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[1, :] + (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[1, :]))
			g[9] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[1, :] + (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[1, :]))
			g[10] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :]) * x[1, :] - (2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :]) * x[1, :]))
			g[11] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :])) * k[2, :] * x[1, :] - (2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[1, :]))
			g[12] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :])) * k[3, :] * x[1, :] - (2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[1, :]))
			g[13] = mean(#= none:1 =# @__dot__((-2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :]) * x[1, :] + (2.0 * k[1, :]) * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :]) * x[1, :]))
			g[14] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[1, :] + (2.0 * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :])) * k[2, :] * x[1, :]))
			g[15] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[1, :] + (2.0 * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :])) * k[3, :] * x[1, :]))
			g[16] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])))
			g[17] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :]))
			g[18] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :]))
			g[19] = mean(#= none:1 =# @__dot__((-2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :]) * x[2, :]))
			g[20] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[2, :]))
			g[21] = mean(#= none:1 =# @__dot__((-2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[2, :]))
			g[22] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :]) * x[2, :] - (2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :]) * x[2, :]))
			g[23] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[2, :]))
			g[24] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[1, :]) + k[1, :] * ξ[1] + k[2, :] * ξ[2] + k[3, :] * ξ[3] + ((-(k[1, :]) * ξ[10] - k[2, :] * ξ[11]) - k[3, :] * ξ[12]) * x[1, :] + ((-(k[1, :]) * ξ[13] - k[2, :] * ξ[14]) - k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[22] + k[2, :] * ξ[23] + k[3, :] * ξ[24]) * x[2, :] + (k[1, :] * ξ[37] + k[2, :] * ξ[38] + k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[4] - k[2, :] * ξ[5]) - k[3, :] * ξ[6]) * x[1, :] + (k[1, :] * ξ[52] + k[2, :] * ξ[53] + k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[7] - k[2, :] * ξ[8]) - k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[2, :]))
			g[25] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :]) * x[2, :] - (2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :]) * x[2, :]))
			g[26] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :])) * k[2, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[2, :]))
			g[27] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[3, :]) + k[1, :] * ξ[31] + k[2, :] * ξ[32] + k[3, :] * ξ[33] + (k[1, :] * ξ[10] + k[2, :] * ξ[11] + k[3, :] * ξ[12]) * x[1, :] + (k[1, :] * ξ[25] + k[2, :] * ξ[26] + k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[34] - k[2, :] * ξ[35]) - k[3, :] * ξ[36]) * x[3, :] + ((-(k[1, :]) * ξ[37] - k[2, :] * ξ[38]) - k[3, :] * ξ[39]) * x[3, :] + ((-(k[1, :]) * ξ[40] - k[2, :] * ξ[41]) - k[3, :] * ξ[42]) * x[3, :] + ((-(k[1, :]) * ξ[43] - k[2, :] * ξ[44]) - k[3, :] * ξ[45]) * x[3, :] + (k[1, :] * ξ[58] + k[2, :] * ξ[59] + k[3, :] * ξ[60]) * x[4, :])) * k[3, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[2, :]))
			g[28] = mean(#= none:1 =# @__dot__((2.0 * k[1, :]) * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :]) * x[2, :] - (2.0 * k[1, :]) * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :]) * x[2, :]))
			g[29] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :])) * k[2, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[2, :] * x[2, :]))
			g[30] = mean(#= none:1 =# @__dot__((2.0 * (-(dx[4, :]) + k[1, :] * ξ[46] + k[2, :] * ξ[47] + k[3, :] * ξ[48] + (k[1, :] * ξ[13] + k[2, :] * ξ[14] + k[3, :] * ξ[15]) * x[1, :] + (k[1, :] * ξ[28] + k[2, :] * ξ[29] + k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[43] + k[2, :] * ξ[44] + k[3, :] * ξ[45]) * x[3, :] + ((-(k[1, :]) * ξ[49] - k[2, :] * ξ[50]) - k[3, :] * ξ[51]) * x[4, :] + ((-(k[1, :]) * ξ[52] - k[2, :] * ξ[53]) - k[3, :] * ξ[54]) * x[4, :] + ((-(k[1, :]) * ξ[55] - k[2, :] * ξ[56]) - k[3, :] * ξ[57]) * x[4, :] + ((-(k[1, :]) * ξ[58] - k[2, :] * ξ[59]) - k[3, :] * ξ[60]) * x[4, :])) * k[3, :] * x[2, :] - (2.0 * (-(dx[2, :]) + k[1, :] * ξ[16] + k[2, :] * ξ[17] + k[3, :] * ξ[18] + ((-(k[1, :]) * ξ[19] - k[2, :] * ξ[20]) - k[3, :] * ξ[21]) * x[2, :] + ((-(k[1, :]) * ξ[22] - k[2, :] * ξ[23]) - k[3, :] * ξ[24]) * x[2, :] + ((-(k[1, :]) * ξ[25] - k[2, :] * ξ[26]) - k[3, :] * ξ[27]) * x[2, :] + ((-(k[1, :]) * ξ[28] - k[2, :] * ξ[29]) - k[3, :] * ξ[30]) * x[2, :] + (k[1, :] * ξ[40] + k[2, :] * ξ[41] + k[3, :] * ξ[42]) * x[3, :] + (k[1, :] * ξ[55] + k[2, :] * ξ[56] + k[3, :] * ξ[57]) * x[4, :] + (k[1, :] * ξ[7] + k[2, :] * ξ[8] + k[3, :] * ξ[9]) * x[1, :])) * k[3, :] * x[2, :]))
# ....Truncated owing to message size limits...
		end
		g
	end
end

g = zeros(nξ)
tms2 = [tms..., @belapsed loss_grad($g, x, $ξ, k, dx)]
nms2 = [nms..., "Analytical rewrite"]
bar(1:length(nms2), log10.(tms2), ylim=(0,length(nms2)+1), 
	yticks=(1:length(nms2), nms2), 
	orientation=:h, xlabel="log10(time) (s)", xlims=(-5, 2.5), label=:none)

As you can see in the plot below, this version (labeled “Analytical rewrite”) is similar to the fastest other options:
image

However, the compile time is slow, and seems to grow exponentially as x and g get larger…

Part of the explanation for your results is that many autodiff packages need “preparation” to be performant, and you don’t want to include that preparation step in your benchmark. Examples include the config in ForwardDiff, the tape in ReverseDiff, the symbolic compilation in FastDifferentiation and so on.
To make your life easier, DifferentiationInterfaceTest.jl includes benchmarking utilities that prepare the operator before benchmarking. Take a look at the tutorial: Tutorial · DifferentiationInterfaceTest.jl

Also note that sparse backends won’t change anything here: they are only useful for jacobians and hessians, not for gradients.

Finally, your function seems extremely inefficient as written: when you profile it, most of the time is spent copying stuff / managing eltypes.

2 Likes

I wouldn’t expect symbolics to give a better result than forward-mode AD for gradients. It’s a very similar calculation for that, but has expression growth for a compile time disadvantage. Only sparse Jacobians would likely have a real use case where symbolics would outperform AD.

This reminds me I should probably make benchmarking robust to errors in DifferentiationInterfaceTest.jl, with something like this

try
    result = Chairmarks.@be f()
catch e
    result = # Inf everywhere
end

plot_35

I’ll investigate why some of the backends fail

Benchmarking code
begin
    using Statistics
    # using ProfileCanvas
    # using BenchmarkTools
    using DataFrames
    using Plots

    using DifferentiationInterface
    using DifferentiationInterfaceTest

    # using ChainRulesCore
    # using Diffractor
    using Enzyme
    using FastDifferentiation
    using FiniteDiff
    # using FiniteDifferences
    using ForwardDiff
    using PolyesterForwardDiff
    using ReverseDiff
    using SparseDiffTools
    using Symbolics
    using Tapir
    using Tracker
    using Zygote
end

function loss(x, ξ, k, dx)
    k1, k2, k3 = view(k, 1, :), view(k, 2, :), view(k, 3, :)
    x1, x2, x3, x4 = view(x, 1, :), view(x, 2, :), view(x, 3, :), view(x, 4, :)
    dx1, dx2, dx3, dx4 = view(dx, 1, :), view(dx, 2, :), view(dx, 3, :), view(dx, 4, :)
    y = @. begin
        abs(
            -dx3 +
            k1 * ξ[31] +
            k2 * ξ[32] +
            k3 * ξ[33] +
            (k1 * ξ[10] + k2 * ξ[11] + k3 * ξ[12]) * x1 +
            (k1 * ξ[25] + k2 * ξ[26] + k3 * ξ[27]) * x2 +
            (-k1 * ξ[34] - k2 * ξ[35] - k3 * ξ[36]) * x3 +
            (-k1 * ξ[37] - k2 * ξ[38] - k3 * ξ[39]) * x3 +
            (-k1 * ξ[40] - k2 * ξ[41] - k3 * ξ[42]) * x3 +
            (-k1 * ξ[43] - k2 * ξ[44] - k3 * ξ[45]) * x3 +
            (k1 * ξ[58] + k2 * ξ[59] + k3 * ξ[60]) * x4,
        ) +
        abs(
            -dx1 +
            k1 * ξ[1] +
            k2 * ξ[2] +
            k3 * ξ[3] +
            (-k1 * ξ[10] - k2 * ξ[11] - k3 * ξ[12]) * x1 +
            (-k1 * ξ[13] - k2 * ξ[14] - k3 * ξ[15]) * x1 +
            (k1 * ξ[22] + k2 * ξ[23] + k3 * ξ[24]) * x2 +
            (k1 * ξ[37] + k2 * ξ[38] + k3 * ξ[39]) * x3 +
            (-k1 * ξ[4] - k2 * ξ[5] - k3 * ξ[6]) * x1 +
            (k1 * ξ[52] + k2 * ξ[53] + k3 * ξ[54]) * x4 +
            (-k1 * ξ[7] - k2 * ξ[8] - k3 * ξ[9]) * x1,
        ) +
        abs(
            -dx4 +
            k1 * ξ[46] +
            k2 * ξ[47] +
            k3 * ξ[48] +
            (k1 * ξ[13] + k2 * ξ[14] + k3 * ξ[15]) * x1 +
            (k1 * ξ[28] + k2 * ξ[29] + k3 * ξ[30]) * x2 +
            (k1 * ξ[43] + k2 * ξ[44] + k3 * ξ[45]) * x3 +
            (-k1 * ξ[49] - k2 * ξ[50] - k3 * ξ[51]) * x4 +
            (-k1 * ξ[52] - k2 * ξ[53] - k3 * ξ[54]) * x4 +
            (-k1 * ξ[55] - k2 * ξ[56] - k3 * ξ[57]) * x4 +
            (-k1 * ξ[58] - k2 * ξ[59] - k3 * ξ[60]) * x4,
        ) +
        abs(
            -dx2 +
            k1 * ξ[16] +
            k2 * ξ[17] +
            k3 * ξ[18] +
            (-k1 * ξ[19] - k2 * ξ[20] - k3 * ξ[21]) * x2 +
            (-k1 * ξ[22] - k2 * ξ[23] - k3 * ξ[24]) * x2 +
            (-k1 * ξ[25] - k2 * ξ[26] - k3 * ξ[27]) * x2 +
            (-k1 * ξ[28] - k2 * ξ[29] - k3 * ξ[30]) * x2 +
            (k1 * ξ[40] + k2 * ξ[41] + k3 * ξ[42]) * x3 +
            (k1 * ξ[55] + k2 * ξ[56] + k3 * ξ[57]) * x4 +
            (k1 * ξ[7] + k2 * ξ[8] + k3 * ξ[9]) * x1,
        )
    end
    return mean(y)
end

begin
    v = 4
    d = 10
    nk = 3
    nξ = 60
    x = rand(v, d)
    dx = rand(v, d)
    ξ = rand(nξ)
    k = rand(nk, d)
end

f(ξ) = loss(x, ξ, k, dx)
f(ξ)

scenarios = [GradientScenario(f; x=ξ, operator=:inplace)]

backends = [
    # AutoDiffractor(),
    # AutoEnzyme(; mode=Enzyme.Forward),
    # AutoEnzyme(; mode=Enzyme.Reverse),
    AutoFastDifferentiation(),
    AutoFiniteDiff(),
    # AutoFiniteDifferences(),
    AutoForwardDiff(),
    # AutoPolyesterForwardDiff(; chunksize=8),
    # AutoTracker(),
    # AutoReverseDiff(),
    # AutoSymbolics(),
    # AutoTapir(),
    AutoZygote(),
]

result = benchmark_differentiation(backends, scenarios; logging=true)

df = DataFrame(result)

df_filtered = df[df[!, :operator] .== :gradient!, :]

plt = bar(
    df_filtered[!, :backend],
    df_filtered[!, :time],
    label=nothing,
    xlabel="backend",
    ylabel="runtime [log]",
    xrotation=10,
    yscale=:log10,
    margin=15Plots.mm
)

savefig(plt, "benchmark.png")
2 Likes

FastDifferentiation.jl specializes a lot on + and * (assuming associativity and commutativity) so it’s cool to see that come into play here.

1 Like

@brianguenter you’ll like this

Beware that forward and reverse mode scale quite differently with the number of parameters, so you are likely to get very different results for e.g. 1000 parameters.

Thanks everyone! I’ve eventually figured out a way to turn the above operation into something closer to matrix multiplications, which I guess is probably the best solution for this kind of thing.

1 Like

I’d be curious if you could share the updated code, I’d love to benchmark it again. I think the reason why some backends lagged forever is the complexity of the function

Here is the code:

umask = [
	1.0 0.0 0.0 0.0 0.0; 
	0.0 1.0 0.0 0.0 0.0; 
	0.0 1.0 0.0 0.0 0.0; 
	0.0 1.0 0.0 0.0 0.0; 
	0.0 1.0 0.0 0.0 0.0; 
	1.0 0.0 0.0 0.0 0.0; 
	0.0 0.0 1.0 0.0 0.0; 
	0.0 0.0 1.0 0.0 0.0; 
	0.0 0.0 1.0 0.0 0.0; 
	0.0 0.0 1.0 0.0 0.0; 
	1.0 0.0 0.0 0.0 0.0; 
	0.0 0.0 0.0 1.0 0.0; 
	0.0 0.0 0.0 1.0 0.0; 
	0.0 0.0 0.0 1.0 0.0; 
	0.0 0.0 0.0 1.0 0.0; 
	1.0 0.0 0.0 0.0 0.0; 
	0.0 0.0 0.0 0.0 1.0; 
	0.0 0.0 0.0 0.0 1.0; 
	0.0 0.0 0.0 0.0 1.0; 
	0.0 0.0 0.0 0.0 1.0;
]

stoich = [
	1.0 -1.0 -1.0 -1.0 -1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0; 
	0.0 0.0 1.0 0.0 0.0 1.0 -1.0 -1.0 -1.0 -1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0; 
	0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 -1.0 -1.0 -1.0 -1.0 0.0 0.0 0.0 0.0 1.0; 
	0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 1.0 -1.0 -1.0 -1.0 -1.0;
]

function dudt_mat(u, ξ, k)
	oneplusu = [1; u]
	ξmat = transpose(reshape(ξ, nrate, :))
	ratelaws = (ξmat * k) .* (umask * oneplusu)
	stoich * ratelaws
end

function loss_mat(u, ξ, k, dudt_target)
	mean((dudt(u, ξ, k) .- dudt_target).^2)
end

let
	f(ξ) = loss_mat(x, ξ, k, dx)
	@btime value_and_gradient(f, AutoZygote(), $ξ) # 41.555 ms (1227564 allocations: 51.30 MiB)
end

Thanks for your help!

Interestingly, the speed of this last version is almost identical to the first version, but the execution and compile times scale much better with increased numbers of parameters and state variables.

I think it might scale even better if you make umask and stoich

  • sparse matrices with SparseArrays.sparse
  • constant with const

We’ll try that, thanks! (The actual implementation is type stable, but we haven’t tried sparse matrices.)