Greetings, one of the custom Lux models I am trying to develop requires interpolation in between, something like the following(?) which errs
using DataInterpolations, Lux, Random, Zygote, Enzyme
function convert_interpolation_then_vector(M :: AbstractMatrix)
vectorized_data = collect.(eachslice(M,dims=1))
vec = LagrangeInterpolation(vectorized_data[1], time)
return vec'
end
temp_model = Chain(Dense(1, 1, use_bias = false), convert_interpolation_then_vector, Dense(1,1))
ERROR: Need an adjoint for constructor QuadraticInterpolation{Vector{Float64}, Vector{Float64}, true, Float64}. Gradient is of type Vector{Float64}
And gradient calculations for some of the interpolations work on Enzyme but not on Zygote.
time = collect(0:0.01:1)
x = LinearInterpolation(sin.(time), time)
Enzyme.gradient(Forward, time -> sum(x(time)), time) #works
Zygote.gradient(time -> sum(x(time)), time) # Mutating Array error
But I am not sure if Enzyme will work on Lux models currently for gradient calculations? Can anyone suggest what I am missing? Thank you very much for your help.