DataInterpolations.jl and Zygote.jl

Hi everyone, I am trying to differentiate through an interpolation (as offered by DataInterpolations.jl.
Here is a MWE

x = Array(LinRange(0,10,101))
a = 1
b = 2
c = 3
p(a,b,c) = a.+b.*x.+c.*x.^2
gradient(b -> sum(p(a,b,c)), b)

function Q(a,b,c,x)
    akima = AkimaInterpolation(p(a,b,c), x)
    return akima.(x)
end

q(a,b,c) = Q(a,b,c,x)
gradient(b -> sum(q(a,b,c)), b)

The first call to gradient works, the second one no.

Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations`

So, which are the possible fixes? Am I doing something wrong?
I am considering writing a small routine to implement a quadratic interpolation from scratch, but maybe this can be solved in another way…?
Thanks to everyone,
Marco

1 Like

You mean the gradient of the gradient, so the second derivative call? You didn’t include the code that errors.

Hi @ChrisRackauckas ,
thank you for your quick answer.
I was a bit unprecise
The first call to gradient (the one differentiating the p function I defined) work as expected.
The second call to gradient (the one differentiating the q function I defined) does not work.

using Zygote, DataInterpolations
x = Array(LinRange(0,10,101))
a = 1
b = 2
c = 3
p(a,b,c) = a.+b.*x.+c.*x.^2

function Q(a,b,c,x)
    akima = AkimaInterpolation(p(a,b,c), x)
    return akima.(x)
end

q(a,b,c) = Q(a,b,c,x)
gradient(b -> sum(q(a,b,c)), b)

In this MWE, the last line will raise the error (here the full stack trace).

Stacktrace
julia> gradient(b -> sum(q(a,b,c)), b)
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:70
  [3] (::Zygote.var"#539#540"{Vector{Float64}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:82
  [4] (::Zygote.var"#2623#back#541"{Zygote.var"#539#540"{Vector{Float64}}})(Ξ”::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
  [5] #AkimaInterpolation#5
    @ ~/.julia/packages/DataInterpolations/Pz5Mr/src/interpolation_caches.jl:159 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Ξ”::Nothing)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [7] AkimaInterpolation
    @ ~/.julia/packages/DataInterpolations/Pz5Mr/src/interpolation_caches.jl:142 [inlined]
  [8] (::Zygote.Pullback{Tuple{Type{…}, Vector{…}, Vector{…}}, Tuple{Zygote.Pullback{…}}})(Ξ”::Nothing)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [9] Q
    @ ./REPL[7]:2 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Ξ”::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [11] q
    @ ./REPL[8]:1 [inlined]
 [12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Ξ”::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [13] #1
    @ ./REPL[9]:1 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Ξ”::Float64)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [15] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Ξ”::Float64)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
 [16] gradient(::Function, ::Int64, ::Vararg{Int64})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
 [17] top-level scope
    @ REPL[9]:1
Some type information was truncated. Use `show(err)` to see complete types.

Thank you again!

It currently only has the adjoints to differentiate w.r.t. t

1 Like

Ok!
So, two follow-up questions:

  1. Would there be interest in a PR adding adjoints wrt u and A to DataInterpolations ?
  2. If so, are you aware of somewhere where I could find the rules? If so, I could try to implement them myself.

Thanks,
Marco

Yes

You’d have to derive it. It’s just calculus on the interpolation functions so it’s not hard but someone has to do it to avoid the mutation.

Or Enzyme should work.

1 Like

Ok, thank you!
I will go through the math.
I started going through the code, as you said it shouldn’t be that difficult.
For instance, I started from QuadraticSpline (the one I actually need).

I think we need some additional rules, for stuff like Tridiagonal

Tridiagonal adjoint
Zygote.@adjoint function Tridiagonal(dl, d, du)
    y = Tridiagonal(dl, d, du)
    function Tridiagonal_pullback(yΜ„)
        βˆ‚dl = @thunk(Array(diag(yΜ„, -1)))
        βˆ‚d  = @thunk(Array(diag(yΜ„, 0)))
        βˆ‚du = @thunk(Array(diag(yΜ„, +1)))
        return (βˆ‚dl, βˆ‚d, βˆ‚du)
    end
    return y, Tridiagonal_pullback
end

Curiosly, when I was trying to define the rule with ChainRules

function rrule(::typeof(Tridiagonal), dl, d, du)
    y = Tridiagonal(dl, d, du)
    project_dl = ProjectTo(dl)
    project_d = ProjectTo(d)
    project_du = ProjectTo(du)
    
    function Tridiagonal_pullback(yΜ„)
        βˆ‚dl = (diag(yΜ„, -1))
        βˆ‚d  = (diag(yΜ„, 0))
        βˆ‚du = (diag(yΜ„, +1))
        return NoTangent(), project_dl(βˆ‚dl), project_d(βˆ‚d), project_du(βˆ‚du)
    end
    return y, times_pullback
end

It was not working (it gave the same errore as before, maybe there is a rule that takes precedence when using Zygote?)

@ChrisRackauckas , if you were willing to guide me a bit, how should we proceed?
I think I would start understanding which rules need to be created and do some local checks. After doing this and writing something that works locally on my laptop, I would proceed with opening a PR. Does it make sense to you?

Edit: I was likely wrong, the Tridiagonal adjoint is actually not needed.

The rrules that are needed are the ones on the constructors for the interpolation, which need a derivation of the derivative of the spline coefficients with respect to the data values.

But did you check if Enzyme just works?

1 Like

Hi @ChrisRackauckas ,
no, I have to check whether Enzyme will work or not.
Regarding Zygote, I had to add the rule for the tridiagonal matrix, and later it was able to perform the differentiation by itself. However, adding rules to other pieces of the code (e.g. the constructor of the spline) improved performance.

Follow up on this thread. It will be a loooong thread

A MWE

I want to be able to differentiate something like this using Zygote

#random data creation
n = 64
x = vcat([0.], sort(rand(n-2)), [1.])
x1 = vcat([0.], sort(2*rand(10n-2)), [2.])
y = rand(n)
#interpolates original data (y evaluated on x) on a new x1
function di_spline(y,x,xn)
    spline = QuadraticSpline(y,x, extrapolate = true)
    return spline.(xn)
end

gradient(y->sum(di_spline(y,x,x1)), y)#does not work

Prerequisites: Tridiagonal rrule

In order to be able to differentiate, I had to implement the Tridiagonal rrule

rrule implementation and validation
Zygote.@adjoint function Tridiagonal(dl, d, du)
    y = Tridiagonal(dl, d, du)
    function Tridiagonal_pullback(yΜ„)
        βˆ‚dl = @thunk(Array(diag(yΜ„, -1)))
        βˆ‚d  = @thunk(Array(diag(yΜ„, 0)))
        βˆ‚du = @thunk(Array(diag(yΜ„, +1)))
        return (βˆ‚dl, βˆ‚d, βˆ‚du)
    end
    return y, Tridiagonal_pullback
end

function rrule(::typeof(Tridiagonal), dl, d, du)
    Ξ© = Tridiagonal(dl, d, du)
    function Tridiagonal_pullback(ΔΩ)
        βˆ‚dl = @thunk(Array(diag(yΜ„, -1)))
        βˆ‚d  = @thunk(Array(diag(yΜ„, 0)))
        βˆ‚du = @thunk(Array(diag(yΜ„, +1)))
        return (NoTangent(), βˆ‚dl, βˆ‚d, βˆ‚du)
    end
    return Ξ©, Tridiagonal_pullback
end

I checked the implemented rrule against FiniteDifferences

d = rand(1024)
dl = rand(1023)
du = rand(1023)
FiniteDifferences.grad(central_fdm(5, 1), du -> sum(Tridiagonal(dl, d, du)), du)[1]β‰ˆgradient(du -> sum(Tridiagonal(dl, d, du)), du)[1]#true
FiniteDifferences.grad(central_fdm(5, 1), dl -> sum(Tridiagonal(dl, d, du)), dl)[1]β‰ˆgradient(dl -> sum(Tridiagonal(dl, d, du)), du)[1]#true
FiniteDifferences.grad(central_fdm(5, 1), d -> sum(Tridiagonal(dl, d, du)), d)[1]β‰ˆgradient(d -> sum(Tridiagonal(dl, d, du)), d)[1]#true

After implementing the Tridiagonal rrule, I am able to differentiate di_spline (and it is consistent with FiniteDifferences). There is just a caveat: performance is terrrible!

@benchmark sum(di_spline($y,$x,$x1))
BenchmarkTools.Trial: 10000 samples with 4 evaluations.
 Range (min … max):  7.237 ΞΌs … 998.080 ΞΌs  β”Š GC (min … max): 0.00% … 97.62%
 Time  (median):     8.534 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   9.113 ΞΌs Β±  10.412 ΞΌs  β”Š GC (mean Β± Οƒ):  1.70% Β±  2.19%

  β–‚β–ƒβ–„β–ƒβ–β–‚β–…β–‡β–ˆβ–„β–ƒβ–ƒβ–„β–…β–„β–„β–‚β–    ▁                    ▁                β–‚
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–†β–†β–†β–†β–†β–…β–…β–†β–„β–…β–…β–„β–„β–…β–„β–β–„β–‡β–ˆβ–ˆβ–ˆβ–‡β–†β–†β–…β–†β–†β–…β–…β–‡β–…β–‡β–‡β–‡ β–ˆ
  7.24 ΞΌs      Histogram: log(frequency) by time      16.6 ΞΌs <

 Memory estimate: 7.98 KiB, allocs estimate: 7.

@benchmark gradient($y->sum(di_spline($y,$x,$x1)), $y)
BenchmarkTools.Trial: 133 samples with 1 evaluation.
 Range (min … max):  35.348 ms … 41.547 ms  β”Š GC (min … max): 0.00% … 10.68%
 Time  (median):     36.750 ms              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   37.707 ms Β±  1.963 ms  β”Š GC (mean Β± Οƒ):  3.73% Β±  4.59%

      β–‚   β–„β–ˆ                                                   
  β–‡β–…β–ƒβ–…β–ˆβ–ˆβ–‡β–…β–ˆβ–ˆβ–‡β–„β–…β–‡β–†β–„β–„β–ƒβ–ƒβ–β–ƒβ–ƒβ–β–ƒβ–β–β–ƒβ–β–ƒβ–ƒβ–β–β–ƒβ–ƒβ–„β–β–β–ƒβ–β–„β–„β–ƒβ–„β–…β–‡β–…β–β–…β–„β–†β–ƒβ–…β–„β–ƒβ–…β–β–„β–β–„ β–ƒ
  35.3 ms         Histogram: frequency by time        41.5 ms <

 Memory estimate: 18.85 MiB, allocs estimate: 319149.

Implementing rrules

In order to be able to improve performance, I decided to implement the relevant rrules. Here comes the first issue: how to evaluate them? Although (I think) I mostly understand how to compute and implement them, I am not sure in such a case I can implement the rules in an efficient and clever way.
I hence followed the noy approach I could think of: I rewrote QuadraticSpline, basically copy-pasting the original code without using structs and implementing some utility function (basically, the strategy is to write down the rrule for the smaller functions).

My spline implementation
    s = length(t)
    s_new = length(new_t)
    dl = ones(eltype(t), s - 1)
    d_tmp = ones(eltype(t), s)
    du = zeros(eltype(t), s - 1)
    tA = Tridiagonal(dl, d_tmp, du)

    # zero for element type of d, which we don't know yet
    typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin]))

    d = create_d(u, t, s, typed_zero)#map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
    z = tA \ d
    i_list = create_i_list(t, new_t, s_new)#[min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i], firstindex(t) - 1)), length(t)) for i in 1:s_new]
    Cα΅’_list = create_Cα΅’_list(u, i_list)#[u[i - 1] for i in i_list]
    Οƒ = create_Οƒ(z, t, i_list)#[1 // 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]) for i in i_list]
    return compose(z, t, new_t, Cα΅’_list, s_new, i_list, Οƒ)#[z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) + Οƒ[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cα΅’_list[i] for i in 1:s_new]
end

compose(z, t, new_t, Cα΅’_list, s_new, i_list, Οƒ) = map(i -> z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) + Οƒ[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cα΅’_list[i], 1:s_new)
create_Οƒ(z, t, i_list) = map(i -> 1 / 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]),  i_list)
create_Cα΅’_list(u, i_list) = map(i-> u[i - 1],  i_list)
create_i_list(t, new_t, s_new) = map(i-> min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i], firstindex(t) - 1)), length(t)),  1:s_new)
create_d(u, t, s, typed_zero) = map(i -> i == 1 ? typed_zero : 2 / 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)

I explicitely checked that it is equivalent to the original one.
After this I wrote the rrules (as before, I checked the correctness for the rrules)

my_spline rrules
Zygote.@adjoint function create_d(u, t, s, typed_zero)
    y = create_d(u, t, s, typed_zero)
    function create_d_pullback(yΜ„)
        βˆ‚u = Tridiagonal(zeros(eltype(typed_zero), s-1),
               map(i -> i == 1 ? typed_zero : 2 / (t[i] - t[i - 1]), 1:s),
               map(i -> - 2 / (t[i+1] - t[i]), 1:s-1)) * yΜ„
        βˆ‚t = Tridiagonal(zeros(eltype(typed_zero), s-1),
               map(i -> i == 1 ? typed_zero : -2 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]) ^ 2, 1:s),
               map(i -> 2 * (u[i+1] - u[i]) / (t[i+1] - t[i]) ^ 2, 1:s-1)) * yΜ„
        return (βˆ‚u, βˆ‚t, NoTangent(), NoTangent())
    end
    return y, create_d_pullback
end

Zygote.@adjoint function create_Οƒ(z, x, i_list)
    y = create_Οƒ(z, x, i_list)
    function create_Οƒ_pullback(yΜ„)
        s = length(z)
        s1 = length(i_list)
        βˆ‚z = zeros(s,s1)
        βˆ‚x = zeros(s,s1)
        
        for j in 1:s1
            i = i_list[j]
            a = @views (z[i] - z[i-1])
            b = @views (x[i] - x[i-1])
            βˆ‚z[i,j] += 0.5 / b
            βˆ‚z[i-1,j] -= 0.5 / b
            βˆ‚x[i,j] -= 0.5 * a / b^2
            βˆ‚x[i-1,j] += 0.5 * a / b^2
        end
        
        βˆ‚z = βˆ‚z * yΜ„
        βˆ‚x = βˆ‚x * yΜ„
        return (βˆ‚z, βˆ‚x, NoTangent())
    end
    return y, create_Οƒ_pullback
end #works, but performance can be a bit improved

Zygote.@adjoint function create_i_list(t, new_t, s_new)
    y = create_i_list(t, new_t, s_new)
    function create_i_list_pullback(yΜ„)
        return (NoTangent(), NoTangent(), NoTangent())
    end
    return y, create_i_list_pullback
end#not sure about this

The final result

So, was it worth? Performance definitely improved

@benchmark gradient($y->sum(my_spline($y,$x,$x)), $y)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  46.449 ΞΌs … 708.213 ΞΌs  β”Š GC (min … max): 0.00% … 84.52%
 Time  (median):     52.767 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   58.348 ΞΌs Β±  47.454 ΞΌs  β”Š GC (mean Β± Οƒ):  7.39% Β±  8.16%

    β–ƒβ–†β–ˆβ–‡β–…β–„β–ƒβ–‚β–‚β–                                                 β–‚
  β–„β–†β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆβ–ˆβ–ˆβ–‡β–†β–†β–…β–…β–„β–β–„β–…β–„β–„β–„β–„β–β–β–β–β–β–β–β–β–β–β–β–β–β–ƒβ–…β–β–„β–„β–…β–„β–ƒβ–„β–ƒβ–β–β–β–β–ƒ β–ˆ
  46.4 ΞΌs       Histogram: log(frequency) by time       125 ΞΌs <

 Memory estimate: 549.17 KiB, allocs estimate: 706.

To be compared with the initial

@benchmark gradient($y->sum(di_spline($y,$x,$x)), $y)
BenchmarkTools.Trial: 1119 samples with 1 evaluation.
 Range (min … max):  3.770 ms …   8.495 ms  β”Š GC (min … max): 0.00% … 46.86%
 Time  (median):     4.324 ms               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   4.466 ms Β± 724.365 ΞΌs  β”Š GC (mean Β± Οƒ):  3.15% Β±  8.69%

      β–β–‚β–†β–ˆβ–†β–‚                                                   
  β–†β–…β–ˆβ–†β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–β–„β–β–β–β–β–β–β–β–„β–β–β–β–β–β–β–β–β–„β–„β–β–β–β–β–β–β–β–β–β–„β–„β–β–β–β–β–„β–„β–„β–β–β–…β–†β–†β–„β–„β–…β–†β–† β–ˆ
  3.77 ms      Histogram: log(frequency) by time      8.15 ms <

 Memory estimate: 2.28 MiB, allocs estimate: 38332.

A factor of 80 improvement! On the precision side, the two implementation are mostly equivalent (I checked explicitly they give almost the same result and made a comparison with FiniteDifferences).

So, what?

Now, the question is: what to do? For sure, for my project, I am happy with my spline implementation. However, I am sure that performance can still be improved (one of the rules I implemented does not take advantage of sparsity) and maybe (likely?) this is not even the smartest way to implement such a thing. @ChrisRackauckas , if there is any advice coming on how to make it better and integrate it into DataInterpolations, I would be happy to do it.
Maybe there is a smart Implicit Differentiation trick I missed (@gdalle )?
In the meantime, thanks to everyone that will read this thread :slight_smile:

Edit : I still have to test Enzyme, as I am mostly using Zygote, and I am not too familiar with it.

Is it not covered by Add Tridiagonal construction rule by ChrisRackauckas Β· Pull Request #758 Β· JuliaDiff/ChainRules.jl Β· GitHub ?

1 Like

Sorry, I did not notice I had to update my env. Now it works.
There is still the issue with the differentiation of the other pieces of code. Any suggestion on how to proceed?

Any MWE on that? I’m losing track of what’s being asked.

Sure!

using Zygote
using ForwardDiff
using DiffRules
using BenchmarkTools
using ChainRulesCore
const RealOrComplex = Union{Real,Complex}
using DataInterpolations
using ChainRules
using LinearAlgebra
using FindFirstFunctions
using SparseArrays
using FiniteDifferences

n = 64
x = vcat([0.], sort(rand(n-2)), [1.])
x1 = vcat([0.], sort(rand(n-2)), [1.])
y = rand(n);

function di_spline(y,x,xn)
    spline = QuadraticSpline(y,x, extrapolate = true)
    return spline.(xn)
end

b1 = @benchmark sum(di_spline($y,$x,$x1))
b2 = @benchmark gradient($y->sum(di_spline($y,$x,$x1)), $y)

function my_spline(u, t, new_t::AbstractArray)
    s = length(t)
    s_new = length(new_t)
    dl = ones(eltype(t), s - 1)
    d_tmp = ones(eltype(t), s)
    du = zeros(eltype(t), s - 1)
    tA = Tridiagonal(dl, d_tmp, du)

    # zero for element type of d, which we don't know yet
    typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin]))

    d = create_d(u, t, s, typed_zero)#map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
    z = tA \ d
    i_list = create_i_list(t, new_t, s_new)#[min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i], firstindex(t) - 1)), length(t)) for i in 1:s_new]
    Cα΅’_list = create_Cα΅’_list(u, i_list)#[u[i - 1] for i in i_list]
    Οƒ = create_Οƒ(z, t, i_list)#[1 // 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]) for i in i_list]
    return compose(z, t, new_t, Cα΅’_list, s_new, i_list, Οƒ)#[z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) + Οƒ[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cα΅’_list[i] for i in 1:s_new]
end

compose(z, t, new_t, Cα΅’_list, s_new, i_list, Οƒ) = map(i -> z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) + Οƒ[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cα΅’_list[i], 1:s_new)
create_Οƒ(z, t, i_list) = map(i -> 1 / 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]),  i_list)
create_Cα΅’_list(u, i_list) = map(i-> u[i - 1],  i_list)
create_i_list(t, new_t, s_new) = map(i-> min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i], firstindex(t) - 1)), length(t)),  1:s_new)
create_d(u, t, s, typed_zero) = map(i -> i == 1 ? typed_zero : 2 / 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)

b3 = @benchmark sum(my_spline($y,$x,$x1))
b4 = @benchmark gradient($y->sum(my_spline($y,$x,$x1)), $y)

Zygote.@adjoint function create_d(u, t, s, typed_zero)
    y = create_d(u, t, s, typed_zero)
    function create_d_pullback(yΜ„)
        βˆ‚u = Tridiagonal(zeros(eltype(typed_zero), s-1),
               map(i -> i == 1 ? typed_zero : 2 / (t[i] - t[i - 1]), 1:s),
               map(i -> - 2 / (t[i+1] - t[i]), 1:s-1)) * yΜ„
        βˆ‚t = Tridiagonal(zeros(eltype(typed_zero), s-1),
               map(i -> i == 1 ? typed_zero : -2 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]) ^ 2, 1:s),
               map(i -> 2 * (u[i+1] - u[i]) / (t[i+1] - t[i]) ^ 2, 1:s-1)) * yΜ„
        return (βˆ‚u, βˆ‚t, NoTangent(), NoTangent())
    end
    return y, create_d_pullback
end

Zygote.@adjoint function create_Οƒ(z, x, i_list)
    y = create_Οƒ(z, x, i_list)
    function create_Οƒ_pullback(yΜ„)
        s = length(z)
        s1 = length(i_list)
        βˆ‚z = zeros(s,s1)
        βˆ‚x = zeros(s,s1)
        
        for j in 1:s1
            i = i_list[j]
            a = @views (z[i] - z[i-1])
            b = @views (x[i] - x[i-1])
            βˆ‚z[i,j] = 0.5 / b
            βˆ‚z[i-1,j] = -0.5 / b
            βˆ‚x[i,j] = -0.5 * a / b^2
            βˆ‚x[i-1,j] = 0.5 * a / b^2
        end
        
        βˆ‚z = βˆ‚z * yΜ„
        βˆ‚x = βˆ‚x * yΜ„
        return (βˆ‚z, βˆ‚x, NoTangent())
    end
    return y, create_Οƒ_pullback
end

Zygote.@adjoint function create_i_list(t, new_t, s_new)
    y = create_i_list(t, new_t, s_new)
    function create_i_list_pullback(yΜ„)
        return (NoTangent(), NoTangent(), NoTangent())
    end
    return y, create_i_list_pullback
end

b5 = @benchmark gradient($y->sum(my_spline($y,$x,$x1)), $y)

In my benchmarks, the version without gradients (b2) reads

BenchmarkTools.Trial: 1280 samples with 1 evaluation.
 Range (min … max):  3.600 ms …   6.411 ms  β”Š GC (min … max): 0.00% … 30.68%
 Time  (median):     3.721 ms               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   3.907 ms Β± 513.169 ΞΌs  β”Š GC (mean Β± Οƒ):  2.42% Β±  7.02%

   β–†β–ˆβ–‡β–…β–„β–„β–β–‚                                                    
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–…β–†β–†β–…β–„β–…β–†β–†β–„β–…β–…β–†β–„β–…β–„β–„β–†β–…β–„β–…β–†β–…β–…β–†β–…β–…β–β–„β–…β–β–β–…β–β–…β–β–„β–…β–…β–…β–†β–†β–‡β–ˆβ–†β–‡β–†β–† β–ˆ
  3.6 ms       Histogram: log(frequency) by time       5.9 ms <

 Memory estimate: 2.28 MiB, allocs estimate: 38307.

the one with custom gradients

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  44.632 ΞΌs … 702.216 ΞΌs  β”Š GC (min … max): 0.00% … 75.22%
 Time  (median):     56.627 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   65.096 ΞΌs Β±  34.856 ΞΌs  β”Š GC (mean Β± Οƒ):  4.53% Β±  8.01%

   β–‡β–ˆβ–‚                                                          
  β–ƒβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–…β–„β–„β–ƒβ–ƒβ–ƒβ–ƒβ–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–β–‚β–β–‚β–β–β–‚β–β–β–β–‚β–β–β–β–‚β–β–β–‚β–‚ β–ƒ
  44.6 ΞΌs         Histogram: frequency by time          265 ΞΌs <

 Memory estimate: 550.83 KiB, allocs estimate: 707.

So, the question is: how to write the rules for the DataInterpolation spline, without dividing its functions in the composition of smaller functions as I did?

You can write the rule on the QuadraticSpline constructor itself, and on its calls to interpolation.

1 Like

The constructor should not be a problem, but I don’t know how to evaluate the adjoint of

function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
    idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first)
    Cα΅’ = A.u[idx - 1]
    Οƒ = 1 // 2 * (A.z[idx] - A.z[idx - 1]) / (A.t[idx] - A.t[idx - 1])
    return A.z[idx - 1] * (t - A.t[idx - 1]) + Οƒ * (t - A.t[idx - 1])^2 + Cα΅’, idx
end

without dividing it in smaller chunks. I know, it’s a limitation of mine :sweat_smile:
Edit: probably a solution could be to compose together the functions I wrote before, but I am not sure this is the best approachπŸ˜…

This probably should be an issue.

Opened a new issue.

Lil’ update.
After adding some additional adjoints, I obtained the following performance

@benchmark gradient($y->sum(my_spline($y,$x,$x1)), $y)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  17.007 ΞΌs … 706.183 ΞΌs  β”Š GC (min … max): 0.00% … 87.69%
 Time  (median):     19.779 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   22.133 ΞΌs Β±  25.732 ΞΌs  β”Š GC (mean Β± Οƒ):  7.35% Β±  6.13%

     β–ƒβ–…β–†β–‡β–ˆβ–ˆβ–‡β–‡β–…β–„β–„β–ƒβ–ƒβ–‚β–β–β–β– ▁                                      β–‚
  β–ƒβ–†β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆβ–‡β–ˆβ–ˆβ–‡β–ˆβ–‡β–‡β–†β–†β–‡β–†β–„β–†β–†β–†β–†β–†β–†β–…β–†β–…β–†β–…β–ƒβ–„β–…β–„β–„β–„β–†β–„β–„β–ƒ β–ˆ
  17 ΞΌs         Histogram: log(frequency) by time      36.2 ΞΌs <

 Memory estimate: 210.53 KiB, allocs estimate: 123.

A factor of three improvement, and also the number of allocation dropped significantly. I think now the last piece is taking full advantage of the sparse structure of the problem (now I am allocating Dense matrices, but most of those elements are zeros).

Implemented rules
Zygote.@adjoint function create_Cα΅’_list(u, i_list)
    y = create_Cα΅’_list(u, i_list)
    function create_Cα΅’_list_pullback(yΜ„)
        s = length(z)
        s1 = length(i_list)
        βˆ‚Cα΅’_list = zeros(s,s1)
        
        for j in 1:s1
            i = i_list[j]
            βˆ‚Cα΅’_list[i-1,j] = 1.
        end
        βˆ‚Cα΅’_list = βˆ‚Cα΅’_list * yΜ„
        return (βˆ‚Cα΅’_list, NoTangent())
    end
    return y, create_Cα΅’_list_pullback
end

Zygote.@adjoint function compose(z, t, new_t, Cα΅’_list, s_new, i_list, Οƒ)
    y = compose(z, t, new_t, Cα΅’_list, s_new, i_list, Οƒ)
    function compose_pullback(yΜ„)
        s = length(z)
        s1 = length(i_list)
        βˆ‚z = zeros(s,s1)
        βˆ‚t = zeros(s,s1)
        βˆ‚t1 = zeros(s1,s1)
        
        
        for j in 1:s1
            i = i_list[j]
            βˆ‚z[i-1,j] = new_t[j] - t[i_list[j] - 1]
            βˆ‚t[i-1,j] = -z[i_list[j] - 1]  - 2Οƒ[j] * (new_t[j] - t[i_list[j] - 1])
            βˆ‚t1[j,j] = +z[i_list[j] - 1]  + 2Οƒ[j] * (new_t[j] - t[i_list[j] - 1])
        end
        
        βˆ‚z = βˆ‚z * yΜ„
        βˆ‚t = βˆ‚t * yΜ„
        βˆ‚t1 = βˆ‚t1 * yΜ„
        βˆ‚Οƒ = Diagonal(map(i -> (new_t[i] - t[i_list[i] - 1])^2, 1:s_new)) * yΜ„
        βˆ‚Cα΅’_list = Diagonal(ones(s1)) * yΜ„
        return (βˆ‚z, βˆ‚t, βˆ‚t1, βˆ‚Cα΅’_list, NoTangent(), NoTangent(), βˆ‚Οƒ)
    end
    return y, compose_pullback
end