How do I get the expected return type of a polymorphic function at compile time?

I find myself frequently running into this problem - I have a function f that returns a vector, and a function g that pre-allocates an array and uses f to fill it column by column. In the event that f raises an exception, I would like that column to be filled with NaN.

So something like this:

function g(n, m, args...)
    a = fill(NaN, m, n)
    for i in 1:n
        try
            a[:, i] = f(i, args...)
        catch
        end
    end
    return a
end

The problem is that f is polymorphic and it may not be trivial to determine what the element type of a should be from args alone. If f is type stable the compiler should know this, but I don’t know how to access that information inside g in a way that the compiler can also predict the type of a.

The way I have been doing this in general is calling f(1, args...) before allocating the array and then using the eltype of that, but this does not work when the function is not guaranteed to return.

1 Like

You have run into one of the trickiest problems in using Julia performantly with generic functions and pre-allocated buffers.

Unfortunately, there is no generic solution. The cleanest way is probably defining an auxiliary function g_elt that works on types. Sometimes you can assemble it from some functions that work on types, eg see this and similar discussions.

There is also Base.return_types, although relying on that information is somewhat frowned upon as it is an implementation detail.

This is strange to me, because (assuming f is type stable) the compiler does know ahead of time what the return type will be. It seems like we should be able to access that somehow.

Realistically, I can probably get away with just writing another function that gets the return type of f for several special cases of input, but I am surprised there’s no general way of doing this.

Maybe I am missing something but in the following Julia is able to infer the output type (albeit with really simple functions). What if you pass f as an argument to g? If f is type-stable given the inputs i and args then it should be able to infer it is my understanding.

f1(x,y) = x*y
f2(x,y) = "$x$y"

g(f, x, y) = f(x,y)

@code_warntype g(f1, 1, 2)
@code_warntype g(f2, 1, 2)

The compiler may know, especially if the function is type stable. But even though the compiler is getting better and better, there is no guarantee, so you may not be able to rely on this in general.

As @mauro3 pointed out, you have some tools for inspecting what the compiler knows, but it is generally considered a bad habit to make results dependent on what the compiler can infer, as this can change (usually for the better).

As has been pointed out, in general this is not solvable, but here I have some code that attempts anyway, though the employed solution is frowned upon.

This is in general a tricky issue, and it’s been discussed before (see Ridiculous idea: types from the future and the various threads that link to it).

But before we get too deep into the weeds of the various solutions to the problem, it would be helpful to understand what you actually want. NaN is a specific value of type Float64. If f() returns columns whose element type is also Float64, then the behavior of your function is obvious. But what if it doesn’t? If f() returns a Float32, then do you want a to have an element type of Float64 ? Or do you want a to have element type Float32 with all the NaNs replaced with their 32-bit equivalents? And what if f() returns a type which does not have a NaN (like Int, for example)? Should the element type of a be Union{Float64, Int} ?

Unless you actually mean NaN in the IEEE floating-point sense, I would actually suggest not using it at all and instead using Union{T, Missing} or Union{T, Nothing} (where T is the output element type of f(). See First-Class Statistical Missing Values Support in Julia 0.7 for more on why this is now a recommended approach in Julia 1.0.

It is my impression that all the elements are filled up, so a = Array{T}(undef, m, n) may be better.

Perhaps, indeed. But filled up with what? Is convert(T, NaN) what the OP actually wants? Or is NaN just a flag to signify missingness?

Presumably (from the code) the value is irrelevant, since it is overwritten, the key is calculating the type T.

This approach can be extended to pre-allocated buffers if applied twice (inner-outer, just guess a buffer type and get a new one if guessed wrong, doing it in each iteration), but I need to think about how to package and simplify this.

1 Like

But it’s not: in the OP’s example, the entries are only overwritten if f() does not throw an exception.

1 Like

Good point, I missed this. Indeed missing is a better suggestion.

It’s not ideal, for performance reasons, to rely on errors for control flow. Make your function f return a missing instead of throwing an exception, and then either have the result array type Union{Missing,eltype(f(...))}, or if you really want NaN only assign if the function result is not a missing.

As for the return type of the function, maybe you could use some dummy arguments you know will work for the first call?

In this specific case the return value of f is always going to be a subtype of Real. Specifically, single or double precision floats or ForwardDiff.Dual, which do all have an NaN value.

But this is just one case of a more general problem I run into often. It would be nice if I could do this in general as long as @code_warntype shows that f is type stable with the given arguments.

I’ve come across this issue several times. I imagine there can’t be a fully general solution because you could do things which depend on the inferred type which affect the inferred type and so you spiral into hell, or converge to a fixed point. There’s probably a good enough solution though