Why is this case type-unstable?

The following MWE behaves unexpectedly (for me):

julia> function f(t)
         return t
       end
f (generic function with 1 method)

julia> @code_warntype sum(f(1) for i in 1:2)
Variables
  #self#::Core.Compiler.Const(sum, false)
  a::Base.Generator{UnitRange{Int64},var"#15#16"}

Body::Int64
1 ─ %1 = Base.sum(Base.identity, a)::Int64
└──      return %1

It is type stable. However, if I pack this function inside a tuple:

julia> tup = (f, )
(f,)

julia> @code_warntype sum(tup.f(1) for i in 1:2)
Variables
  #self#::Core.Compiler.Const(sum, false)
  a::Base.Generator{UnitRange{Int64},var"#17#18"}

Body::Any
1 ─ %1 = Base.sum(Base.identity, a)::Any
└──      return %1

It becomes type unstable.

What I am trying to accomplish is the following: I have many different functions, which I want to evaluate and add those results:

julia> σ₁(t) = t
σ₁ (generic function with 1 method)

julia> σ₂(t) = 2t
σ₂ (generic function with 1 method)

julia> σ₃(t) = π * t
σ₃ (generic function with 1 method)

julia> σ₄(t) = sin(t)
σ₄ (generic function with 1 method)

julia> σ = (σ₁, σ₂, σ₃, σ₄)
(σ₁, σ₂, σ₃, σ₄)

julia> @code_warntype sum(σ[i](1.0) for i in 1:4)
Variables
  #self#::Core.Compiler.Const(sum, false)
  a::Base.Generator{UnitRange{Int64},var"#21#22"}

Body::Any
1 ─ %1 = Base.sum(Base.identity, a)::Any
└──      return %1

How can I avoid this kind of instability?

tup is a glabal variable and that is encoded in the generator. I don’t even think the first case works BTW. Did you meant to use named tuple?

1 Like

That’s because it’s an error:

julia> sum(tup.f(1) for i in 1:2)
ERROR: type Tuple has no field f

That said, even if you correct this error it’ll be type unstable because tup is a non-constant global:

julia> @code_warntype sum(tup[1](1) for i in 1:2)
Variables
  #self#::Core.Compiler.Const(sum, false)
  a::Base.Generator{UnitRange{Int64},var"#13#14"}

Body::Any

This is because I can later change what the name tup is used for. If we make it const it’ll be type stable.

1 Like

And on top of that, σ[i] is very hard to infer for a heterogeneous tuple. f(1.0) for f in σ is easier. However, you are actually hitting some tuple/recursive length limit for 4 elemnts, 3 elements works on 1.3.0.

julia> const σ2 = (σ₁, σ₂, σ₃)
(σ₁, σ₂, σ₃)

julia> @code_warntype sum(f(1.0) for f in σ2)
Variables
  #self#::Core.Compiler.Const(sum, false)
  a::Core.Compiler.Const(Base.Generator{Tuple{typeof(σ₁),typeof(σ₂),typeof(σ₃)},var"#23#24"}(var"#23#24"(), (σ₁, σ₂, σ₃)), false)

Body::Float64
1 ─ %1 = Base.sum(Base.identity, a)::Float64
└──      return %1
1 Like

Thanks for your replies @yuyichao and @mbauman and sorry for the late response.

I want to make some remarks:

  1. Indeed, it was a typo: tup[1] instead of tup.f.
  2. Also, in the MWE that I have posted, I did not set σ as a const. However, in the example I was testing locally, σ is a variable of known type. Then, the type-instability comes from what @yuyichao commented:

Is there any way I can avoid this? Let me explain where I am having this problem.

I have many functions σ that have either interpolated data or algebraic expressions (or a mixture of both). Then, I would like to compute a value using all those functions in a compact way. For example:

using Parameters
Params = @with_kw (

  N = 5,

  # functions sigma can be anything. In this case they are
  # equivalent and simple algebraic expressions.
  σ₁ =t -> 1. * t,
  σ₂ =t -> 1. * t,
  σ₃ =t -> 1. * t,
  σ₄ =t -> 1. * t,
  σ₅ =t -> 1. * t,
  σ′ = (σ₁, σ₂, σ₃, σ₄, σ₅),

  q = searchsortedfirst,

  δ  = 0.5,
  τ  = [δ for i in 1:N+1],
  T′ = cumsum(τ),
  τ′ = @view τ[2:end]
)

function drift!(du, u, p, t)

  @unpack (N, σ₁, σ₂, σ₃, σ₄, σ₅, σ′, q, δ , τ , T′, τ′) = p

  for i in 1:N
    du[i] = σ′[i](t) * u[i] * sum((τ′[j] * σ′[j](t) * u[j]) / (1 + τ′[j] * u[j]) for j = q(T′, t):i)
  end
end

const p = Params()
const du = zeros(p.N)
const u = ones(p.N)

drift!(du, u, p, 0.1)

@code_warntype drift!(du, u, p, 0.1)

Is there any way Julia can handle the function drift! as type stable (without making explicitly the summation)?

Please, avoid the following solution:

function drift!(du, u, p, t)

  @unpack (N, σ₁, σ₂, σ₃, σ₄, σ₅, σ′, q, δ , τ , T′, τ′) = p

  # σ = (σ₁(t), σ₂(t), σ₃(t), σ₄(t), σ₅(t))
  σ = [σ₁(t), σ₂(t), σ₃(t), σ₄(t), σ₅(t)]

  for i in 1:N
    du[i] = σ[i] * u[i] * sum((τ′[j] * σ[j] * u[j]) / (1 + τ′[j] * u[j]) for j = q(T′, t):i)
  end
end

because the user can only set an expression for du using the provided parameters. He cannot set additional sentences for the drift! function. What I mean is that I would like to know if I can make this function type stable by implementing changes only the du assignment.

For example, this could be a solution:

using Parameters
Params = @with_kw (

  N = 5,

  # functions sigma can be anything. In this case they are
  # equivalent and simple algebraic expressions.
  σ₁ =t -> 1. * t,
  σ₂ =t -> 1. * t,
  σ₃ =t -> 1. * t,
  σ₄ =t -> 1. * t,
  σ₅ =t -> 1. * t,
  σ′ = (σ₁, σ₂, σ₃, σ₄, σ₅),

  # Set this new parameter:
  σ = t -> [σ₁(t), σ₂(t), σ₃(t), σ₄(t), σ₅(t)],

  q = searchsortedfirst,

  δ  = 0.5,
  τ  = [δ for i in 1:N+1],
  T′ = cumsum(τ),
  τ′ = @view τ[2:end]
)

function drift!(du, u, p, t)

  @unpack (N, σ₁, σ₂, σ₃, σ₄, σ₅, σ′, σ, q, δ , τ , T′, τ′) = p

  # for i in 1:N
  #   du[i] = σ′[i](t) * u[i] * sum((τ′[j] * σ′[j](t) * u[j]) / (1 + τ′[j] * u[j]) for j = qt:i)
  # end

  for i in 1:N
    du[i] = σ(t)[i] * u[i] * sum((τ′[j] * σ(t)[j] * u[j]) / (1 + τ′[j] * u[j]) for j = q(T′, t):i)
  end
end

const p = Params()
const du = zeros(p.N)
const u = ones(p.N)

drift!(du, u, p, 0.1)

@code_warntype drift!(du, u, p, 0.1)

Now, this is type stable (as expected), but when I compute σ(t)[i] I am computing many numbers that won’t be used, making the function more expensive than it needs to be.

Any thoughts?

Finally, this is type stable by only changing parameters:

using Parameters
Params = @with_kw (

  N = 5,

  # functions sigma can be anything. In this case they are
  # equivalent and simple algebraic expressions.
  σ₁ =t -> 1. * t,
  σ₂ =t -> 1. * t,
  σ₃ =t -> 1. * t,
  σ₄ =t -> 1. * t,
  σ₅ =t -> 1. * t,
  σ′ = (σ₁, σ₂, σ₃, σ₄, σ₅),

  # σ = t -> [σ₁(t), σ₂(t), σ₃(t), σ₄(t), σ₅(t)],

  σ = (t, i) -> begin
    if i == 1
      return σ₁(t)
    elseif i == 2
      return σ₂(t)
    elseif i == 3
      return σ₃(t)
    elseif i == 4
      return σ₄(t)
    elseif i == 5
      return σ₅(t)
    end
  end,

  q = searchsortedfirst,

  δ  = 0.5,
  τ  = [δ for i in 1:N+1],
  T′ = cumsum(τ),
  τ′ = @view τ[2:end]
)

function drift!(du, u, p, t)

  @unpack (N, σ₁, σ₂, σ₃, σ₄, σ₅, σ′, σ, q, δ , τ , T′, τ′) = p

  # for i in 1:N
  #   du[i] = σ′[i](t) * u[i] * sum((τ′[j] * σ′[j](t) * u[j]) / (1 + τ′[j] * u[j]) for j = qt:i)
  # end

  # for i in 1:N
  #   du[i] = σ(t)[i] * u[i] * sum((τ′[j] * σ(t)[j] * u[j]) / (1 + τ′[j] * u[j]) for j = qt:i)
  # end

  for i in 1:N
    du[i] = σ(t, i) * u[i] * sum((τ′[j] * σ(t, j) * u[j]) / (1 + τ′[j] * u[j]) for j = q(T′, t):i)
  end
end

const p = Params()
const du = zeros(p.N)
const u = ones(p.N)

drift!(du, u, p, 0.1)

@code_warntype drift!(du, u, p, 0.1)

I wonder if you @yuyichao have a better approach. Thanks!