Custom Flux.jl loss function with Interpolations.jl evaluation Error

I was not able to fix the problem. My guess is that Interpolations.jl is not compatible with Zygote.jl. A possible workaround I found, was writing a custom interpolations class and function. I include a working example if anyone is interested:

using Flux, Zygote
using LinearAlgebra
using Interpolations

# create a custom linear splines class

struct CustomInterpolator
    x::Vector
    y::Vector
    
    function CustomInterpolator(x,y)
        @assert issorted(x)
        return new(x,y)
    end
end

function custom_interpolate(citp::CustomInterpolator, x::Number)
    left_value, right_value = 0, 0
    left_index, right_index = 1, 1

    # check bound
    if x > citp.x[end] || x < citp.x[1]
        @error "Out of bounds"
        throw(DomainError(x))
    end
    
    #find the right indices 
    for (i,v) in enumerate(citp.x)
        if left_value > x
            right_value = v
            right_index = i
            break
        end
        left_value = v
        left_index = i
    end

    # do a linear inter interpolation between the two selected indices
    interpolated_value = (1 - (x - left_value)/(right_value - left_value)) * citp.y[left_index] + (x - left_value)/(right_value - left_value) * citp.y[right_index]
    return interpolated_value
end

##
# create custom interpolator
x = LinRange(0,1,2)
y = [zeros(2,2), ones(2,2)]
citp = CustomInterpolator(x,y)

# create training set
training_set = [(ones(3)*i, ones(2,2) - i*ones(2,2)) for i in 0:0.2:1]

#build the model
model = Chain(Dense(3,3), Dense(3,1), i-> clamp(i[1],0,1), i->custom_interpolate(citp,i))
opt = ADAM() 
ps = Flux.params(model)
loss(x,y) = Flux.mse(model(x), y)

# training
n_epochs = 1000
for epoch in 1:n_epochs
    Flux.train!(loss, ps, training_set, opt)
    println(sum([loss(i[1],i[2]) for i in training_set]))
end


1 Like