# DataInterpolations.jl and Zygote.jl

Hi everyone, I am trying to differentiate through an interpolation (as offered by `DataInterpolations.jl`.
Here is a MWE

``````x = Array(LinRange(0,10,101))
a = 1
b = 2
c = 3
p(a,b,c) = a.+b.*x.+c.*x.^2

function Q(a,b,c,x)
akima = AkimaInterpolation(p(a,b,c), x)
return akima.(x)
end

q(a,b,c) = Q(a,b,c,x)
``````

The first call to gradient works, the second one no.

``````Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations`
``````

So, which are the possible fixes? Am I doing something wrong?
I am considering writing a small routine to implement a quadratic interpolation from scratch, but maybe this can be solved in another wayβ¦?
Thanks to everyone,
Marco

1 Like

You mean the gradient of the gradient, so the second derivative call? You didnβt include the code that errors.

Hi @ChrisRackauckas ,
I was a bit unprecise
The first call to gradient (the one differentiating the `p` function I defined) work as expected.
The second call to gradient (the one differentiating the `q` function I defined) does not work.

``````using Zygote, DataInterpolations
x = Array(LinRange(0,10,101))
a = 1
b = 2
c = 3
p(a,b,c) = a.+b.*x.+c.*x.^2

function Q(a,b,c,x)
akima = AkimaInterpolation(p(a,b,c), x)
return akima.(x)
end

q(a,b,c) = Q(a,b,c,x)
``````

In this MWE, the last line will raise the error (here the full stack trace).

Stacktrace
``````julia> gradient(b -> sum(q(a,b,c)), b)
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] _throw_mutation_error(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:70
[3] (::Zygote.var"#539#540"{Vector{Float64}})(::Nothing)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:82
[4] (::Zygote.var"#2623#back#541"{Zygote.var"#539#540"{Vector{Float64}}})(Ξ::Nothing)
[5] #AkimaInterpolation#5
@ ~/.julia/packages/DataInterpolations/Pz5Mr/src/interpolation_caches.jl:159 [inlined]
[6] (::Zygote.Pullback{Tuple{β¦}, Tuple{β¦}})(Ξ::Nothing)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[7] AkimaInterpolation
@ ~/.julia/packages/DataInterpolations/Pz5Mr/src/interpolation_caches.jl:142 [inlined]
[8] (::Zygote.Pullback{Tuple{Type{β¦}, Vector{β¦}, Vector{β¦}}, Tuple{Zygote.Pullback{β¦}}})(Ξ::Nothing)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[9] Q
@ ./REPL[7]:2 [inlined]
[10] (::Zygote.Pullback{Tuple{β¦}, Tuple{β¦}})(Ξ::FillArrays.Fill{Float64, 1, Tuple{β¦}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[11] q
@ ./REPL[8]:1 [inlined]
[12] (::Zygote.Pullback{Tuple{β¦}, Tuple{β¦}})(Ξ::FillArrays.Fill{Float64, 1, Tuple{β¦}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[13] #1
@ ./REPL[9]:1 [inlined]
[14] (::Zygote.Pullback{Tuple{β¦}, Tuple{β¦}})(Ξ::Float64)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[15] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{β¦}, Tuple{β¦}}})(Ξ::Float64)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
[17] top-level scope
@ REPL[9]:1
Some type information was truncated. Use `show(err)` to see complete types.

``````

Thank you again!

It currently only has the adjoints to differentiate w.r.t. t

1 Like

Ok!
So, two follow-up questions:

1. Would there be interest in a PR adding adjoints wrt `u` and `A` to DataInterpolations ?
2. If so, are you aware of somewhere where I could find the rules? If so, I could try to implement them myself.

Thanks,
Marco

Yes

Youβd have to derive it. Itβs just calculus on the interpolation functions so itβs not hard but someone has to do it to avoid the mutation.

Or Enzyme should work.

1 Like

Ok, thank you!
I will go through the math.
I started going through the code, as you said it shouldnβt be that difficult.
For instance, I started from `QuadraticSpline` (the one I actually need).

I think we need some additional rules, for stuff like `Tridiagonal`

``````Zygote.@adjoint function Tridiagonal(dl, d, du)
y = Tridiagonal(dl, d, du)
function Tridiagonal_pullback(yΜ)
βdl = @thunk(Array(diag(yΜ, -1)))
βd  = @thunk(Array(diag(yΜ, 0)))
βdu = @thunk(Array(diag(yΜ, +1)))
return (βdl, βd, βdu)
end
return y, Tridiagonal_pullback
end
``````

Curiosly, when I was trying to define the rule with `ChainRules`

``````function rrule(::typeof(Tridiagonal), dl, d, du)
y = Tridiagonal(dl, d, du)
project_dl = ProjectTo(dl)
project_d = ProjectTo(d)
project_du = ProjectTo(du)

function Tridiagonal_pullback(yΜ)
βdl = (diag(yΜ, -1))
βd  = (diag(yΜ, 0))
βdu = (diag(yΜ, +1))
return NoTangent(), project_dl(βdl), project_d(βd), project_du(βdu)
end
return y, times_pullback
end
``````

It was not working (it gave the same errore as before, maybe there is a rule that takes precedence when using `Zygote`?)

@ChrisRackauckas , if you were willing to guide me a bit, how should we proceed?
I think I would start understanding which `rules` need to be created and do some local checks. After doing this and writing something that works locally on my laptop, I would proceed with opening a PR. Does it make sense to you?

Edit: I was likely wrong, the Tridiagonal adjoint is actually not needed.

The rrules that are needed are the ones on the constructors for the interpolation, which need a derivation of the derivative of the spline coefficients with respect to the data values.

But did you check if Enzyme just works?

1 Like

Hi @ChrisRackauckas ,
no, I have to check whether Enzyme will work or not.
Regarding Zygote, I had to add the rule for the tridiagonal matrix, and later it was able to perform the differentiation by itself. However, adding rules to other pieces of the code (e.g. the constructor of the spline) improved performance.

# A MWE

I want to be able to differentiate something like this using `Zygote`

``````#random data creation
n = 64
x = vcat([0.], sort(rand(n-2)), [1.])
x1 = vcat([0.], sort(2*rand(10n-2)), [2.])
y = rand(n)
#interpolates original data (y evaluated on x) on a new x1
function di_spline(y,x,xn)
spline = QuadraticSpline(y,x, extrapolate = true)
return spline.(xn)
end

``````

# Prerequisites: `Tridiagonal` rrule

In order to be able to differentiate, I had to implement the `Tridiagonal` rrule

rrule implementation and validation
``````Zygote.@adjoint function Tridiagonal(dl, d, du)
y = Tridiagonal(dl, d, du)
function Tridiagonal_pullback(yΜ)
βdl = @thunk(Array(diag(yΜ, -1)))
βd  = @thunk(Array(diag(yΜ, 0)))
βdu = @thunk(Array(diag(yΜ, +1)))
return (βdl, βd, βdu)
end
return y, Tridiagonal_pullback
end

function rrule(::typeof(Tridiagonal), dl, d, du)
βdl = @thunk(Array(diag(yΜ, -1)))
βd  = @thunk(Array(diag(yΜ, 0)))
βdu = @thunk(Array(diag(yΜ, +1)))
return (NoTangent(), βdl, βd, βdu)
end
end
``````

I checked the implemented rrule against `FiniteDifferences`

``````d = rand(1024)
dl = rand(1023)
du = rand(1023)
FiniteDifferences.grad(central_fdm(5, 1), du -> sum(Tridiagonal(dl, d, du)), du)[1]βgradient(du -> sum(Tridiagonal(dl, d, du)), du)[1]#true
FiniteDifferences.grad(central_fdm(5, 1), dl -> sum(Tridiagonal(dl, d, du)), dl)[1]βgradient(dl -> sum(Tridiagonal(dl, d, du)), du)[1]#true
FiniteDifferences.grad(central_fdm(5, 1), d -> sum(Tridiagonal(dl, d, du)), d)[1]βgradient(d -> sum(Tridiagonal(dl, d, du)), d)[1]#true
``````

After implementing the `Tridiagonal` rrule, I am able to differentiate `di_spline` (and it is consistent with `FiniteDifferences`). There is just a caveat: performance is terrrible!

``````@benchmark sum(di_spline(\$y,\$x,\$x1))
BenchmarkTools.Trial: 10000 samples with 4 evaluations.
Range (min β¦ max):  7.237 ΞΌs β¦ 998.080 ΞΌs  β GC (min β¦ max): 0.00% β¦ 97.62%
Time  (median):     8.534 ΞΌs               β GC (median):    0.00%
Time  (mean Β± Ο):   9.113 ΞΌs Β±  10.412 ΞΌs  β GC (mean Β± Ο):  1.70% Β±  2.19%

ββββββββββββββββββ    β                    β                β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
7.24 ΞΌs      Histogram: log(frequency) by time      16.6 ΞΌs <

Memory estimate: 7.98 KiB, allocs estimate: 7.

BenchmarkTools.Trial: 133 samples with 1 evaluation.
Range (min β¦ max):  35.348 ms β¦ 41.547 ms  β GC (min β¦ max): 0.00% β¦ 10.68%
Time  (median):     36.750 ms              β GC (median):    0.00%
Time  (mean Β± Ο):   37.707 ms Β±  1.963 ms  β GC (mean Β± Ο):  3.73% Β±  4.59%

β   ββ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
35.3 ms         Histogram: frequency by time        41.5 ms <

Memory estimate: 18.85 MiB, allocs estimate: 319149.
``````

# Implementing rrules

In order to be able to improve performance, I decided to implement the relevant rrules. Here comes the first issue: how to evaluate them? Although (I think) I mostly understand how to compute and implement them, I am not sure in such a case I can implement the rules in an efficient and clever way.
I hence followed the noy approach I could think of: I rewrote `QuadraticSpline`, basically copy-pasting the original code without using structs and implementing some utility function (basically, the strategy is to write down the rrule for the smaller functions).

My spline implementation
``````    s = length(t)
s_new = length(new_t)
dl = ones(eltype(t), s - 1)
d_tmp = ones(eltype(t), s)
du = zeros(eltype(t), s - 1)
tA = Tridiagonal(dl, d_tmp, du)

# zero for element type of d, which we don't know yet
typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin]))

d = create_d(u, t, s, typed_zero)#map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
z = tA \ d
i_list = create_i_list(t, new_t, s_new)#[min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i], firstindex(t) - 1)), length(t)) for i in 1:s_new]
Cα΅’_list = create_Cα΅’_list(u, i_list)#[u[i - 1] for i in i_list]
Ο = create_Ο(z, t, i_list)#[1 // 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]) for i in i_list]
return compose(z, t, new_t, Cα΅’_list, s_new, i_list, Ο)#[z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) + Ο[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cα΅’_list[i] for i in 1:s_new]
end

compose(z, t, new_t, Cα΅’_list, s_new, i_list, Ο) = map(i -> z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) + Ο[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cα΅’_list[i], 1:s_new)
create_Ο(z, t, i_list) = map(i -> 1 / 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]),  i_list)
create_Cα΅’_list(u, i_list) = map(i-> u[i - 1],  i_list)
create_i_list(t, new_t, s_new) = map(i-> min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i], firstindex(t) - 1)), length(t)),  1:s_new)
create_d(u, t, s, typed_zero) = map(i -> i == 1 ? typed_zero : 2 / 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
``````

I explicitely checked that it is equivalent to the original one.
After this I wrote the rrules (as before, I checked the correctness for the rrules)

my_spline rrules
``````Zygote.@adjoint function create_d(u, t, s, typed_zero)
y = create_d(u, t, s, typed_zero)
function create_d_pullback(yΜ)
βu = Tridiagonal(zeros(eltype(typed_zero), s-1),
map(i -> i == 1 ? typed_zero : 2 / (t[i] - t[i - 1]), 1:s),
map(i -> - 2 / (t[i+1] - t[i]), 1:s-1)) * yΜ
βt = Tridiagonal(zeros(eltype(typed_zero), s-1),
map(i -> i == 1 ? typed_zero : -2 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]) ^ 2, 1:s),
map(i -> 2 * (u[i+1] - u[i]) / (t[i+1] - t[i]) ^ 2, 1:s-1)) * yΜ
return (βu, βt, NoTangent(), NoTangent())
end
return y, create_d_pullback
end

y = create_Ο(z, x, i_list)
function create_Ο_pullback(yΜ)
s = length(z)
s1 = length(i_list)
βz = zeros(s,s1)
βx = zeros(s,s1)

for j in 1:s1
i = i_list[j]
a = @views (z[i] - z[i-1])
b = @views (x[i] - x[i-1])
βz[i,j] += 0.5 / b
βz[i-1,j] -= 0.5 / b
βx[i,j] -= 0.5 * a / b^2
βx[i-1,j] += 0.5 * a / b^2
end

βz = βz * yΜ
βx = βx * yΜ
return (βz, βx, NoTangent())
end
return y, create_Ο_pullback
end #works, but performance can be a bit improved

y = create_i_list(t, new_t, s_new)
function create_i_list_pullback(yΜ)
return (NoTangent(), NoTangent(), NoTangent())
end
return y, create_i_list_pullback

``````

# The final result

So, was it worth? Performance definitely improved

``````@benchmark gradient(\$y->sum(my_spline(\$y,\$x,\$x)), \$y)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max):  46.449 ΞΌs β¦ 708.213 ΞΌs  β GC (min β¦ max): 0.00% β¦ 84.52%
Time  (median):     52.767 ΞΌs               β GC (median):    0.00%
Time  (mean Β± Ο):   58.348 ΞΌs Β±  47.454 ΞΌs  β GC (mean Β± Ο):  7.39% Β±  8.16%

ββββββββββ                                                 β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
46.4 ΞΌs       Histogram: log(frequency) by time       125 ΞΌs <

Memory estimate: 549.17 KiB, allocs estimate: 706.

``````

To be compared with the initial

``````@benchmark gradient(\$y->sum(di_spline(\$y,\$x,\$x)), \$y)
BenchmarkTools.Trial: 1119 samples with 1 evaluation.
Range (min β¦ max):  3.770 ms β¦   8.495 ms  β GC (min β¦ max): 0.00% β¦ 46.86%
Time  (median):     4.324 ms               β GC (median):    0.00%
Time  (mean Β± Ο):   4.466 ms Β± 724.365 ΞΌs  β GC (mean Β± Ο):  3.15% Β±  8.69%

ββββββ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
3.77 ms      Histogram: log(frequency) by time      8.15 ms <

Memory estimate: 2.28 MiB, allocs estimate: 38332.
``````

A factor of 80 improvement! On the precision side, the two implementation are mostly equivalent (I checked explicitly they give almost the same result and made a comparison with `FiniteDifferences`).

# So, what?

Now, the question is: what to do? For sure, for my project, I am happy with my spline implementation. However, I am sure that performance can still be improved (one of the rules I implemented does not take advantage of sparsity) and maybe (likely?) this is not even the smartest way to implement such a thing. @ChrisRackauckas , if there is any advice coming on how to make it better and integrate it into `DataInterpolations`, I would be happy to do it.
Maybe there is a smart Implicit Differentiation trick I missed (@gdalle )?

Edit : I still have to test Enzyme, as I am mostly using Zygote, and I am not too familiar with it.

1 Like

Sorry, I did not notice I had to update my env. Now it works.
There is still the issue with the differentiation of the other pieces of code. Any suggestion on how to proceed?

Any MWE on that? Iβm losing track of whatβs being asked.

Sure!

``````using Zygote
using ForwardDiff
using DiffRules
using BenchmarkTools
using ChainRulesCore
const RealOrComplex = Union{Real,Complex}
using DataInterpolations
using ChainRules
using LinearAlgebra
using FindFirstFunctions
using SparseArrays
using FiniteDifferences

n = 64
x = vcat([0.], sort(rand(n-2)), [1.])
x1 = vcat([0.], sort(rand(n-2)), [1.])
y = rand(n);

function di_spline(y,x,xn)
spline = QuadraticSpline(y,x, extrapolate = true)
return spline.(xn)
end

b1 = @benchmark sum(di_spline(\$y,\$x,\$x1))

function my_spline(u, t, new_t::AbstractArray)
s = length(t)
s_new = length(new_t)
dl = ones(eltype(t), s - 1)
d_tmp = ones(eltype(t), s)
du = zeros(eltype(t), s - 1)
tA = Tridiagonal(dl, d_tmp, du)

# zero for element type of d, which we don't know yet
typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin]))

d = create_d(u, t, s, typed_zero)#map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
z = tA \ d
i_list = create_i_list(t, new_t, s_new)#[min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i], firstindex(t) - 1)), length(t)) for i in 1:s_new]
Cα΅’_list = create_Cα΅’_list(u, i_list)#[u[i - 1] for i in i_list]
Ο = create_Ο(z, t, i_list)#[1 // 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]) for i in i_list]
return compose(z, t, new_t, Cα΅’_list, s_new, i_list, Ο)#[z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) + Ο[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cα΅’_list[i] for i in 1:s_new]
end

compose(z, t, new_t, Cα΅’_list, s_new, i_list, Ο) = map(i -> z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) + Ο[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cα΅’_list[i], 1:s_new)
create_Ο(z, t, i_list) = map(i -> 1 / 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]),  i_list)
create_Cα΅’_list(u, i_list) = map(i-> u[i - 1],  i_list)
create_i_list(t, new_t, s_new) = map(i-> min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i], firstindex(t) - 1)), length(t)),  1:s_new)
create_d(u, t, s, typed_zero) = map(i -> i == 1 ? typed_zero : 2 / 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)

b3 = @benchmark sum(my_spline(\$y,\$x,\$x1))

Zygote.@adjoint function create_d(u, t, s, typed_zero)
y = create_d(u, t, s, typed_zero)
function create_d_pullback(yΜ)
βu = Tridiagonal(zeros(eltype(typed_zero), s-1),
map(i -> i == 1 ? typed_zero : 2 / (t[i] - t[i - 1]), 1:s),
map(i -> - 2 / (t[i+1] - t[i]), 1:s-1)) * yΜ
βt = Tridiagonal(zeros(eltype(typed_zero), s-1),
map(i -> i == 1 ? typed_zero : -2 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]) ^ 2, 1:s),
map(i -> 2 * (u[i+1] - u[i]) / (t[i+1] - t[i]) ^ 2, 1:s-1)) * yΜ
return (βu, βt, NoTangent(), NoTangent())
end
return y, create_d_pullback
end

y = create_Ο(z, x, i_list)
function create_Ο_pullback(yΜ)
s = length(z)
s1 = length(i_list)
βz = zeros(s,s1)
βx = zeros(s,s1)

for j in 1:s1
i = i_list[j]
a = @views (z[i] - z[i-1])
b = @views (x[i] - x[i-1])
βz[i,j] = 0.5 / b
βz[i-1,j] = -0.5 / b
βx[i,j] = -0.5 * a / b^2
βx[i-1,j] = 0.5 * a / b^2
end

βz = βz * yΜ
βx = βx * yΜ
return (βz, βx, NoTangent())
end
return y, create_Ο_pullback
end

y = create_i_list(t, new_t, s_new)
function create_i_list_pullback(yΜ)
return (NoTangent(), NoTangent(), NoTangent())
end
return y, create_i_list_pullback
end

``````

``````BenchmarkTools.Trial: 1280 samples with 1 evaluation.
Range (min β¦ max):  3.600 ms β¦   6.411 ms  β GC (min β¦ max): 0.00% β¦ 30.68%
Time  (median):     3.721 ms               β GC (median):    0.00%
Time  (mean Β± Ο):   3.907 ms Β± 513.169 ΞΌs  β GC (mean Β± Ο):  2.42% Β±  7.02%

ββββββββ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
3.6 ms       Histogram: log(frequency) by time       5.9 ms <

Memory estimate: 2.28 MiB, allocs estimate: 38307.
``````

``````BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max):  44.632 ΞΌs β¦ 702.216 ΞΌs  β GC (min β¦ max): 0.00% β¦ 75.22%
Time  (median):     56.627 ΞΌs               β GC (median):    0.00%
Time  (mean Β± Ο):   65.096 ΞΌs Β±  34.856 ΞΌs  β GC (mean Β± Ο):  4.53% Β±  8.01%

βββ
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
44.6 ΞΌs         Histogram: frequency by time          265 ΞΌs <

Memory estimate: 550.83 KiB, allocs estimate: 707.
``````

So, the question is: how to write the rules for the `DataInterpolation` spline, without dividing its functions in the composition of smaller functions as I did?

You can write the rule on the `QuadraticSpline` constructor itself, and on its calls to interpolation.

1 Like

The constructor should not be a problem, but I donβt know how to evaluate the adjoint of

``````function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first)
Cα΅’ = A.u[idx - 1]
Ο = 1 // 2 * (A.z[idx] - A.z[idx - 1]) / (A.t[idx] - A.t[idx - 1])
return A.z[idx - 1] * (t - A.t[idx - 1]) + Ο * (t - A.t[idx - 1])^2 + Cα΅’, idx
end
``````

without dividing it in smaller chunks. I know, itβs a limitation of mine
Edit: probably a solution could be to compose together the functions I wrote before, but I am not sure this is the best approachπ

This probably should be an issue.

Opened a new issue.

Lilβ update.

``````@benchmark gradient(\$y->sum(my_spline(\$y,\$x,\$x1)), \$y)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max):  17.007 ΞΌs β¦ 706.183 ΞΌs  β GC (min β¦ max): 0.00% β¦ 87.69%
Time  (median):     19.779 ΞΌs               β GC (median):    0.00%
Time  (mean Β± Ο):   22.133 ΞΌs Β±  25.732 ΞΌs  β GC (mean Β± Ο):  7.35% Β±  6.13%

ββββββββββββββββββ β                                      β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
17 ΞΌs         Histogram: log(frequency) by time      36.2 ΞΌs <

Memory estimate: 210.53 KiB, allocs estimate: 123.
``````

A factor of three improvement, and also the number of allocation dropped significantly. I think now the last piece is taking full advantage of the sparse structure of the problem (now I am allocating Dense matrices, but most of those elements are zeros).

Implemented rules
``````Zygote.@adjoint function create_Cα΅’_list(u, i_list)
y = create_Cα΅’_list(u, i_list)
function create_Cα΅’_list_pullback(yΜ)
s = length(z)
s1 = length(i_list)
βCα΅’_list = zeros(s,s1)

for j in 1:s1
i = i_list[j]
βCα΅’_list[i-1,j] = 1.
end
βCα΅’_list = βCα΅’_list * yΜ
return (βCα΅’_list, NoTangent())
end
return y, create_Cα΅’_list_pullback
end

Zygote.@adjoint function compose(z, t, new_t, Cα΅’_list, s_new, i_list, Ο)
y = compose(z, t, new_t, Cα΅’_list, s_new, i_list, Ο)
function compose_pullback(yΜ)
s = length(z)
s1 = length(i_list)
βz = zeros(s,s1)
βt = zeros(s,s1)
βt1 = zeros(s1,s1)

for j in 1:s1
i = i_list[j]
βz[i-1,j] = new_t[j] - t[i_list[j] - 1]
βt[i-1,j] = -z[i_list[j] - 1]  - 2Ο[j] * (new_t[j] - t[i_list[j] - 1])
βt1[j,j] = +z[i_list[j] - 1]  + 2Ο[j] * (new_t[j] - t[i_list[j] - 1])
end

βz = βz * yΜ
βt = βt * yΜ
βt1 = βt1 * yΜ
βΟ = Diagonal(map(i -> (new_t[i] - t[i_list[i] - 1])^2, 1:s_new)) * yΜ
βCα΅’_list = Diagonal(ones(s1)) * yΜ
return (βz, βt, βt1, βCα΅’_list, NoTangent(), NoTangent(), βΟ)
end
return y, compose_pullback
end
``````