Custom Flux.jl loss function with Interpolations.jl evaluation Error

I am trying to implement a model with a custom loss function in the Flux.jl package. I include the code for a simplified model, but the error stays the same.

I have an interpolator which takes a scalar value and returns a 2x2 matrix. The goal of my model is to use 3 observations to find the best point to evaluate the interpolator at. For this I wrote a custom loss function that computes the suggested evalutation_point and evaluates the interpolator at this point. Then the interpolated result is compared to the true solution from the dataset.

using Flux, Zygote
using LinearAlgebra
using Interpolations


# create interpolator
x = LinRange(0,1,10)
y = [rand(2,2) for i in 1:10]
itp = interpolate(y, BSpline(Linear())) |> i -> scale(i, x)

# create training set
training_set = [(rand(3), rand(2,2)) for i in 0:0.2:1]

#build the model
model = Chain(Dense(3,1),i-> clamp(i[1],0,1))
opt = Descent()
ps = Flux.params(model)

function loss(evaluation_point, solution)
    interpolated  = itp(model(evaluation_point))
    return norm(interpolated - solution)

# training NOK
n_epochs = 100
for epoch in 1:n_epochs
    Flux.train!(loss, ps, training_set, opt)
    println(sum([loss_fnc(i[1],i[2]) for i in training_set]))

This returns the following error:

ERROR: DimensionMismatch("matrix A has dimensions (2,2), vector B has length 1")
  [1] generic_matvecmul!(C::Vector{Matrix{Float64}}, tA::Char, A::Matrix{Float64}, B::StaticArrays.SVector{1, Matrix{Float64}}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra C:\Users\thega\AppData\Local\Programs\Julia-1.7.2\share\julia\stdlib\v1.7\LinearAlgebra\src\matmul.jl:713
  [2] mul!
    @ C:\Users\thega\AppData\Local\Programs\Julia-1.7.2\share\julia\stdlib\v1.7\LinearAlgebra\src\matmul.jl:81 [inlined]     
  [3] mul!
    @ C:\Users\thega\AppData\Local\Programs\Julia-1.7.2\share\julia\stdlib\v1.7\LinearAlgebra\src\matmul.jl:275 [inlined]    
  [4] *
    @ C:\Users\thega\AppData\Local\Programs\Julia-1.7.2\share\julia\stdlib\v1.7\LinearAlgebra\src\matmul.jl:51 [inlined]     
  [5] interpolate_pullback
    @ C:\Users\thega\.julia\packages\Interpolations\Glp9h\src\chainrules\chainrules.jl:13 [inlined]
  [6] ZBack
    @ C:\Users\thega\.julia\packages\Zygote\H6vD3\src\compiler\chainrules.jl:204 [inlined]
  [7] Pullback
    @ c:\Users\thega\Desktop\Question\main.jl:21 [inlined]
  [8] (::typeof(∂(loss)))(Δ::Float64)
    @ Zygote C:\Users\thega\.julia\packages\Zygote\H6vD3\src\compiler\interface2.jl:0
  [9] #212
    @ C:\Users\thega\.julia\packages\Zygote\H6vD3\src\lib\lib.jl:203 [inlined]
 [10] #1750#back
    @ C:\Users\thega\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [11] Pullback
    @ C:\Users\thega\.julia\packages\Flux\0c9kI\src\optimise\train.jl:102 [inlined]
 [12] (::typeof(∂(λ)))(Δ::Float64)
    @ Zygote C:\Users\thega\.julia\packages\Zygote\H6vD3\src\compiler\interface2.jl:0
 [13] (::Zygote.var"#93#94"{Params, typeof(∂(λ)), Zygote.Context})(Δ::Float64)
    @ Zygote C:\Users\thega\.julia\packages\Zygote\H6vD3\src\compiler\interface.jl:357
 [14] gradient(f::Function, args::Params)
    @ Zygote C:\Users\thega\.julia\packages\Zygote\H6vD3\src\compiler\interface.jl:76
 [15] macro expansion
    @ C:\Users\thega\.julia\packages\Flux\0c9kI\src\optimise\train.jl:101 [inlined]
 [16] macro expansion
    @ C:\Users\thega\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
 [17] train!(loss::Function, ps::Params, data::Vector{Tuple{Vector{Float64}, Matrix{Float64}}}, opt::Descent; cb::Flux.Optimise.var"#40#46")
    @ Flux.Optimise C:\Users\thega\.julia\packages\Flux\0c9kI\src\optimise\train.jl:99
 [18] train!(loss::Function, ps::Params, data::Vector{Tuple{Vector{Float64}, Matrix{Float64}}}, opt::Descent)
    @ Flux.Optimise C:\Users\thega\.julia\packages\Flux\0c9kI\src\optimise\train.jl:97
 [19] top-level scope
    @ c:\Users\thega\Desktop\Question\main.jl:28

So something about a dimension mismatch, but the evaluation of the loss function works fine.

loss(training_set[1][1], training_set[1][2])

I play around a bit and found that the problem is the gradient computation:

gradient(loss , training_set[1][1], training_set[1][2])

I am pretty new to julia and dont see a way to fix this. Or maybe it is not possible to implement it, in the way I planed. I hope you can help me out.

I was not able to fix the problem. My guess is that Interpolations.jl is not compatible with Zygote.jl. A possible workaround I found, was writing a custom interpolations class and function. I include a working example if anyone is interested:

using Flux, Zygote
using LinearAlgebra
using Interpolations

# create a custom linear splines class

struct CustomInterpolator
    function CustomInterpolator(x,y)
        @assert issorted(x)
        return new(x,y)

function custom_interpolate(citp::CustomInterpolator, x::Number)
    left_value, right_value = 0, 0
    left_index, right_index = 1, 1

    # check bound
    if x > citp.x[end] || x < citp.x[1]
        @error "Out of bounds"
    #find the right indices 
    for (i,v) in enumerate(citp.x)
        if left_value > x
            right_value = v
            right_index = i
        left_value = v
        left_index = i

    # do a linear inter interpolation between the two selected indices
    interpolated_value = (1 - (x - left_value)/(right_value - left_value)) * citp.y[left_index] + (x - left_value)/(right_value - left_value) * citp.y[right_index]
    return interpolated_value

# create custom interpolator
x = LinRange(0,1,2)
y = [zeros(2,2), ones(2,2)]
citp = CustomInterpolator(x,y)

# create training set
training_set = [(ones(3)*i, ones(2,2) - i*ones(2,2)) for i in 0:0.2:1]

#build the model
model = Chain(Dense(3,3), Dense(3,1), i-> clamp(i[1],0,1), i->custom_interpolate(citp,i))
opt = ADAM() 
ps = Flux.params(model)
loss(x,y) = Flux.mse(model(x), y)

# training
n_epochs = 1000
for epoch in 1:n_epochs
    Flux.train!(loss, ps, training_set, opt)
    println(sum([loss(i[1],i[2]) for i in training_set]))

1 Like

I played with this a little bit, I think it is a bug: rrule fails for matrix-valued interpolation · Issue #486 · JuliaMath/Interpolations.jl · GitHub

Auto-differentiation support in Interpolations.jl is relatively recent.

I could use some help in maintaining these features in Interpolations.jl.

1 Like