I’m trying to use Flux.jl to optimize a loss function that needs to perform cubic spline interpolation, so I need a package that includes gradient support. Right now I’m working with DataInterpolations.jl. So far I have the following code for testing purposes but receive an error about too few arguments:
julia> using DataInterpolations, Flux, Zygote
julia> function foo(x::Vector)
y = 2 .* x.^2 .+ cos.(x);
x2 = [2.0, 5.0, 9.0, 9.5]
y2 = DataInterpolations.CubicSpline(y, x).(x2)
return sum(y2)
end
foo (generic function with 1 method)
julia> x = [1:10.;];
julia> Flux.gradient(foo, x)
ERROR: ArgumentError: new: too few arguments (expected 4)
Stacktrace:
[1] __new__
@ ~/.julia/packages/Zygote/dABKa/src/tools/builtins.jl:9 [inlined]
[2] adjoint
@ ~/.julia/packages/Zygote/dABKa/src/lib/lib.jl:293 [inlined]
[3] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[4] _pullback
@ /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/LinearAlgebra/src/tridiag.jl:485 [inlined]
[5] _pullback(::Zygote.Context{false}, ::Type{LinearAlgebra.Tridiagonal{Float64, Vector{Float64}}}, ::Vector{Float64}, ::Vector{Float64}, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[6] _pullback
@ /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/LinearAlgebra/src/tridiag.jl:520 [inlined]
[7] _pullback
@ ~/.julia/packages/DataInterpolations/Al4Ib/src/interpolation_caches.jl:160 [inlined]
[8] _pullback(::Zygote.Context{false}, ::Type{CubicSpline}, ::Vector{Float64}, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[9] _pullback
@ ./REPL[86]:3 [inlined]
[10] _pullback(::Zygote.Context{false}, ::typeof(cspline_interp), ::Vector{Float64}, ::Vector{Float64}, ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[11] _pullback
@ ./REPL[87]:4 [inlined]
[12] _pullback(ctx::Zygote.Context{false}, f::typeof(foo), args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[13] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:44
[14] pullback
@ ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:42 [inlined]
[15] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:96
[16] top-level scope
@ REPL[90]:1
[17] top-level scope
@ ~/.julia/packages/CUDA/DfvRa/src/initialization.jl:52
julia>
But using ForwardDiff does work –
julia> ForwardDiff.gradient(foo, x)
10-element Vector{Float64}:
-2.071298673189042e-6
0.4980577531556391
-6.450221711938862e-6
1.1919435010926138e-5
-0.007078293628581947
-0.0004909205140410721
-0.001571414632713099
-0.03183108029526967
-1.6174037110341146
0.5156259493105049
Status `~/.julia/environments/v1.8/Project.toml`
[a0f608ac] AffineInvariantMCMC v1.0.2
[7d9fca2a] Arpack v0.5.3
[5c4adb95] AstroAngles v0.1.3
[c7932e45] AstroLib v0.4.2
[c61b5328] AstroTime v0.7.0
⌃ [6e4b80f9] BenchmarkTools v1.3.1
[336ed68f] CSV v0.10.7
[49dc2e85] Calculus v0.5.1
[35d6a980] ColorSchemes v3.19.0
[5ae59095] Colors v0.12.8
[861a8166] Combinatorics v1.0.2
[b0b7db55] ComponentArrays v0.13.4
[2569d6c7] ConcreteStructs v0.2.3
[9c784101] CubicSplines v0.2.1
[717857b8] DSP v0.7.7
⌃ [a93c6f00] DataFrames v1.3.6
[82cc6244] DataInterpolations v3.10.1
[864edb3b] DataStructures v0.18.13
[39dd38d3] Dierckx v0.5.2
[7a1cc6ca] FFTW v1.5.0
[525bcba6] FITSIO v0.17.0
[9aa1b823] FastClosures v0.3.2
[5789e2e9] FileIO v1.16.0
[1a297f60] FillArrays v0.13.5
[6a86dc24] FiniteDiff v2.15.0
[26cc04aa] FiniteDifferences v0.12.25
[587475ba] Flux v0.13.7
[59287772] Formatting v0.4.2
[f6369f11] ForwardDiff v0.10.32
[a75be94c] GalacticOptim v3.4.0
[c27321d9] Glob v1.3.0
[bb4c363b] GridInterpolations v1.1.2
[82e4d734] ImageIO v0.6.6
[6218d12a] ImageMagick v1.2.2
[916415d5] Images v0.25.2
[5903a43b] Infiltrator v1.6.3
[a98d9a8b] Interpolations v0.14.6
[42fd0dbc] IterativeSolvers v0.9.2
[033835bb] JLD2 v0.4.25
[b964fa9f] LaTeXStrings v1.3.0
[2ee39098] LabelledArrays v1.12.5
[fc60dff9] LombScargle v1.0.3
[bdcacae8] LoopVectorization v0.12.136
[2fda8390] LsqFit v0.13.0
[d41bc354] NLSolversBase v7.8.2
[15e1cf62] NPZ v0.4.2
[77ba4419] NaNMath v1.0.1
[b946abbf] NaNStatistics v0.6.14
[86f7a689] NamedArrays v0.9.6
[429524aa] Optim v1.7.3
[7f7a1694] Optimization v3.9.2
[253f991c] OptimizationFlux v0.1.1
[42dfb2eb] OptimizationOptimisers v0.1.0
[5432bcbf] PaddedViews v0.5.11
[d96e819e] Parameters v0.12.3
[58dd65bb] Plotly v0.4.1
[f0f68f2c] PlotlyJS v0.18.10
⌃ [91a5bcdd] Plots v1.35.6
⌃ [c3e4b0f8] Pluto v0.19.9
[7f904dfe] PlutoUI v0.7.48
[f27b6e38] Polynomials v3.2.0
[92933f4c] ProgressMeter v1.7.2
[438e738f] PyCall v1.94.1
[d330b81b] PyPlot v2.11.0
[6099a3de] PythonCall v0.9.9
[189a3867] Reexport v1.2.2
[295af30f] Revise v3.4.0
[f2b01f46] Roots v2.0.8
⌃ [0bca4576] SciMLBase v1.67.0
[fc659fc5] SkyCoords v1.0.2
[276daf66] SpecialFunctions v2.1.7
[928aab9d] SpecialMatrices v2.0.1
[a25cea48] SpecialPolynomials v0.4.2
[aedffcd0] Static v0.7.7
[90137ffa] StaticArrays v1.5.9
[2913bbd2] StatsBase v0.33.21
[0c5d862f] Symbolics v4.13.0
[bd369af6] Tables v1.10.0
[e88e6eb3] Zygote v0.6.49
[8bb1440f] DelimitedFiles
[37e2e46d] LinearAlgebra
[10745b16] Statistics
Insights appreciated!