How to write type stable code with branches, where one branch calls a complicated function?

I’m having a hard time writing a type-stable function where one branch calls a complicated function and another one doesn’t. Here’s a minified example:

using Distributions

function f(μ::Real, σ::Real, a::Real)
    if isnan(μ) || isnan(σ)
        return 0.0
    else
        return pdf(Normal(μ, σ), a) 
    end
end

Don’t ask why I want to return 0 when mu or sigma are NaN! This is a minified example, but hopefully the semantics I’m trying to achieve in this case are clear even if unmotivated.

I would like f to be type stable. However, you’ve probably spotted an issue: the 0.0 is a Float64. So a call such as f(3.0f32, 4.0f32, 1.0f32) would not be type stable. I could do a heuristic like zero(promote_type(typeof(μ), typeof(σ))) in the first branch. But this means I basically have to reverse engineer the promotion behaviour of pdf and Normal! (To see why it’s hard: do you know if the heuristic I just wrote is correct? Maybe the type of a factors in too?) It’s hard to scale this to more complicated functions.

I would just like to tell the compiler: do the reasonable thing here… whatever type the other branch would normally be, make the first branch a zero of that type, so that f can be type stable when possible. Is there a way to express this ergonomically?

3 Likes

This is a weak point in Julia’s goal of type-generic code.

If the cost of pdf is small or the NaN branch will be taken rarely, you could do something simple like

function f(μ::Real, σ::Real, a::Real)
    p = pdf(Normal(μ, σ), a)
    if isnan(μ) || isnan(σ)
        return zero(p)
    else
        return p
    end
end

I didn’t bother to load up Distributions.jl for this, but when I used muladd(μ, σ, a) instead of pdf and looked at the @code_native, it had actually optimized itself so that it only computed p when it really needed the value. For the zero branch, it simply returned zero of the appropriate type without computing the result of the function call. The same might happen in your example, depending on how much the compiler can infer about pdf.

Otherwise this gets more difficult, relying on introspection. Base.return_types may be helpful, but really it’s not a function that should be relied upon (it’s also not public API so may break in the future) and it requires some gymnastics to use in a type-stable way. Hopefully someone else can offer something better than that.

6 Likes

Indeed it’s tricky! I guess it’s been discussed plenty how easy it is to write non-type-stable code, I hope that some of the discussion about better static systems in Julia can lead to a solution to this problem :slightly_smiling_face:

Perhaps using ifelse may help

Type-stability is not possible in a strong definition (at least as it is in Rust) because Julia is a dynamic language. If I am not getting wrong, how about adding some constraints:

using Distributions

function f(μ::T, σ::T, a::T)::T where {T <: Real}
    if isnan(μ) || isnan(σ)
        return zero(T)
    else
        return pdf(Normal(μ, σ), a) 
    end
end

supposing all of the parameters in same type.

1 Like

I will just add that type-genericity in Distributions.jl is still very much a work in progress: see Fitting does not respect type parameters · Issue #1544 · JuliaStats/Distributions.jl · GitHub for an example

2 Likes

You could just convert and type assert.

using Distributions

function f(μ::Real, σ::Real, a::Real)::Float64
    if isnan(μ) || isnan(σ)
        return 0.0
    else
        return convert(Float64, pdf(Normal(μ, σ), a))::Float64
    end
end
2 Likes

But that just turns everything into Float64. I doubt that is desirable, in the general case. What if the input is a dual number or a complex number or a BigFloat, for example? (The example specifies Real inputs, but the question is, I believe, intended to broader than the specific example.)

2 Likes

Seconding @jbytecode.
A pattern I often have is

function f(μ::T, σ::T, a::T) where {T<:Real}
    ...
end

function f(μ::Real, σ::Real, a::Real)
    f(promote(μ, σ, a)...)
end
6 Likes

The promote pattern only really works with numbers, and usually won’t work with other types such as AbstractVectors

I’d like to add that “small type instabilities” (that is Union{..} of only a few types) are not that bad for performance since Julia has union-splitting (see this excellent Blog post by @tim.holy for explanation). So it is not necessary to try to make functions 100% type stable all of the time.

2 Likes

Not having concrete type-inference may complicate downstream inference, though.

julia> f(x) = x > 1 ? 0.0 : x # a simple function that returns one of two types
f (generic function with 1 method)

julia> @code_typed f(1)
CodeInfo(
1 ─ %1 = Base.slt_int(1, x)::Bool
└──      goto #3 if not %1
2 ─      return 0.0
3 ─      return x
) => Union{Float64, Int64}

julia> @code_typed (() -> [f(i) for i in 1:1])()
CodeInfo(
1 ─      goto #3 if not true
2 ─      nothing::Nothing
3 ┄ %3 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Int64}, svec(Any, Int64), 0, :(:ccall), Vector{Int64}, 1, 1))::Vector{Int64}
│        Base.arrayset(true, %3, 1, 1)::Vector{Int64}
│   %5 = invoke Base.collect_to!(%3::Vector{Int64}, $(QuoteNode(Base.Generator{UnitRange{Int64}, var"#8#10"}(var"#8#10"(), 1:1)))::Base.Generator{UnitRange{Int64}, var"#8#10"}, 2::Int64, 1::Int64)::Union{Vector{Int64}, Vector{Real}}
└──      goto #4
4 ─      return %5
) => Union{Vector{Int64}, Vector{Real}}

Perhaps the Vector{Real} may be narrowed to Vector{Float64} with some careful promotion, but this may fail to return concrete element types in general. There are also cases where one might desire concrete type-inference to avoid any allocation.

I agree that this is bad if you put values into containers which then need to have abstract types.
I am not sure what you mean with :

Doing

julia> f(x) = x > 5 ? 0.0 : x # type unstable
julia> g(::Float64) = "float"
julia> g(::Int) = "int"
julia> h(x) = g(f(x))
julia> using BenchmarkTools; @btime h(x) setup=(x=rand(1:10);)
  2.305 ns (0 allocations: 0 bytes)

So no allocations. Union-splitting is not to be confused with dynamic dispatch where Julia needs to do dispatch entirely at runtime. For union-splitting all possible types and the corresponding methods are fully inferred and simply the correct one is selected by a conditional. In the simple example above, each of the calls to g is even fully inlined:

julia> @code_native debuginfo=:none dump_module=:false h(5)
	.text
	pushq	%rbp
	movq	%rsp, %rbp
	cmpq	$6, %rdi
	movabsq	$140073470456400, %rcx          # imm = 0x7F656572F250
	movabsq	$140073446523592, %rax          # imm = 0x7F656405C2C8
	cmovlq	%rcx, %rax
	popq	%rbp
	retq

Sorry, I think I was still thinking of containers, as is the case when multiple values are returned from a function. E.g.:

julia> f(i) = i > 0 ? i : 0.0
f (generic function with 1 method)

julia> g(i) = f(i), nothing
g (generic function with 1 method)

julia> @btime g(i) setup=(i=rand(1:4))
  74.393 ns (1 allocation: 16 bytes)
(2, nothing)

This is because

julia> @code_typed g(1)
CodeInfo(
1 ─ %1 = Base.slt_int(0, i)::Bool
└──      goto #3 if not %1
2 ─      goto #4
3 ─      goto #4
4 ┄ %5 = φ (#2 => i, #3 => 0.0)::Union{Float64, Int64}
│   %6 = Core.tuple(%5, Main.nothing)::Tuple{Union{Float64, Int64}, Nothing}
└──      return %6
) => Tuple{Union{Float64, Int64}, Nothing}

where the return type is a Tuple containing a Union, and not a Union of simple Tuples.

1 Like

I have multiple situations like this in code that I write for stochastic simulation and inference. There’s a hack that I use because it makes it so everything is type stable for ForwardDiff. I select one of the arguments of the function and use its type as a parameter. I make sure it’s a argument that will be autodiff’d. I guess in your case I would hack it around like this:

using Distributions

function f(μ::T, σ::Real, a::Real) where T<:Real
    res = zero(T)
    if !(isnan(μ) || isnan(σ))
        res += pdf(Normal(μ, σ), a) 
    end
    return res
end

@Vasily_Pisarev 's way of doing this is the correct one (imo) but I am often too lazy to apply it everywhere. I should probably refactor some more stuff…