Writing efficient code in monadic style

Hi everyone,
I have been trying to write some Julia code in monadic style while keeping it as efficient as possible. I am running in a problem that can be illustrated by the following minimal example:

# Simple implementation of the "list" monad
bind(f, xs) = [ e for x in xs for e in f(x) ]
ret(x) = [x]

function f()
    bind([0, 1]) do x
    bind([0, 1]) do y
        ret((x, y))
    end end
end
            
@assert f() == [(0, 0), (0, 1), (1, 0), (1, 1)]          
@code_warntype(f()) # Body::Array{_1,1} where _1...

Here, I wish that the type system could infer the return type of f (namely Vector{Tuple{Int, Int}}), which it does not. Note that in the following simpler example, Julia successfully infers the return type of f2:

function f2()
    bind([0, 1]) do x
        ret((x, x))
    end
end
@code_warntype f2() # Body::Array{Tuple{Int64,Int64},1}...

Can you see any way I can solve this problem?

Edit: I found that it is possible to solve the problem by adding the following type annotations to f:

function f()
    bind([0, 1]) do x
    bind([0, 1]) do y
            ret((x, y))
    end :: Vector{Tuple{Int, Int}} 
    end :: Vector{Tuple{Int, Int}}
end
            
@assert f() == [(0, 0), (0, 1), (1, 0), (1, 1)]          
@code_warntype(f()) # Body::Array{Tuple{Int64,Int64,Int64},1}

Any idea why these annotations are needed here?

I think that the type inference fails on the outer generator. You can see this with

bind(f, xs) = [ e for x in xs for e in f(x) ]
ret(x) = [x]

function f()
   bind(f_inner, [0, 1])
end

function f_inner(x)
    bind([0, 1]) do y
        ret((x, y))
    end
end

@code_warntype(f()) # Body::Array{_1,1} where _1..
@code_warntype(f_inner(0)) # OK

I am not sure that the Julia compiler is ideal for dealing with monad acrobatics, since arrow types are explicitly not included in the language. So perhaps this programming style will lead to difficulties.

1 Like