Correct way of using value type argument


#1

In a separate post (https://discourse.julialang.org/t/type-stability-issues-when-using-staticarrays/19126), I learned some things about value type argument. I still have a doubt. In the following code, it seems to me that foo2 is the correct usage of Val{ }. However, it seems that the benchmarks are ok for foo1 as well. Is this correct?

using StaticArrays
using BenchmarkTools
using Test

function foo1(::Val{n}) where n
	xi = gausspoints1d(Val(n))
	w = gaussweights1d(Val(n))
end

function foo2(::Val{n}) where n
	xi = gausspoints1d_2(Val(n))
	w = gaussweights1d_2(Val(n))
end

function gausspoints1d_2(::Val{n}) where n
	return gausspoints1d(Val(n))
end

function gaussweights1d_2(::Val{n}) where n
	return gaussweights1d(Val(n))
end

function gausspoints1d(::Val{9})
	return @SVector [-0.968160239507626089835576202904,
                             -0.836031107326635794299429788070,
                             -0.613371432700590397308702039341,
                             -0.324253423403808929038538014643,
                             0.0,
                             0.324253423403808929038538014643,
                             0.613371432700590397308702039341,
                             0.836031107326635794299429788070,
                             0.968160239507626089835576202904]
end

function gausspoints1d(::Val{10})
	return @SVector [-0.973906528517171720077964012084,
                              -0.865063366688984510732096688423,
                              -0.679409568299024406234327365115,
                              -0.433395394129247290799265943166,
                              -0.148874338981631210884826001130,
                               0.148874338981631210884826001130,
                               0.433395394129247290799265943166,
                               0.679409568299024406234327365115,
                               0.865063366688984510732096688423,
                               0.973906528517171720077964012084]
end

function gaussweights1d(::Val{9})
	return @SVector [0.0812743883615744119718921581105,
                             0.180648160694857404058472031243,
                             0.260610696402935462318742869419,
                             0.312347077040002840068630406584,
                             0.330239355001259763164525069287,
                             0.312347077040002840068630406584,
                             0.260610696402935462318742869419,
                             0.180648160694857404058472031243,
                             0.0812743883615744119718921581105]
end

function gaussweights1d(::Val{10})
	return @SVector [0.0666713443086881375935688098933,
0.149451349150580593145776339658,
0.219086362515982043995534934228,
0.269266719309996355091226921569,
0.295524224714752870173892994651,
0.295524224714752870173892994651,
0.269266719309996355091226921569,
0.219086362515982043995534934228,
0.149451349150580593145776339658,
0.0666713443086881375935688098933]
end

# tests
n = 10
time1() = @btime foo1(Val($n)) 
time2() = @btime foo2(Val($n))

time1()
time2()

@inferred foo1(Val(n))
@inferred foo2(Val(n))

#2

The difference is one more indirection, correct? What are you trying to achieve with this?


#3

@mauro3 type stability. The @inferred says that both are type stable, but from the performance tips I get that foo2 is the correct usage of Val{ }. My question is why foo1 is apparently working ok.


#4

I would think that both ways are correct, but foo2 is needlessly complicated: your functions gausspoints1d_2 and gaussweights1d_2 are merely aliases to existing functions gausspoints1d and gaussweights1d.

Maybe it will be clearer this way: in the following example, you would probably agree that foo is useless because it only forwards its argument to bar:

foo(x) = bar(x)

Your functions behave essentially in the same way:

foo(Val{n}) where n = bar(Val(n))

is equivalent to

foo(x::Val{n}) where n = bar(x)

because Val(n) is the only possible value of type Val{n}. Therefore we necessarily have x === Val(n).


#5

You want to structure your code as follows:

function read_in_data(file::String)
    # Here we read in data from `file`
    # and create a bunch of types.
    # since we do not know what will be in the file
    # perhaps we do not know the types of everything
    type1, type2, type3 = read_stuff(file)
    run_core_computation(type1, type2, type3)
end

function run_core_computation(type1, type2, type3)
    # Here we run the core computation, since we know the types of 
    # type1, type2, type3 when the function is called, it can be optimized
    # properly.
    for i in 1:1000000
        do_stuff(type1, type2, type3)
    end
end   

#6

Thanks @kristoffer.carlsson. However, this takes me to my initial code, where it was coded like you say. I repeatedly saw comments about type stability in the forum and then I asked myself what was that since I never took that into consideration when I started coding. So, I started playing with the @inferred macro at different parts of my code getting errors in many parts. My concern is the following: Say I put @inferred at a top level function and get no error… ok. But, say I now put the @inferred at an inner function and suddenly get an error (as a novice, I am not sure if this is the intended use for @inferred). So, can it be said that the code is type unstable? Here is a code that illustrates my concern:

using StaticArrays
using BenchmarkTools
using Test

function type_matches()
	xi = gausspoints1d(10)
	w = gaussweights1d(10)
end

# @inferred is moved inside and an error is thrown
function type_does_not_match()
	xi = gausspoints1d(10)
	w = gaussweights1d(10)
	@inferred gausspoints1d(10)
	@inferred gaussweights1d(10)
end

function gausspoints1d(n::Int)
	n == 9 && return @SVector [-0.968160239507626089835576202904,
                             -0.836031107326635794299429788070,
                             -0.613371432700590397308702039341,
                             -0.324253423403808929038538014643,
                             0.0,
                             0.324253423403808929038538014643,
                             0.613371432700590397308702039341,
                             0.836031107326635794299429788070,
                             0.968160239507626089835576202904]
  n == 10 && return @SVector [-0.973906528517171720077964012084,
                              -0.865063366688984510732096688423,
                              -0.679409568299024406234327365115,
                              -0.433395394129247290799265943166,
                              -0.148874338981631210884826001130,
                               0.148874338981631210884826001130,
                               0.433395394129247290799265943166,
                               0.679409568299024406234327365115,
                               0.865063366688984510732096688423,
                               0.973906528517171720077964012084]
	throw(ArgumentError("gausspoints1d(): n = $n is not a valid value"))
end

function gaussweights1d(n::Int)
	n == 9 && return @SVector [0.0812743883615744119718921581105,
                             0.180648160694857404058472031243,
                             0.260610696402935462318742869419,
                             0.312347077040002840068630406584,
                             0.330239355001259763164525069287,
                             0.312347077040002840068630406584,
                             0.260610696402935462318742869419,
                             0.180648160694857404058472031243,
                             0.0812743883615744119718921581105]
  n == 10 && return @SVector [0.0666713443086881375935688098933,
                              0.149451349150580593145776339658,
                              0.219086362515982043995534934228,
                              0.269266719309996355091226921569,
                              0.295524224714752870173892994651,
                              0.295524224714752870173892994651,
                              0.269266719309996355091226921569,
                              0.219086362515982043995534934228,
                              0.149451349150580593145776339658,
                              0.0666713443086881375935688098933]
	throw(ArgumentError("gaussweights1d(): n = $n is not a valid value"))
end

@inferred type_matches()
type_does_not_match()

In this code, type_matches() is type stable as @inferred correctly infers the type. However, if @inferred is moved inside the function (I chose a different name for the function to make the comparison: type_does_not_match()) an error is thrown indicating that the type cannot be correctly inferred. So essentially the same code, but according to @inferred, one is type stable and the other is not. Do I need to be concerned just with the top function or do I need to be concerned also with the inner functions for type stability?


#7

You typically don’t care if the creation of the gausspoints is type stable. That happens once and in a function “high up” the call stack. You then pass the created gausspoints to a function that does the actual computation and the type will be known in those functions and performance will be good.

Create the types (which is very quick), pass them to functions doing computations on them.


#8

There are a bunch of issues mixed together in this particular question. One of the things that’s going on in your particular case is constant propagation, which Julia has gotten much better at in v1.0.

Here’s a much simpler example. Let’s define an obviously type-unstable function:

julia> function f(x)
         if x < 0.5
           return 1.0
         else
           return 1
         end
       end
f (generic function with 1 method)

julia> @inferred(f(1))
ERROR: return type Int64 does not match inferred return type Union{Float64, Int64}
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope at none:0

Now let’s define a function g which calls f with some constant argument 1:

julia> function g()
         f(1)
       end
g (generic function with 1 method)

julia> @inferred g()
1

g() is inferred correctly! That’s pretty cool. What’s happening is that during the compilation of g(), Julia was able to propagate that constant value of 1 when compiling its call to f() and was able to figure out that the if x < 0.5 condition would always be false and therefore the output would always be an Int. That’s pretty neat, but it’s not really relevant to your example.

For example, if we change g() to call f with a value which cannot be propagated as a constant (like an input variable, a random number, or something read from a file), then g() is just as hard to infer as f():

julia> function g()
         f(rand())
       end
g (generic function with 1 method)

julia> @inferred g()
ERROR: return type Int64 does not match inferred return type Union{Float64, Int64}
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] top-level scope at none:0

So this is interesting but not really relevant to what you want to do.

Your real application is closer to something like this:

function long_running_computation(x)
  for i in 1:100
     println(x)
  end
end

function g()
  x = f(rand())
  long_running_computation(x)
end

So the type of the variable x inside g() cannot be inferred, but it doesn’t matter, because once you pass the value of x, whatever type it happens to be, Julia will call the correct, specialized compiled version of long_running_computation() for the particular type of x.

This is what we mean when we refer to a “function barrier”: outside of long_running_computation, the type of x might not be inferrable, but once you pass a particular value into long_running_computation, there’s no further cost.


#9

Wow! That was really very instructive. Thank you @rdeits


#10

In the long run, it would be great to have these things documented. The compiler is getting more and more clever over time, but understanding what works and what doesn’t, and how to fix it, is usually something that one learns from experience, by trial and error, and reading existing code, especially in Base.