I’m trying to use cubic spline interpolation from DataInterpolations.jl in my loss function with ModelingToolkit.jl. There’s an older thread that seems to discuss this here but I have’t been able to get that working, and it hasn’t seen traffic in some time, but please let me know if I should continue with that thread.
I’ve tried with and without the various collect’s but I think my problem is my lack of understanding of how @register
works.
julia> using ModelingToolkit, DataInterpolations
julia> n=10
10
julia> @parameters x[1:n] data[1:n]
2-element Vector{Symbolics.Arr{Num, 1}}:
x[1:10]
data[1:10]
julia> function interp(x, y, xnew)
itp = CubicSpline(y, x)
return itp.(xnew)
end
interp (generic function with 1 method)
julia> @register interp(a, b, c)
julia> rms(x, y) = sqrt(sum((x .- y).^2) / length(x))
rms (generic function with 1 method)
julia> @variables c0 c1
2-element Vector{Num}:
c0
c1
julia> model(x, c0, c1) = @. c0 + c1 * x
model (generic function with 1 method)
julia> loss = rms(collect(data), interp(collect(x), model(collect(x), c0, c1), collect(x)))
ERROR: TypeError: non-boolean (Num) used in boolean context
Stacktrace:
[1] lu!(A::LinearAlgebra.Tridiagonal{Num, Vector{Num}}, pivot::LinearAlgebra.RowMaximum; check::Bool)
@ LinearAlgebra /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/LinearAlgebra/src/lu.jl:515
[2] lu(A::LinearAlgebra.Tridiagonal{Num, Vector{Num}}, pivot::LinearAlgebra.RowMaximum; check::Bool)
@ LinearAlgebra /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/LinearAlgebra/src/lu.jl:279
[3] lu (repeats 2 times)
@ /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/LinearAlgebra/src/lu.jl:278 [inlined]
[4] \(A::LinearAlgebra.Tridiagonal{Num, Vector{Num}}, B::Vector{Real})
@ LinearAlgebra /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/stdlib/v1.7/LinearAlgebra/src/generic.jl:1142
[5] CubicSpline(u::Vector{Num}, t::Vector{Num})
@ DataInterpolations ~/.julia/packages/DataInterpolations/HJsu4/src/interpolation_caches.jl:156
[6] interp(x::Vector{Num}, y::Vector{Num}, xnew::Vector{Num})
@ Main ./REPL[8]:2
[7] top-level scope
@ REPL[15]:1