Keyword Arguments and Type Stability

I need to remain flexible (I’m testing the code) and I ran into the following problem.

MWE

param = 2
x = rand(100)

foo(x; β = param)  =  sum(x .* β)

@code_warntype foo(x)             # type unstable
@code_warntype foo(x; β = param)  # type stable

I was suprised that param is taken as a global variable in the first case. Is there any way to make foo(x) type stable without using const params = 2.0?

I want to keep params flexible and that the compiler infers the type. In other words, I want to work as if I call the function through foo(x; β = param), without typing β = param in every run (otherwise I prefer to avokd keyword arguments).

I tried

foo(x; β::T = param)  where {T}          =  sum(x .* β)     #  not solving it
foo(x; β::T = param)  where {T<:Number}  =  sum(x .* β)     #  not solving it
foo(x; β::T = param)  where {T<:Float64} =  sum(x .* β)     # this works but is not flexible

Many Thanks!

1 Like

Have you tried this?

foo(x; β::typeof(param) = param)  =  sum(x .* β)
1 Like

Thanks!!! this is what happens when you think about an easy problem in a really complicated way…

It seems the compiler could easily figure out the type on its own? I’m sure there must be some good reason why not, but the compiler could automatically wrap this function into another function? Don’t know. Thanks again.

Consider this:

julia> param = 2
       foo(x; β = param)  =  sum(x .* β)
       @show foo(5)
       param = 3
       @show foo(5)
       param = 4.0
       @show foo(5);
foo(5) = 10
foo(5) = 15
foo(5) = 20.0

param’s type and value are free to change, and such changes are immediately reflected in the next call to foo.

However, if we annotate the type:

julia> param = 2
       foo(x; β::typeof(param) = param)  =  sum(x .* β)
       @show foo(5)
       param = 3
       @show foo(5)
       param = 4.0
       @show foo(5);
foo(5) = 10
foo(5) = 15
ERROR: MethodError: no method matching var"#foo#14"(::Float64, ::typeof(foo), ::Int64)

when param’s type changes, you get an error—but you don’t see it until foo is called again. To avoid an error, you have to re-define the function after changing the type of param. Alternatively, you can declare the type of param with e.g. param::Int = 2.

There’s still a slight performance penalty even when it’s type-stable, because the value has to be looked up on each function call, so if param’s value won’t change between function definitions you might prefer to use a let block:

julia> param=2
       foo(x; β = param)  =  sum(x .* β)
       bar(x; β::typeof(param) = param) = sum(x .* β)
       let param=param
           global baz(x; β = param) = sum(x .* β)
       end
baz (generic function with 1 method)

julia> using BenchmarkTools
       @btime foo($(1,2,3))
       @btime bar($(1,2,3))
       @btime baz($(1,2,3))
  18.630 ns (1 allocation: 32 bytes)
  5.900 ns (0 allocations: 0 bytes)
  2.000 ns (0 allocations: 0 bytes)
12
4 Likes

This is what I was having in mind before this discussion, since calling foo(x; β = param) works fine.

x = [1,2,3]
param = 3.0
foo(x; β = param)  =  sum(x .* β) 

# function called by the user
param = 3.0
foo(x)            

#what Julia calls is literally
foo(x) =  sum(x .* param) # type unstable

# What I expected to be called
 foo(x; β = param) 
# Julia recognizes `param` as local if you call the function in this way, 
# hence it's type stable. So it's like if Julia called 
`foo(x, param)`  
# rather than `foo(x)`

It seems natural to me that, when you define the function with a keyword argument, you’re hinting that function_called_expected should be called.

I imagine it’s conflicting with something else. Just talking what seems natural, not necessarily feasible.

If I understand what you’re trying to do, it seems like the best approach is to set:

param::Float64 = 0.0
foo(x; β = param) = sum(x .* β)

and then the access to β will be type-stable because param has its type fixed. When users attempt to set param to an integer value, it automatically gets promoted to float:

julia> param = 2
       param
2.0

The drawback to this is that the user won’t be able to adjust the precision of the global param, but with Float64 that’s often not necessary anyway.


To address what I detect seems to be a bit of confusion (correct me if I’m mis-estimating the situation!):

Even with the above type-stable definition, if the user writes this code in global scope:

x = (1,2,3)
foo(x)

the call to foo will be type-unstable anyway, because the type of global x is not constant, and runtime dispatch must occur.

It might seem like it’s type-stable, because @code_warntype says so,

julia> @code_warntype foo(x)
MethodInstance for foo(::Tuple{Int64, Int64, Int64})
  from foo(x; β) @ Main REPL[1]:2
Arguments
  #self#::Core.Const(foo)
  x::Tuple{Int64, Int64, Int64}
Body::Float64
1 ─ %1 = Main.:(var"#foo#3")(Main.param, #self#, x)::Float64
└──      return %1

but the @code_warntype macro isn’t telling the whole story. The macro doesn’t actually know whether x is const or not, or whether it has its type fixed or not; it just runs typeof(x) at that moment to find the concrete type of its value at that moment and runs with it. For a more accurate depiction that includes the effects of x’s type-instability, call @code_warntype (()->foo(x))() to wrap x in a function that captures it. Benchmarks are also educational.

Another note: notice from the above that a function var"#foo#3" has been declared behind the scenes, and is being called by foo(x). And notice that foo(x) passes param to it as an argument. Here’s another perspective:

julia> @code_lowered foo(x)
CodeInfo(
1 ─ %1 = Main.:(var"#foo#3")(Main.param, #self#, x)
└──      return %1
)

(“Lowered” code is still Julia, just a subset of the language after the parser has transformed a bunch of syntax sugar into a lower-level form.)

Basically, foo(x) is actually:

foo(x) = var"#foo#3"(param, foo, x)

hence, the global param is simply captured and passed to a separate function:

julia> methods(var"#foo#3")
# 1 method for generic function "#foo#3" from Main:
 [1] var"#foo#3"(β, ::typeof(foo), x)
     @ REPL[1]:2

julia> @code_lowered var"#foo#3"(param, foo, x)
CodeInfo(
1 ─      nothing
│   %2 = Base.broadcasted(Main.:*, x, β)
│   %3 = Base.materialize(%2)
│   %4 = Main.sum(%3)
└──      return %4
)

this is the function you actually wrote. Functions that take keywords are actually broken up into multiple functions behind-the-scenes. (try @code_lowered foo(x; β=params) too, and see if you can find var"#foo#3" (or whatever it’s called on your system)!)

Thanks for your detailed answer!!! it’s not actually for a package or for any user. I was just testing some code, and didn’t want to specify β all the time. But I was surprised that this was creating a type instability, and wanted to know why. I rarely work with keyword arguments.

I know that foo2() = foo(x) creates a type instability, but I expected that keyword arguments would behave like this example (I mean in terms of results, even if under the hood Julia does something else).

param = 1
foo(x; β = param) = sum(x .+ β)
foo([1,2]) #type unstable

## I expected to have the same behavior as
foo2(x, nt) = sum(x .+ nt.β)
   
param = 1
nt = (; β = param)
@code_warntype foo2([1,2],nt) # 5 and type stable

param = 3.0
nt = (; β = param)
@code_warntype foo2([1,2],nt) # 9.0 and type stable

But Julia is doing the following (maybe not exactly this, but for practical purposes it has the same implications)

x = [1,2]
param = 1

foo(x; β = param) = sum(x .+ β)

#equivalent to
foo2(x; β) = sum(x .+ β)
foo2(x)    = sum(x .+ param)

foo2(2) #type unstable because `param` is global

Thanks again!

IIUC, it would be more like

foo(x; β = param) = sum(x .+ β)

# "translated to"
foo_(x, β) = sum(x .+ β)    # internal function, always expects 2 arguments
foo(x)     = foo_(x, param) # user-facing method with only one provided argument
foo(x; β)  = foo_(x, β)     # user-facing method when a keyword argument is provided

This makes a difference because even when β is not provided and param is used by default, there is still a function barrier. Meaning that dynamic dispatch will stop at the point where foo_ is called, because it will be specialized for the concrete types of arguments provided for x and β.

1 Like

As explained above, if nt is a global, this will not be type-stable in practice. The fact that @code_warntype says it’s stable is a measurement artifact.

Ah! In that case, you should consider this approach:

param() = 0.0
foo(x; β = param()) = sum(x .* β)

By expressing a constant as a function, its type is inferrable without having to declare it const, and you can update its value and its type without having to restart your Julia session every time.

1 Like

param() is exactly what I was looking for!!

What do you mean that this is an artifact but not type stable?

x = [1,2,3]
sum(x)

thanks again

1 Like

I was going to come up with an example, but I realized that I’m wrong—in the way that matters anyway😅

Dynamic dispatch indeed occurs when you call sum(x), because x is a global variable of unknown type, but that will take twenty nanoseconds or so and be a one-time cost per “shift-enter”—so it’s totally irrelevant to user experience.

1 Like

Let me try and show a simple enough but not-too-contrived example of this:

julia> β₀ = 1.0
1.0

julia> rescale(x; β=β₀) = x * β
rescale (generic function with 1 method)

When testing interactively, it appears as though passing β₀ as argument to rescale is type stable:

julia> @code_warntype rescale(1.0, β=β₀)
MethodInstance for (::var"#rescale##kw")(::NamedTuple{(:β,), Tuple{Float64}}, ::typeof(rescale), ::Float64)                         
  from (::var"#rescale##kw")(::Any, ::typeof(rescale), x) in Main at REPL[2]:1
Arguments
  _::Core.Const(var"#rescale##kw"())
  @_2::NamedTuple{(:β,), Tuple{Float64}}
  @_3::Core.Const(rescale)
  x::Float64
Locals
  β::Float64
  @_6::Float64
Body::Float64
[...]

But this is only due to the global variable β₀ having a value (of a determined, concrete type) at the moment you called @code_warntype. What this really tells you is: if we know about the types of x and β₀, then type inference can determine the type of the results. But there is a catch: since β₀ is a global, it can be rebound to any value of any type at any time without the compiler knowing anything about it. In an interactive context, @code_warntype bases its analysis on the type of the value β₀ is currently bound to, but this information will in general only be known at at runtime (right when rescale is about to get called). This is what we see below:

julia> foo(x) = rescale(x, β=β₀)
foo (generic function with 1 method)

julia> @code_warntype foo(1.0)
MethodInstance for foo(::Float64)
  from foo(x) in Main at REPL[4]:1
Arguments
  #self#::Core.Const(foo)
  x::Float64
Body::Any
[...]

When foo gets compiled, there is no way the compiler can prove that β₀ will always have the same type, so inference can’t tell you anything about the return type of rescale. I guess this is what @uniment wanted to say: rescale appears to be type-stable (because it is: when the type of its arguments are known, its return type can be inferred), but passing a global variable as an argument to rescale is not type-stable (because the type of the global variable can’t be known ahead of time)


As is often the case, whether this matters in practice depends a lot on your use case. Since rescale is itself type stable, and calling it implies the presence of a function barrier, dynamic dispatch will happen only when the function is called, but the compiler will be able to generate an efficient specialization. So it all boils down to whether your function is called in a hot loop, and how much work it performs.

Here is a case where the type instability caused by the global variable β₀ will matter: the function does almost nothing and it is called in a hot loop

julia> function rescaled_sum1(xs)
           acc = zero(eltype(xs))
           for x in xs
               acc += rescale(x) # β not provided => β₀ is used but the compiler does not know its type at the time when rescaled_sum1 is compiled
           end
           acc
       end
rescaled_sum1 (generic function with 1 method)

julia> using BenchmarkTools

julia> x = rand(1000);

julia> @btime rescaled_sum1($x)
  24.325 μs (3000 allocations: 46.88 KiB)
489.3193654828957

In contrast, here is a case where the presence of a function barrier restores almost all the performance: there will only be one dynamic dispatch up front (which is responsible for the only allocation seen), and its cost is diluted in the quantity of work performed in the inner loop (which has been specialized and compiled to efficient code)

julia> function rescaled_sum2(xs; β=β₀) 
           acc = zero(eltype(xs))
           for x in xs
               acc += rescale(x; β) # the type of β is known at least when rescaled_sum2 gets compiled
           end
           acc
       end
rescaled_sum2 (generic function with 1 method)

julia> @btime rescaled_sum2($x)
  858.532 ns (1 allocation: 16 bytes)
489.3193654828957

And for reference, here is the same case where the global variable is not used any more, making all dynamic dispatch disappear. This confirms that the only allocation above was due to the run-time dispatch, which all-in-all incurred a 30ns extra time.

julia> @btime rescaled_sum2($x, β=1.0)
  820.988 ns (0 allocations: 0 bytes)
489.3193654828957
2 Likes