Cubic Spline Interpolation for unevenly sampled data with gradient support?

I’m trying to use Flux.jl to optimize a loss function that needs to perform cubic spline interpolation, so I need a package that includes gradient support. Right now I’m working with DataInterpolations.jl. So far I have the following code for testing purposes but receive an error about too few arguments:

julia> using DataInterpolations, Flux, Zygote

julia> function foo(x::Vector)
           y = 2 .* x.^2 .+ cos.(x);
           x2 = [2.0, 5.0, 9.0, 9.5]
           y2 = DataInterpolations.CubicSpline(y, x).(x2)
           return sum(y2)
       end
foo (generic function with 1 method)

julia> x = [1:10.;];

julia> Flux.gradient(foo, x)
ERROR: ArgumentError: new: too few arguments (expected 4)
Stacktrace:
  [1] __new__
    @ ~/.julia/packages/Zygote/dABKa/src/tools/builtins.jl:9 [inlined]
  [2] adjoint
    @ ~/.julia/packages/Zygote/dABKa/src/lib/lib.jl:293 [inlined]
  [3] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
  [4] _pullback
    @ /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/LinearAlgebra/src/tridiag.jl:485 [inlined]
  [5] _pullback(::Zygote.Context{false}, ::Type{LinearAlgebra.Tridiagonal{Float64, Vector{Float64}}}, ::Vector{Float64}, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
  [6] _pullback
    @ /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/LinearAlgebra/src/tridiag.jl:520 [inlined]
  [7] _pullback
    @ ~/.julia/packages/DataInterpolations/Al4Ib/src/interpolation_caches.jl:160 [inlined]
  [8] _pullback(::Zygote.Context{false}, ::Type{CubicSpline}, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./REPL[86]:3 [inlined]
 [10] _pullback(::Zygote.Context{false}, ::typeof(cspline_interp), ::Vector{Float64}, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [11] _pullback
    @ ./REPL[87]:4 [inlined]
 [12] _pullback(ctx::Zygote.Context{false}, f::typeof(foo), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
 [13] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:44
 [14] pullback
    @ ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:42 [inlined]
 [15] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:96
 [16] top-level scope
    @ REPL[90]:1
 [17] top-level scope
    @ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52

julia> 

But using ForwardDiff does work –

julia> ForwardDiff.gradient(foo, x)
10-element Vector{Float64}:
 -2.071298673189042e-6
  0.4980577531556391
 -6.450221711938862e-6
  1.1919435010926138e-5
 -0.007078293628581947
 -0.0004909205140410721
 -0.001571414632713099
 -0.03183108029526967
 -1.6174037110341146
  0.5156259493105049
Status `~/.julia/environments/v1.8/Project.toml`
  [a0f608ac] AffineInvariantMCMC v1.0.2
  [7d9fca2a] Arpack v0.5.3
  [5c4adb95] AstroAngles v0.1.3
  [c7932e45] AstroLib v0.4.2
  [c61b5328] AstroTime v0.7.0
⌃ [6e4b80f9] BenchmarkTools v1.3.1
  [336ed68f] CSV v0.10.7
  [49dc2e85] Calculus v0.5.1
  [35d6a980] ColorSchemes v3.19.0
  [5ae59095] Colors v0.12.8
  [861a8166] Combinatorics v1.0.2
  [b0b7db55] ComponentArrays v0.13.4
  [2569d6c7] ConcreteStructs v0.2.3
  [9c784101] CubicSplines v0.2.1
  [717857b8] DSP v0.7.7
⌃ [a93c6f00] DataFrames v1.3.6
  [82cc6244] DataInterpolations v3.10.1
  [864edb3b] DataStructures v0.18.13
  [39dd38d3] Dierckx v0.5.2
  [7a1cc6ca] FFTW v1.5.0
  [525bcba6] FITSIO v0.17.0
  [9aa1b823] FastClosures v0.3.2
  [5789e2e9] FileIO v1.16.0
  [1a297f60] FillArrays v0.13.5
  [6a86dc24] FiniteDiff v2.15.0
  [26cc04aa] FiniteDifferences v0.12.25
  [587475ba] Flux v0.13.7
  [59287772] Formatting v0.4.2
  [f6369f11] ForwardDiff v0.10.32
  [a75be94c] GalacticOptim v3.4.0
  [c27321d9] Glob v1.3.0
  [bb4c363b] GridInterpolations v1.1.2
  [82e4d734] ImageIO v0.6.6
  [6218d12a] ImageMagick v1.2.2
  [916415d5] Images v0.25.2
  [5903a43b] Infiltrator v1.6.3
  [a98d9a8b] Interpolations v0.14.6
  [42fd0dbc] IterativeSolvers v0.9.2
  [033835bb] JLD2 v0.4.25
  [b964fa9f] LaTeXStrings v1.3.0
  [2ee39098] LabelledArrays v1.12.5
  [fc60dff9] LombScargle v1.0.3
  [bdcacae8] LoopVectorization v0.12.136
  [2fda8390] LsqFit v0.13.0
  [d41bc354] NLSolversBase v7.8.2
  [15e1cf62] NPZ v0.4.2
  [77ba4419] NaNMath v1.0.1
  [b946abbf] NaNStatistics v0.6.14
  [86f7a689] NamedArrays v0.9.6
  [429524aa] Optim v1.7.3
  [7f7a1694] Optimization v3.9.2
  [253f991c] OptimizationFlux v0.1.1
  [42dfb2eb] OptimizationOptimisers v0.1.0
  [5432bcbf] PaddedViews v0.5.11
  [d96e819e] Parameters v0.12.3
  [58dd65bb] Plotly v0.4.1
  [f0f68f2c] PlotlyJS v0.18.10
⌃ [91a5bcdd] Plots v1.35.6
⌃ [c3e4b0f8] Pluto v0.19.9
  [7f904dfe] PlutoUI v0.7.48
  [f27b6e38] Polynomials v3.2.0
  [92933f4c] ProgressMeter v1.7.2
  [438e738f] PyCall v1.94.1
  [d330b81b] PyPlot v2.11.0
  [6099a3de] PythonCall v0.9.9
  [189a3867] Reexport v1.2.2
  [295af30f] Revise v3.4.0
  [f2b01f46] Roots v2.0.8
⌃ [0bca4576] SciMLBase v1.67.0
  [fc659fc5] SkyCoords v1.0.2
  [276daf66] SpecialFunctions v2.1.7
  [928aab9d] SpecialMatrices v2.0.1
  [a25cea48] SpecialPolynomials v0.4.2
  [aedffcd0] Static v0.7.7
  [90137ffa] StaticArrays v1.5.9
  [2913bbd2] StatsBase v0.33.21
  [0c5d862f] Symbolics v4.13.0
  [bd369af6] Tables v1.10.0
  [e88e6eb3] Zygote v0.6.49
  [8bb1440f] DelimitedFiles
  [37e2e46d] LinearAlgebra
  [10745b16] Statistics

Insights appreciated!

I don’t want to pollute your new post and make people think that a solution has been found, so I’m posting here. I tried to run your code and I don’t see how to fix the bug, but at least you’re not alone :muscle:
Do you have an idea why the ForwardDiff version works here but fails in Optimization.jl, DataInterpolations.jl and Gradients ?