I’m having a hard time writing a type-stable function where one branch calls a complicated function and another one doesn’t. Here’s a minified example:

using Distributions
function f(μ::Real, σ::Real, a::Real)
if isnan(μ) || isnan(σ)
return 0.0
else
return pdf(Normal(μ, σ), a)
end
end

Don’t ask why I want to return 0 when mu or sigma are NaN! This is a minified example, but hopefully the semantics I’m trying to achieve in this case are clear even if unmotivated.

I would like f to be type stable. However, you’ve probably spotted an issue: the 0.0 is a Float64. So a call such as f(3.0f32, 4.0f32, 1.0f32) would not be type stable. I could do a heuristic like zero(promote_type(typeof(μ), typeof(σ))) in the first branch. But this means I basically have to reverse engineer the promotion behaviour of pdf and Normal! (To see why it’s hard: do you know if the heuristic I just wrote is correct? Maybe the type of a factors in too?) It’s hard to scale this to more complicated functions.

I would just like to tell the compiler: do the reasonable thing here… whatever type the other branch would normally be, make the first branch a zero of that type, so that f can be type stable when possible. Is there a way to express this ergonomically?

This is a weak point in Julia’s goal of type-generic code.

If the cost of pdf is small or the NaN branch will be taken rarely, you could do something simple like

function f(μ::Real, σ::Real, a::Real)
p = pdf(Normal(μ, σ), a)
if isnan(μ) || isnan(σ)
return zero(p)
else
return p
end
end

I didn’t bother to load up Distributions.jl for this, but when I used muladd(μ, σ, a) instead of pdf and looked at the @code_native, it had actually optimized itself so that it only computed p when it really needed the value. For the zero branch, it simply returned zero of the appropriate type without computing the result of the function call. The same might happen in your example, depending on how much the compiler can infer about pdf.

Otherwise this gets more difficult, relying on introspection. Base.return_typesmay be helpful, but really it’s not a function that should be relied upon (it’s also not public API so may break in the future) and it requires some gymnastics to use in a type-stable way. Hopefully someone else can offer something better than that.

Indeed it’s tricky! I guess it’s been discussed plenty how easy it is to write non-type-stable code, I hope that some of the discussion about better static systems in Julia can lead to a solution to this problem

Type-stability is not possible in a strong definition (at least as it is in Rust) because Julia is a dynamic language. If I am not getting wrong, how about adding some constraints:

using Distributions
function f(μ::T, σ::T, a::T)::T where {T <: Real}
if isnan(μ) || isnan(σ)
return zero(T)
else
return pdf(Normal(μ, σ), a)
end
end

using Distributions
function f(μ::Real, σ::Real, a::Real)::Float64
if isnan(μ) || isnan(σ)
return 0.0
else
return convert(Float64, pdf(Normal(μ, σ), a))::Float64
end
end

But that just turns everything into Float64. I doubt that is desirable, in the general case. What if the input is a dual number or a complex number or a BigFloat, for example? (The example specifies Real inputs, but the question is, I believe, intended to broader than the specific example.)

I’d like to add that “small type instabilities” (that is Union{..} of only a few types) are not that bad for performance since Julia has union-splitting (see this excellent Blog post by @tim.holy for explanation). So it is not necessary to try to make functions 100% type stable all of the time.

Not having concrete type-inference may complicate downstream inference, though.

julia> f(x) = x > 1 ? 0.0 : x # a simple function that returns one of two types
f (generic function with 1 method)
julia> @code_typed f(1)
CodeInfo(
1 ─ %1 = Base.slt_int(1, x)::Bool
└── goto #3 if not %1
2 ─ return 0.0
3 ─ return x
) => Union{Float64, Int64}
julia> @code_typed (() -> [f(i) for i in 1:1])()
CodeInfo(
1 ─ goto #3 if not true
2 ─ nothing::Nothing
3 ┄ %3 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Vector{Int64}, svec(Any, Int64), 0, :(:ccall), Vector{Int64}, 1, 1))::Vector{Int64}
│ Base.arrayset(true, %3, 1, 1)::Vector{Int64}
│ %5 = invoke Base.collect_to!(%3::Vector{Int64}, $(QuoteNode(Base.Generator{UnitRange{Int64}, var"#8#10"}(var"#8#10"(), 1:1)))::Base.Generator{UnitRange{Int64}, var"#8#10"}, 2::Int64, 1::Int64)::Union{Vector{Int64}, Vector{Real}}
└── goto #4
4 ─ return %5
) => Union{Vector{Int64}, Vector{Real}}

Perhaps the Vector{Real} may be narrowed to Vector{Float64} with some careful promotion, but this may fail to return concrete element types in general. There are also cases where one might desire concrete type-inference to avoid any allocation.

I agree that this is bad if you put values into containers which then need to have abstract types.
I am not sure what you mean with :

Doing

julia> f(x) = x > 5 ? 0.0 : x # type unstable
julia> g(::Float64) = "float"
julia> g(::Int) = "int"
julia> h(x) = g(f(x))
julia> using BenchmarkTools; @btime h(x) setup=(x=rand(1:10);)
2.305 ns (0 allocations: 0 bytes)

So no allocations. Union-splitting is not to be confused with dynamic dispatch where Julia needs to do dispatch entirely at runtime. For union-splitting all possible types and the corresponding methods are fully inferred and simply the correct one is selected by a conditional. In the simple example above, each of the calls to g is even fully inlined:

Sorry, I think I was still thinking of containers, as is the case when multiple values are returned from a function. E.g.:

julia> f(i) = i > 0 ? i : 0.0
f (generic function with 1 method)
julia> g(i) = f(i), nothing
g (generic function with 1 method)
julia> @btime g(i) setup=(i=rand(1:4))
74.393 ns (1 allocation: 16 bytes)
(2, nothing)

I have multiple situations like this in code that I write for stochastic simulation and inference. There’s a hack that I use because it makes it so everything is type stable for ForwardDiff. I select one of the arguments of the function and use its type as a parameter. I make sure it’s a argument that will be autodiff’d. I guess in your case I would hack it around like this:

using Distributions
function f(μ::T, σ::Real, a::Real) where T<:Real
res = zero(T)
if !(isnan(μ) || isnan(σ))
res += pdf(Normal(μ, σ), a)
end
return res
end

@Vasily_Pisarev 's way of doing this is the correct one (imo) but I am often too lazy to apply it everywhere. I should probably refactor some more stuff…