Type inference from list of functions

I was surprised that type inference does not work in this case

for i=1:3
    @eval $(Symbol("f"*string(i)))(x) = x^$(i)+$(i) 
end
#
@code_warntype f1(1.1) # it is Float64, fine!
#
function test(i)
    list_of_functions = [f1,f2,f3]
    return list_of_functions[i](rand())
end
@code_warntype test(1) # Any

what would be the correct way call a function by given index?

For recording, I have the other few “wrong” ways to write the same code,

function testII(i)
    return @eval $(Symbol("f"*string(i)))(rand())
end
@code_warntype testII(1) # Any
# 
function testIII(i)
    f = f1
    i==2 && (f = f2)
    i==3 && (f = f3)
    return f(rand())
end
@code_warntype testIII(1) # Any

In the next two examples, the correct type is inferred, however, they are incovinient,

# code for function is repeated for every f[i]
function testIII(i)
    i==1 && return f1(rand())
    i==2 && return f2(rand())
    return f3(rand())
end
@code_warntype testIII(1) # Float64
#
# not really what was required
# also, the code is not readable for students
for i=1:3
    @eval function $(Symbol("test_"*string(i)))()
        return $(Symbol("f"*string(i)))(rand())
    end
end
@code_warntype test_2() # Float64

The type information is lost at [f1,f2,f3] because that vector is of type Vector{Function}. A tuple will instead store the type of each of the three arguments (i.e. the type of the functions), so you need:

function test(i)
    list_of_functions = (f1,f2,f3)
    return list_of_functions[i](rand())
end

The variable i is still not known at compile time, but thanks to constant propagation, Julia figures it out, you just need to test the whole thing inside another function. Its more obvious that its working if the three function return different types

f1(x) = 1
f2(x) = "a"
f3(x) = nothing
test1() = test(1)
test2() = test(2)
test3() = test(3)
@code_warntype test1() # Int
@code_warntype test2() # String
@code_warntype test3() # Nothing
2 Likes

Thank you very much, that is good to know!
Would it propagate constant tuples or more complicated structures as well?

However, it does not solve my problem, still. I know that the return type of test(i) is the same as the type of any of f functions (they all, supposedly, have the same type). I guess, there should be a way how to let julia figure this out. Any ideas?

Sorry, not quite understanding why it doesn’t. The version I gave above will make it so test(i) is inferred correctly, including if all the f's return the same type. Isn’t that what you were looking for?

Yes, I thought so from your reply. Then, I tested (Julia Version 1.1.1 (2019-05-16)),

f1(x) = 1.0
f2(x) = 1.0
f3(x) = 1.0
function test(i)
    list_of_functions = (f1,f2,f3)
    return list_of_functions[i](rand())
end
@code_warntype test(1) # Any

and it did not work for me.

Are you saying that in the same code the type in inferred for you?
Or it works only when constant is propagated during the compilation as in your example?

Ah I see, yes, what you’re seeing is right, its just that constant propagation only works for things inside functions, hence why I put the call inside of one with test1() = test(1) in my example. In general Julia only does all possible optimizations, constant propagation included, for things inside functions, so you’ll definitely want to put all speed-critical code inside of them.

In my working example, there are many function indices like i, they are wrapped in more complicated arrays, so It is not straightforward to precompile for the given indices. (possible, although)

In principle, the problem would be solved if I could specify the type of function, like in c++

std::vector<std::function<double(double)> > list_of_functions

It seems to me that such inference could be handled by the julia compiler.

Thank you for replies, anyway.

Ah ok I think I understand the confusion. I don’t mean that you actually need to write functions test1, test2, etc… I just mean that your calls to test(1), test(2), etc… (for any arbitrary i) need to be literally inside some function, and they will be inferred correctly. The only place they will not be inferred correctly is if you call them directly from the REPL (because not all optimizations are turned on there), which is what you are doing when you did @code_warntype test(1) and got Any.

I got it, thank you a lot!

1 Like

Are only scalar constants propagated?

f1(x) = 1.0
f2(x) = 1.0
f3(x) = 1.0
function test(vect)
    i = vect[1]
    list_of_functions = (f1,f2,f3)
    return list_of_f[i](rand())
end
#
const v = [1,2,3,1,3,3]
test1() = test(v)
@code_warntype test1() # Any

You can do this with https://github.com/yuyichao/FunctionWrappers.jl

using FunctionWrapper
using BenchmarkTools

const F64F64Func = FunctionWrapper{Float64,Tuple{Float64}}

const funcs_FW = F64F64Func[]
const funcs = Function[]

for i in 1:100
    push!(funcs_FW, F64F64Func(x -> x + i))
    push!(funcs, x -> x + i)
end

function run(fs)
    s = 0.0
    for i in 1.0:100.0
        for f in fs
            s += f(i)
        end
    end
    return s
end

@btime run(funcs_FW)
@btime run(funcs)
2 Likes