I was not able to fix the problem. My guess is that Interpolations.jl is not compatible with Zygote.jl. A possible workaround I found, was writing a custom interpolations class and function. I include a working example if anyone is interested:
using Flux, Zygote
using LinearAlgebra
using Interpolations
# create a custom linear splines class
struct CustomInterpolator
x::Vector
y::Vector
function CustomInterpolator(x,y)
@assert issorted(x)
return new(x,y)
end
end
function custom_interpolate(citp::CustomInterpolator, x::Number)
left_value, right_value = 0, 0
left_index, right_index = 1, 1
# check bound
if x > citp.x[end] || x < citp.x[1]
@error "Out of bounds"
throw(DomainError(x))
end
#find the right indices
for (i,v) in enumerate(citp.x)
if left_value > x
right_value = v
right_index = i
break
end
left_value = v
left_index = i
end
# do a linear inter interpolation between the two selected indices
interpolated_value = (1 - (x - left_value)/(right_value - left_value)) * citp.y[left_index] + (x - left_value)/(right_value - left_value) * citp.y[right_index]
return interpolated_value
end
##
# create custom interpolator
x = LinRange(0,1,2)
y = [zeros(2,2), ones(2,2)]
citp = CustomInterpolator(x,y)
# create training set
training_set = [(ones(3)*i, ones(2,2) - i*ones(2,2)) for i in 0:0.2:1]
#build the model
model = Chain(Dense(3,3), Dense(3,1), i-> clamp(i[1],0,1), i->custom_interpolate(citp,i))
opt = ADAM()
ps = Flux.params(model)
loss(x,y) = Flux.mse(model(x), y)
# training
n_epochs = 1000
for epoch in 1:n_epochs
Flux.train!(loss, ps, training_set, opt)
println(sum([loss(i[1],i[2]) for i in training_set]))
end