I was surprised that type inference does not work in this case
for i=1:3
@eval $(Symbol("f"*string(i)))(x) = x^$(i)+$(i)
end
#
@code_warntype f1(1.1) # it is Float64, fine!
#
function test(i)
list_of_functions = [f1,f2,f3]
return list_of_functions[i](rand())
end
@code_warntype test(1) # Any
what would be the correct way call a function by given index?
For recording, I have the other few “wrong” ways to write the same code,
function testII(i)
return @eval $(Symbol("f"*string(i)))(rand())
end
@code_warntype testII(1) # Any
#
function testIII(i)
f = f1
i==2 && (f = f2)
i==3 && (f = f3)
return f(rand())
end
@code_warntype testIII(1) # Any
In the next two examples, the correct type is inferred, however, they are incovinient,
# code for function is repeated for every f[i]
function testIII(i)
i==1 && return f1(rand())
i==2 && return f2(rand())
return f3(rand())
end
@code_warntype testIII(1) # Float64
#
# not really what was required
# also, the code is not readable for students
for i=1:3
@eval function $(Symbol("test_"*string(i)))()
return $(Symbol("f"*string(i)))(rand())
end
end
@code_warntype test_2() # Float64
The type information is lost at [f1,f2,f3] because that vector is of type Vector{Function}. A tuple will instead store the type of each of the three arguments (i.e. the type of the functions), so you need:
function test(i)
list_of_functions = (f1,f2,f3)
return list_of_functions[i](rand())
end
The variable i is still not known at compile time, but thanks to constant propagation, Julia figures it out, you just need to test the whole thing inside another function. Its more obvious that its working if the three function return different types
Thank you very much, that is good to know!
Would it propagate constant tuples or more complicated structures as well?
However, it does not solve my problem, still. I know that the return type of test(i) is the same as the type of any of f functions (they all, supposedly, have the same type). I guess, there should be a way how to let julia figure this out. Any ideas?
Sorry, not quite understanding why it doesn’t. The version I gave above will make it so test(i) is inferred correctly, including if all the f’s return the same type. Isn’t that what you were looking for?
Yes, I thought so from your reply. Then, I tested (Julia Version 1.1.1 (2019-05-16)),
f1(x) = 1.0
f2(x) = 1.0
f3(x) = 1.0
function test(i)
list_of_functions = (f1,f2,f3)
return list_of_functions[i](rand())
end
@code_warntype test(1) # Any
and it did not work for me.
Are you saying that in the same code the type in inferred for you?
Or it works only when constant is propagated during the compilation as in your example?
Ah I see, yes, what you’re seeing is right, its just that constant propagation only works for things inside functions, hence why I put the call inside of one with test1() = test(1) in my example. In general Julia only does all possible optimizations, constant propagation included, for things inside functions, so you’ll definitely want to put all speed-critical code inside of them.
In my working example, there are many function indices like i, they are wrapped in more complicated arrays, so It is not straightforward to precompile for the given indices. (possible, although)
In principle, the problem would be solved if I could specify the type of function, like in c++
Ah ok I think I understand the confusion. I don’t mean that you actually need to write functions test1, test2, etc… I just mean that your calls to test(1), test(2), etc… (for any arbitrary i) need to be literally inside some function, and they will be inferred correctly. The only place they will not be inferred correctly is if you call them directly from the REPL (because not all optimizations are turned on there), which is what you are doing when you did @code_warntype test(1) and got Any.
using FunctionWrapper
using BenchmarkTools
const F64F64Func = FunctionWrapper{Float64,Tuple{Float64}}
const funcs_FW = F64F64Func[]
const funcs = Function[]
for i in 1:100
push!(funcs_FW, F64F64Func(x -> x + i))
push!(funcs, x -> x + i)
end
function run(fs)
s = 0.0
for i in 1.0:100.0
for f in fs
s += f(i)
end
end
return s
end
@btime run(funcs_FW)
@btime run(funcs)