DataInterpolations.jl and Zygote.jl

Hi everyone, I am trying to differentiate through an interpolation (as offered by DataInterpolations.jl.
Here is a MWE

x = Array(LinRange(0,10,101))
a = 1
b = 2
c = 3
p(a,b,c) = a.+b.*x.+c.*x.^2
gradient(b -> sum(p(a,b,c)), b)

function Q(a,b,c,x)
    akima = AkimaInterpolation(p(a,b,c), x)
    return akima.(x)
end

q(a,b,c) = Q(a,b,c,x)
gradient(b -> sum(q(a,b,c)), b)

The first call to gradient works, the second one no.

Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations`

So, which are the possible fixes? Am I doing something wrong?
I am considering writing a small routine to implement a quadratic interpolation from scratch, but maybe this can be solved in another way…?
Thanks to everyone,
Marco

You mean the gradient of the gradient, so the second derivative call? You didn’t include the code that errors.

Hi @ChrisRackauckas ,
thank you for your quick answer.
I was a bit unprecise
The first call to gradient (the one differentiating the p function I defined) work as expected.
The second call to gradient (the one differentiating the q function I defined) does not work.

using Zygote, DataInterpolations
x = Array(LinRange(0,10,101))
a = 1
b = 2
c = 3
p(a,b,c) = a.+b.*x.+c.*x.^2

function Q(a,b,c,x)
    akima = AkimaInterpolation(p(a,b,c), x)
    return akima.(x)
end

q(a,b,c) = Q(a,b,c,x)
gradient(b -> sum(q(a,b,c)), b)

In this MWE, the last line will raise the error (here the full stack trace).

Stacktrace
julia> gradient(b -> sum(q(a,b,c)), b)
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:70
  [3] (::Zygote.var"#539#540"{Vector{Float64}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/array.jl:82
  [4] (::Zygote.var"#2623#back#541"{Zygote.var"#539#540"{Vector{Float64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:71
  [5] #AkimaInterpolation#5
    @ ~/.julia/packages/DataInterpolations/Pz5Mr/src/interpolation_caches.jl:159 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [7] AkimaInterpolation
    @ ~/.julia/packages/DataInterpolations/Pz5Mr/src/interpolation_caches.jl:142 [inlined]
  [8] (::Zygote.Pullback{Tuple{Type{…}, Vector{…}, Vector{…}}, Tuple{Zygote.Pullback{…}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
  [9] Q
    @ ./REPL[7]:2 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [11] q
    @ ./REPL[8]:1 [inlined]
 [12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float64, 1, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [13] #1
    @ ./REPL[9]:1 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
 [15] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
 [16] gradient(::Function, ::Int64, ::Vararg{Int64})
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
 [17] top-level scope
    @ REPL[9]:1
Some type information was truncated. Use `show(err)` to see complete types.

Thank you again!

It currently only has the adjoints to differentiate w.r.t. t

1 Like

Ok!
So, two follow-up questions:

  1. Would there be interest in a PR adding adjoints wrt u and A to DataInterpolations ?
  2. If so, are you aware of somewhere where I could find the rules? If so, I could try to implement them myself.

Thanks,
Marco

Yes

You’d have to derive it. It’s just calculus on the interpolation functions so it’s not hard but someone has to do it to avoid the mutation.

Or Enzyme should work.

1 Like

Ok, thank you!
I will go through the math.
I started going through the code, as you said it shouldn’t be that difficult.
For instance, I started from QuadraticSpline (the one I actually need).

I think we need some additional rules, for stuff like Tridiagonal

Tridiagonal adjoint
Zygote.@adjoint function Tridiagonal(dl, d, du)
    y = Tridiagonal(dl, d, du)
    function Tridiagonal_pullback(ȳ)
        ∂dl = @thunk(Array(diag(ȳ, -1)))
        ∂d  = @thunk(Array(diag(ȳ, 0)))
        ∂du = @thunk(Array(diag(ȳ, +1)))
        return (∂dl, ∂d, ∂du)
    end
    return y, Tridiagonal_pullback
end

Curiosly, when I was trying to define the rule with ChainRules

function rrule(::typeof(Tridiagonal), dl, d, du)
    y = Tridiagonal(dl, d, du)
    project_dl = ProjectTo(dl)
    project_d = ProjectTo(d)
    project_du = ProjectTo(du)
    
    function Tridiagonal_pullback(ȳ)
        ∂dl = (diag(ȳ, -1))
        ∂d  = (diag(ȳ, 0))
        ∂du = (diag(ȳ, +1))
        return NoTangent(), project_dl(∂dl), project_d(∂d), project_du(∂du)
    end
    return y, times_pullback
end

It was not working (it gave the same errore as before, maybe there is a rule that takes precedence when using Zygote?)

@ChrisRackauckas , if you were willing to guide me a bit, how should we proceed?
I think I would start understanding which rules need to be created and do some local checks. After doing this and writing something that works locally on my laptop, I would proceed with opening a PR. Does it make sense to you?

Edit: I was likely wrong, the Tridiagonal adjoint is actually not needed.

The rrules that are needed are the ones on the constructors for the interpolation, which need a derivation of the derivative of the spline coefficients with respect to the data values.

But did you check if Enzyme just works?

1 Like

Hi @ChrisRackauckas ,
no, I have to check whether Enzyme will work or not.
Regarding Zygote, I had to add the rule for the tridiagonal matrix, and later it was able to perform the differentiation by itself. However, adding rules to other pieces of the code (e.g. the constructor of the spline) improved performance.