Understanding do notation, type-stability for function with other functions as arguments

If I do

function apply_fn2a(fns::Tuple,a)
    Tuple(fn(a) for fn in fns)
end

using StatsBase
const x = rand(1:5,1000)
@code_warntype apply_fn2a((countmap, sum, mean), x)

then the output type is simply a generic Tuple. However in the examples below Julia can correctly infer a specific type of tuple as output. Is there a way to make the code above be able to infer the output tuple type?

For example

function apply_fn2a(fn::Function,a)
    fn(a)
end

Initially I thought about how type-stability work with apply_fn2a? The output type of apply_fn2a will depend on fn. But given the do notation is a language feature and from my testing, I can see that Julia can infer the type of output apply_fn2a if fn is output-type-stable (if that’s the right term), which is very nice and clever. E.g. @code_warntype apply_fn2a(sum, rand(1000)) and

@code_warntype apply_fn2a(rand(1000)) do x
    x[end] - x[1]
end

show that Julia can infer the types.

But what if I want to write a function that lets a user supply multiple functions and each function has different output. I found that I can just use a tuple of functions and Julia will still correctly infer the types

function apply_fn2a(fn::Tuple{Function, Function},a)
    (fn[1](a), fn[2](a))
end

const x = rand(1:5, 100)
using StatsBase
@code_warntype apply_fn2a((countmap, sum), x)

Note that Tuple{...} constructs a tuple type, while tuple and (...) construct a tuple. You probably want something like

function apply_fn2a(fns::Tuple,a)
    tuple((fn(a) for fn in fns)...)
end

which is inferred properly.

3 Likes

Is it really inferred properly? It’s showing red in @code_warntype output. The return type looks a bit odd with the first element of the tuple hardened and the rest in a Vararg…

julia> @code_warntype apply_fn2a((countmap, sum, mean), x)
Body::Tuple{Dict{Int64,Int64},Vararg{Union{Float64, Int64, Dict{Int64,Int64}},N} where N}
1 ─ %1 = %new(getfield(Main, Symbol("##19#20")){Array{Int64,1}}, a)::getfield(Main, Symbol("##19#20")){Array{Int64,1}}
│   %2 = %new(Base.Generator{Tuple{typeof(countmap),typeof(sum),typeof(mean)},getfield(Main, Symbol("##19#20")){Array{Int64,1}}}, %1, (StatsBase.countmap, sum, Statistics.mean))::Base.Generator{Tuple{typeof(countmap),typeof(sum),typeof(mean)},getfield(Main, Symbol("##19#20")){Array{Int64,1}}}
│   %3 = (Core._apply)(Main.tuple, %2)::Tuple{Dict{Int64,Int64},Vararg{Union{Float64, Int64, Dict{Int64,Int64}},N} where N}
└──      return %3

map(f -> f(a), fns) infers correctly with up to 15 functions in 1.1. Past that, I think you’ll need a generated function.

1 Like

I think that’s just a heuristic limit in Base and not compiler limitation:

https://github.com/JuliaLang/julia/blob/89baf832b221c5959ebdeb86b28f2725d7ce4b3d/base/tuple.jl#L144-L145

https://github.com/JuliaLang/julia/blob/89baf832b221c5959ebdeb86b28f2725d7ce4b3d/base/tuple.jl#L148-L154

It’s pretty easy to roll your own map for tuples if you really need it

mapt(f, xs) = mapargs(f, xs...)
mapargs(f, x, xs...) = (f(x), mapargs(f, xs...)...)
mapargs(f) = ()

fns = ntuple(x -> y -> x + y, 20)

using Test
@inferred mapt(f -> f(0), fns)
1 Like

Actually, 0 .|> fns can be inferred, too.

1 Like