Follow up on this thread. It will be a loooong thread
A MWE
I want to be able to differentiate something like this using Zygote
#random data creation
n = 64
x = vcat([0.], sort(rand(n-2)), [1.])
x1 = vcat([0.], sort(2*rand(10n-2)), [2.])
y = rand(n)
#interpolates original data (y evaluated on x) on a new x1
function di_spline(y,x,xn)
spline = QuadraticSpline(y,x, extrapolate = true)
return spline.(xn)
end
gradient(y->sum(di_spline(y,x,x1)), y)#does not work
Prerequisites: Tridiagonal
rrule
In order to be able to differentiate, I had to implement the Tridiagonal
rrule
rrule implementation and validation
Zygote.@adjoint function Tridiagonal(dl, d, du)
y = Tridiagonal(dl, d, du)
function Tridiagonal_pullback(yΜ)
βdl = @thunk(Array(diag(yΜ, -1)))
βd = @thunk(Array(diag(yΜ, 0)))
βdu = @thunk(Array(diag(yΜ, +1)))
return (βdl, βd, βdu)
end
return y, Tridiagonal_pullback
end
function rrule(::typeof(Tridiagonal), dl, d, du)
Ξ© = Tridiagonal(dl, d, du)
function Tridiagonal_pullback(ΞΞ©)
βdl = @thunk(Array(diag(yΜ, -1)))
βd = @thunk(Array(diag(yΜ, 0)))
βdu = @thunk(Array(diag(yΜ, +1)))
return (NoTangent(), βdl, βd, βdu)
end
return Ξ©, Tridiagonal_pullback
end
I checked the implemented rrule against FiniteDifferences
d = rand(1024)
dl = rand(1023)
du = rand(1023)
FiniteDifferences.grad(central_fdm(5, 1), du -> sum(Tridiagonal(dl, d, du)), du)[1]βgradient(du -> sum(Tridiagonal(dl, d, du)), du)[1]#true
FiniteDifferences.grad(central_fdm(5, 1), dl -> sum(Tridiagonal(dl, d, du)), dl)[1]βgradient(dl -> sum(Tridiagonal(dl, d, du)), du)[1]#true
FiniteDifferences.grad(central_fdm(5, 1), d -> sum(Tridiagonal(dl, d, du)), d)[1]βgradient(d -> sum(Tridiagonal(dl, d, du)), d)[1]#true
After implementing the Tridiagonal
rrule, I am able to differentiate di_spline
(and it is consistent with FiniteDifferences
). There is just a caveat: performance is terrrible!
@benchmark sum(di_spline($y,$x,$x1))
BenchmarkTools.Trial: 10000 samples with 4 evaluations.
Range (min β¦ max): 7.237 ΞΌs β¦ 998.080 ΞΌs β GC (min β¦ max): 0.00% β¦ 97.62%
Time (median): 8.534 ΞΌs β GC (median): 0.00%
Time (mean Β± Ο): 9.113 ΞΌs Β± 10.412 ΞΌs β GC (mean Β± Ο): 1.70% Β± 2.19%
βββββββ
βββββββ
ββββ β β β
βββββββββββββββββββββββββββββββ
β
βββ
β
βββ
βββββββββββ
βββ
β
ββ
βββ β
7.24 ΞΌs Histogram: log(frequency) by time 16.6 ΞΌs <
Memory estimate: 7.98 KiB, allocs estimate: 7.
@benchmark gradient($y->sum(di_spline($y,$x,$x1)), $y)
BenchmarkTools.Trial: 133 samples with 1 evaluation.
Range (min β¦ max): 35.348 ms β¦ 41.547 ms β GC (min β¦ max): 0.00% β¦ 10.68%
Time (median): 36.750 ms β GC (median): 0.00%
Time (mean Β± Ο): 37.707 ms Β± 1.963 ms β GC (mean Β± Ο): 3.73% Β± 4.59%
β ββ
ββ
ββ
ββββ
βββββ
βββββββββββββββββββββββββββββββ
ββ
ββ
ββββ
βββ
ββββ β
35.3 ms Histogram: frequency by time 41.5 ms <
Memory estimate: 18.85 MiB, allocs estimate: 319149.
Implementing rrules
In order to be able to improve performance, I decided to implement the relevant rrules. Here comes the first issue: how to evaluate them? Although (I think) I mostly understand how to compute and implement them, I am not sure in such a case I can implement the rules in an efficient and clever way.
I hence followed the noy approach I could think of: I rewrote QuadraticSpline
, basically copy-pasting the original code without using structs and implementing some utility function (basically, the strategy is to write down the rrule for the smaller functions).
My spline implementation
s = length(t)
s_new = length(new_t)
dl = ones(eltype(t), s - 1)
d_tmp = ones(eltype(t), s)
du = zeros(eltype(t), s - 1)
tA = Tridiagonal(dl, d_tmp, du)
# zero for element type of d, which we don't know yet
typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin]))
d = create_d(u, t, s, typed_zero)#map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
z = tA \ d
i_list = create_i_list(t, new_t, s_new)#[min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i], firstindex(t) - 1)), length(t)) for i in 1:s_new]
Cα΅’_list = create_Cα΅’_list(u, i_list)#[u[i - 1] for i in i_list]
Ο = create_Ο(z, t, i_list)#[1 // 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]) for i in i_list]
return compose(z, t, new_t, Cα΅’_list, s_new, i_list, Ο)#[z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) + Ο[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cα΅’_list[i] for i in 1:s_new]
end
compose(z, t, new_t, Cα΅’_list, s_new, i_list, Ο) = map(i -> z[i_list[i] - 1] * (new_t[i] - t[i_list[i] - 1]) + Ο[i] * (new_t[i] - t[i_list[i] - 1])^2 + Cα΅’_list[i], 1:s_new)
create_Ο(z, t, i_list) = map(i -> 1 / 2 * (z[i] - z[i - 1]) / (t[i] - t[i - 1]), i_list)
create_Cα΅’_list(u, i_list) = map(i-> u[i - 1], i_list)
create_i_list(t, new_t, s_new) = map(i-> min(max(2, FindFirstFunctions.searchsortedfirstcorrelated(t, new_t[i], firstindex(t) - 1)), length(t)), 1:s_new)
create_d(u, t, s, typed_zero) = map(i -> i == 1 ? typed_zero : 2 / 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s)
I explicitely checked that it is equivalent to the original one.
After this I wrote the rrules (as before, I checked the correctness for the rrules)
my_spline rrules
Zygote.@adjoint function create_d(u, t, s, typed_zero)
y = create_d(u, t, s, typed_zero)
function create_d_pullback(yΜ)
βu = Tridiagonal(zeros(eltype(typed_zero), s-1),
map(i -> i == 1 ? typed_zero : 2 / (t[i] - t[i - 1]), 1:s),
map(i -> - 2 / (t[i+1] - t[i]), 1:s-1)) * yΜ
βt = Tridiagonal(zeros(eltype(typed_zero), s-1),
map(i -> i == 1 ? typed_zero : -2 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]) ^ 2, 1:s),
map(i -> 2 * (u[i+1] - u[i]) / (t[i+1] - t[i]) ^ 2, 1:s-1)) * yΜ
return (βu, βt, NoTangent(), NoTangent())
end
return y, create_d_pullback
end
Zygote.@adjoint function create_Ο(z, x, i_list)
y = create_Ο(z, x, i_list)
function create_Ο_pullback(yΜ)
s = length(z)
s1 = length(i_list)
βz = zeros(s,s1)
βx = zeros(s,s1)
for j in 1:s1
i = i_list[j]
a = @views (z[i] - z[i-1])
b = @views (x[i] - x[i-1])
βz[i,j] += 0.5 / b
βz[i-1,j] -= 0.5 / b
βx[i,j] -= 0.5 * a / b^2
βx[i-1,j] += 0.5 * a / b^2
end
βz = βz * yΜ
βx = βx * yΜ
return (βz, βx, NoTangent())
end
return y, create_Ο_pullback
end #works, but performance can be a bit improved
Zygote.@adjoint function create_i_list(t, new_t, s_new)
y = create_i_list(t, new_t, s_new)
function create_i_list_pullback(yΜ)
return (NoTangent(), NoTangent(), NoTangent())
end
return y, create_i_list_pullback
end#not sure about this
The final result
So, was it worth? Performance definitely improved
@benchmark gradient($y->sum(my_spline($y,$x,$x)), $y)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max): 46.449 ΞΌs β¦ 708.213 ΞΌs β GC (min β¦ max): 0.00% β¦ 84.52%
Time (median): 52.767 ΞΌs β GC (median): 0.00%
Time (mean Β± Ο): 58.348 ΞΌs Β± 47.454 ΞΌs β GC (mean Β± Ο): 7.39% Β± 8.16%
βββββ
βββββ β
βββββββββββββββββββββββ
β
ββββ
βββββββββββββββββββ
ββββ
βββββββββ β
46.4 ΞΌs Histogram: log(frequency) by time 125 ΞΌs <
Memory estimate: 549.17 KiB, allocs estimate: 706.
To be compared with the initial
@benchmark gradient($y->sum(di_spline($y,$x,$x)), $y)
BenchmarkTools.Trial: 1119 samples with 1 evaluation.
Range (min β¦ max): 3.770 ms β¦ 8.495 ms β GC (min β¦ max): 0.00% β¦ 46.86%
Time (median): 4.324 ms β GC (median): 0.00%
Time (mean Β± Ο): 4.466 ms Β± 724.365 ΞΌs β GC (mean Β± Ο): 3.15% Β± 8.69%
ββββββ
ββ
ββββββββββββββββββββββββββββββββββββββββββββββββββ
βββββ
ββ β
3.77 ms Histogram: log(frequency) by time 8.15 ms <
Memory estimate: 2.28 MiB, allocs estimate: 38332.
A factor of 80 improvement! On the precision side, the two implementation are mostly equivalent (I checked explicitly they give almost the same result and made a comparison with FiniteDifferences
).
So, what?
Now, the question is: what to do? For sure, for my project, I am happy with my spline implementation. However, I am sure that performance can still be improved (one of the rules I implemented does not take advantage of sparsity) and maybe (likely?) this is not even the smartest way to implement such a thing. @ChrisRackauckas , if there is any advice coming on how to make it better and integrate it into DataInterpolations
, I would be happy to do it.
Maybe there is a smart Implicit Differentiation trick I missed (@gdalle )?
In the meantime, thanks to everyone that will read this thread
Edit : I still have to test Enzyme, as I am mostly using Zygote, and I am not too familiar with it.