What computations are fine to do within `Val`?

Here are two (highly artificial) examples:

f(::Array{T,N}, ::Array{T,M}) where {T,N,M} = zeros(T, ntuple(i -> 2, Val(N+M)))
g(::Array{T,N}, ::Array{T,M}) where {T,N,M} = zeros(T, ntuple(i -> 2, Val(round(Int, √abs(N+M)))))

The first is type stable but the second is not (you can check with @code_warntype). Why?

1 Like

Or even simpler:

f(::Val{N}) where {N} = Val(N^3)
g(::Val{N}) where {N} = Val(N^4)

The first infers, the second doesn’t.

Could it have to do with internal compiler inference heuristics on when to bail? N^4 might just cross a complexity boundary that N^3 doesn’t.

1 Like

Is there an easy way to understand when it works and when it doesn’t?

I think this has to do with how Base.literal_pow(^, x, Val(3)) has a specialized implementation, which just expands to x*x*x, whereas for Val(4), it falls back to Base.^, which uses a power by squaring algorithm, that is probably too complex for the compiler to propagate constants through.

3 Likes

You can simplify the original problem a bit to just an arbitrary parametric type like below. Removing sqrt fixes inference, and so does making the calculation pure (not that I would recommend this):

struct Foo{T} end

Base.@pure calc(x::Int, y::Int) = round(Int, sqrt(abs(x + y)))

g(::Foo{N}, ::Foo{M}) where {N,M} = Foo{calc(N,M)}()
g2(::Foo{N}, ::Foo{M}) where {N,M} = Foo{abs(N+M)}()

F = Foo{2}()
@code_warntype g(F, F)  # OK
@code_warntype g2(F, F) # OK

Unfortunately, I don’t know a systematic way to diagnose these things, just trial and error.

3 Likes

I think this is exactly the situation that the @pure annotation is meant to solve. Basically it promises that there will never be any methods added to calc, +, abs, round or sqrt (or any functions that they might call) or any other global state that affect this calculation.

In this particular case, I think that should be safe to assume.

3 Likes

This is a very useful post for me. It takes a bit of the voodoo out of Base.@pure. I understand the problem would arise, however, if the user were to extend sqrt with some custom type. What would happen in that case? Would things still work as long as N + M are not of that custom type, or would things break in another way?

In this case, calc is only defined for Int, so sqrt will always be called with an Int. Adding methods to sqrt should be fine. Those methods could not possibly affect the result of calc as long as the user is not engaging in serious type piracy.

Of course, extending calc to accept user-defined types would be a bad idea. (Also, I think @pure functions should not be extended at all after they first have been called.)

Edit: Note that this is my own interpretation of the documentation. I am by no means an expert here.

1 Like