Hi @gmantegazza,
This simplest answer is that you’re slightly off with your splatting.
# Current function
NextStateValue(X...)=ApproximateValue(Weights,BasisFunctions,Horizons,X...)
# Change it to:
NextStateValue(X...)=ApproximateValue(Weights,BasisFunctions,Horizons,X)
The first function is going to try to “splat” X
into multiple arguments, where as the second is going to pass X
as a tuple. This is a very subtle part of Julia:
julia> bar(x...) = length(x), x
bar (generic function with 1 method)
julia> foo1(x...) = bar(x...)
foo1 (generic function with 1 method)
julia> foo2(x...) = bar(x)
foo2 (generic function with 1 method)
julia> foo1(1, 2)
(2, (1, 2))
julia> foo2(1, 2)
(1, ((1, 2),))
foo1
is calling bar(1, 2)
, where as foo2
is calling bar((1, 2))
. The simplest rule of thumb is either both the function definition and the call need a splat (like ApproximateValue(Weights,BasisFunctions,Horizons,State...)
, or none of them do.
I’m not really sure I understand your function (what is the findall(Horizons
trying to do?), perhaps this rewrite will help you out:
using JuMP
using LinearAlgebra
function ApproximateValueH(State, BasisFunctions, BasicHorizon_ind)
f(y, u, g, z) = y + sum(g[i] * (z[i] - u[i]) for i in 1:length(z))
return maximum(
f(b2, b1[:, 2], b1[:, 1], State) for
(b1, b2) in BasisFunctions[BasicHorizon_ind]
)
end
N = 5
Horizons = [2, 5, 15, 30, 45]
Weights = ones(length(Horizons)) ./ length(Horizons)
BasisFunctions = [[[rand(N,2), rand()] for j in 1:2] for i in 1:length(Horizons)]
function NextStateValue(X...)
return sum(
Weights[Hix] * ApproximateValueH(State, BasisFunctions, Hix)
for (Hix, H) in enumerate(Horizons)
)
end
model = Model()
@variable(model, NewState[1:N] >= 0)
register(model, :NextValueS, N, NextStateValue; autodiff = true)
@NLexpression(model, my_expr, NextValueS(NewState...))
You might also run into some performance problems with a user-defined function with such a large number of arguments, but that’s a secondary consideration after you get something working first.