Hi everyone, I am trying to differentiate through an interpolation (as offered by DataInterpolations.jl.
Here is a MWE
x = Array(LinRange(0,10,101))
a = 1
b = 2
c = 3
p(a,b,c) = a.+b.*x.+c.*x.^2
gradient(b -> sum(p(a,b,c)), b)
function Q(a,b,c,x)
akima = AkimaInterpolation(p(a,b,c), x)
return akima.(x)
end
q(a,b,c) = Q(a,b,c,x)
gradient(b -> sum(q(a,b,c)), b)
The first call to gradient works, the second one no.
Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations`
So, which are the possible fixes? Am I doing something wrong?
I am considering writing a small routine to implement a quadratic interpolation from scratch, but maybe this can be solved in another way…?
Thanks to everyone,
Marco
Hi @ChrisRackauckas ,
thank you for your quick answer.
I was a bit unprecise
The first call to gradient (the one differentiating the p function I defined) work as expected.
The second call to gradient (the one differentiating the q function I defined) does not work.
using Zygote, DataInterpolations
x = Array(LinRange(0,10,101))
a = 1
b = 2
c = 3
p(a,b,c) = a.+b.*x.+c.*x.^2
function Q(a,b,c,x)
akima = AkimaInterpolation(p(a,b,c), x)
return akima.(x)
end
q(a,b,c) = Q(a,b,c,x)
gradient(b -> sum(q(a,b,c)), b)
In this MWE, the last line will raise the error (here the full stack trace).
Stacktrace
julia> gradient(b -> sum(q(a,b,c)), b)
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] _throw_mutation_error(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:70
[3] (::Zygote.var"#539#540"{Vector{Float64}})(::Nothing)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:82
[4] (::Zygote.var"#2623#back#541"{Zygote.var"#539#540"{Vector{Float64}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
[5] #AkimaInterpolation#5
@ ~/.julia/packages/DataInterpolations/Pz5Mr/src/interpolation_caches.jl:159 [inlined]
[6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[7] AkimaInterpolation
@ ~/.julia/packages/DataInterpolations/Pz5Mr/src/interpolation_caches.jl:142 [inlined]
[8] (::Zygote.Pullback{Tuple{Type{…}, Vector{…}, Vector{…}}, Tuple{Zygote.Pullback{…}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[9] Q
@ ./REPL[7]:2 [inlined]
[10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[11] q
@ ./REPL[8]:1 [inlined]
[12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[13] #1
@ ./REPL[9]:1 [inlined]
[14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[15] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
[16] gradient(::Function, ::Int64, ::Vararg{Int64})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
[17] top-level scope
@ REPL[9]:1
Some type information was truncated. Use `show(err)` to see complete types.
Ok, thank you!
I will go through the math.
I started going through the code, as you said it shouldn’t be that difficult.
For instance, I started from QuadraticSpline (the one I actually need).
I think we need some additional rules, for stuff like Tridiagonal
Tridiagonal adjoint
Zygote.@adjoint function Tridiagonal(dl, d, du)
y = Tridiagonal(dl, d, du)
function Tridiagonal_pullback(ȳ)
∂dl = @thunk(Array(diag(ȳ, -1)))
∂d = @thunk(Array(diag(ȳ, 0)))
∂du = @thunk(Array(diag(ȳ, +1)))
return (∂dl, ∂d, ∂du)
end
return y, Tridiagonal_pullback
end
Curiosly, when I was trying to define the rule with ChainRules
function rrule(::typeof(Tridiagonal), dl, d, du)
y = Tridiagonal(dl, d, du)
project_dl = ProjectTo(dl)
project_d = ProjectTo(d)
project_du = ProjectTo(du)
function Tridiagonal_pullback(ȳ)
∂dl = (diag(ȳ, -1))
∂d = (diag(ȳ, 0))
∂du = (diag(ȳ, +1))
return NoTangent(), project_dl(∂dl), project_d(∂d), project_du(∂du)
end
return y, times_pullback
end
It was not working (it gave the same errore as before, maybe there is a rule that takes precedence when using Zygote?)
@ChrisRackauckas , if you were willing to guide me a bit, how should we proceed?
I think I would start understanding which rules need to be created and do some local checks. After doing this and writing something that works locally on my laptop, I would proceed with opening a PR. Does it make sense to you?
Edit: I was likely wrong, the Tridiagonal adjoint is actually not needed.
The rrules that are needed are the ones on the constructors for the interpolation, which need a derivation of the derivative of the spline coefficients with respect to the data values.
Hi @ChrisRackauckas ,
no, I have to check whether Enzyme will work or not.
Regarding Zygote, I had to add the rule for the tridiagonal matrix, and later it was able to perform the differentiation by itself. However, adding rules to other pieces of the code (e.g. the constructor of the spline) improved performance.