Constant propagation: Irrationals doing much better than Float64

@oschulz and I have been talking about a deterministic RNG, with the main application being to generate an arbitrary point in the support of a given distribution. Here’s our current implementation:

export FixedRNG
struct FixedRNG <: AbstractRNG end

Base.rand(::FixedRNG) = one(Float64) / 2
Random.randn(::FixedRNG) = zero(Float64)
Random.randexp(::FixedRNG) = one(Float64)

Base.rand(::FixedRNG, ::Type{T}) where {T<:Real} = one(T) / 2
Random.randn(::FixedRNG, ::Type{T}) where {T<:Real} = zero(T)
Random.randexp(::FixedRNG, ::Type{T}) where {T<:Real} = one(T)

# We need concrete type parameters to avoid ambiguity for these cases
for T in [Float16, Float32, Float64]
    @eval begin
        Base.rand(::FixedRNG, ::Type{$T}) = one($T) / 2
        Random.randn(::FixedRNG, ::Type{$T}) = zero($T)
        Random.randexp(::FixedRNG, ::Type{$T}) = one($T)
    end
end

In a quick test of this, we were very impressed with the constant propagation:

julia> @code_typed rand(FixedRNG(), TDist(Ο€))
CodeInfo(
1 ─     nothing::Nothing
└──     return 0.0
) => Float64

But such great results seem specific to irrationals. Here’s the same call with a Float64 instead:

Click to expand
julia> @code_typed rand(FixedRNG(), Dists.TDist(3.1))
CodeInfo(
1 ── %1  = Base.getfield(d, :Ξ½)::Float64
β”‚    %2  = Base.ne_float(%1, %1)::Bool
β”‚    %3  = Base.not_int(%2)::Bool
β”‚    %4  = Base.sub_float(%1, %1)::Float64
β”‚    %5  = Base.eq_float(%4, 0.0)::Bool
β”‚    %6  = Base.and_int(%5, true)::Bool
β”‚    %7  = Base.and_int(%6, true)::Bool
β”‚    %8  = Base.not_int(%7)::Bool
β”‚    %9  = Base.and_int(%3, %8)::Bool
└───       goto #3 if not %9
2 ──       goto #26
3 ── %12 = Base.getfield(d, :Ξ½)::Float64
└───       goto #8 if not true
4 ── %14 = Base.lt_float(0.0, %12)::Bool
└───       goto #6 if not %14
5 ──       goto #7
6 ── %17 = invoke Distributions.DomainError(%12::Any, "Chisq: the condition Ξ½ > zero(Ξ½) is not satisfied."::Any)::DomainError
β”‚          Distributions.throw(%17)::Union{}
└───       unreachable
7 ──       nothing::Nothing
8 ┄─       Distributions.nothing::Nothing
└───       goto #9
9 ──       goto #10
10 ─       goto #11
11 ─ %25 = Base.div_float(%12, 2.0)::Float64
└───       goto #18 if not true
12 ─ %27 = Base.lt_float(0.0, %25)::Bool
└───       goto #16 if not %27
13 ─ %29 = Base.lt_float(0.0, 2.0)::Bool
└───       goto #15 if not %29
14 ─       goto #17
15 ─ %32 = invoke Distributions.DomainError(2.0::Any, "Gamma: the condition ΞΈ > zero(ΞΈ) is not satisfied."::Any)::DomainError
β”‚          Distributions.throw(%32)::Union{}
└───       unreachable
16 ─ %35 = invoke Distributions.DomainError(%25::Any, "Gamma: the condition Ξ± > zero(Ξ±) is not satisfied."::Any)::DomainError
β”‚          Distributions.throw(%35)::Union{}
└───       unreachable
17 ─       nothing::Nothing
18 β”„       Distributions.nothing::Nothing
└───       goto #19
19 ─ %41 = %new(Distributions.Gamma{Float64}, %25, 2.0)::Distributions.Gamma{Float64}
└───       goto #20
20 ─       goto #21
21 ─ %44 = invoke Distributions.rand(rng::FixedRNG, %41::Distributions.Gamma{Float64})::Float64
└───       goto #22
22 ─ %46 = Base.getfield(d, :Ξ½)::Float64
β”‚    %47 = Base.div_float(%44, %46)::Float64
β”‚    %48 = Base.lt_float(%47, 0.0)::Bool
└───       goto #24 if not %48
23 ─       invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %47::Float64)::Union{}
└───       unreachable
24 ─ %52 = Base.Math.sqrt_llvm(%47)::Float64
└───       goto #25
25 ─       nothing::Nothing
26 β”„ %55 = Ο† (#2 => false, #25 => true)::Bool
β”‚    %56 = Ο† (#2 => true, #25 => false)::Bool
β”‚    %57 = Ο† (#2 => 1, #25 => %52)::Union{Float64, Int64}
└───       goto #28 if not %55
27 ─ %59 = Ο€ (%57, Float64)
β”‚    %60 = Base.div_float(0.0, %59)::Float64
└───       goto #31
28 ─       goto #30 if not %56
29 ─ %63 = Ο€ (%57, Int64)
β”‚    %64 = Base.sitofp(Float64, %63)::Float64
β”‚    %65 = Base.div_float(0.0, %64)::Float64
└───       goto #31
30 ─       Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{}
└───       unreachable
31 β”„ %69 = Ο† (#27 => %60, #29 => %65)::Float64
└───       return %69
) => Float64

After typing this, I thought to try Static.StaticFloat64s, which again are almost magical (cc @Zach_Christensen):

julia> @code_typed rand(FixedRNG(), TDist(static(3.1)))
CodeInfo(
1 ─     nothing::Nothing
└──     return 0.0
) => Float64

Can the Float64 call be improved, or is this a case where performance requires type-level values?

EDIT: Here are some performance results, varying between great and pretty bad:

julia> for j in 2:-1:-4
           Ξ½ = exp10(j)
           result = rand(FixedRNG(), StudentT(Ξ½))
           println("rand(FixedRNG(), StudentT($Ξ½)) == $result")
           @btime rand(FixedRNG(), StudentT(Ξ½)) setup=(Ξ½=$Ξ½)
           println()
           end
rand(FixedRNG(), StudentT(100.0)) == 0.0
  1.409 ns (0 allocations: 0 bytes)

rand(FixedRNG(), StudentT(10.0)) == 0.0
  1.408 ns (0 allocations: 0 bytes)

rand(FixedRNG(), StudentT(1.0)) == 0.0
  43.211 ns (0 allocations: 0 bytes)

rand(FixedRNG(), StudentT(0.1)) == 0.0
  42.997 ns (0 allocations: 0 bytes)

rand(FixedRNG(), StudentT(0.01)) == 0.0
  42.757 ns (0 allocations: 0 bytes)

rand(FixedRNG(), StudentT(0.001)) == NaN
  32.614 ns (0 allocations: 0 bytes)

rand(FixedRNG(), StudentT(0.0001)) == NaN
  32.609 ns (0 allocations: 0 bytes)
1 Like

There was some related discussion on Slack recently
slack thread – will go away and comments in this issue.

It would be much better if Julia specified that our essential machine number types when called as constructors return values of that same type (or fail, Int(β€œab”)).

There is no good reason to add a method like this:
(while it remains possible, propagation and type handling do as you have seen.)

julia> Base.Float32(x::Int) = isinf(1/(x-2.0)) ? 2.0 : 
                                    Float32(Float64(x))

julia> (typeof(Float32(5)), Float32(5)), 
       (typeof(Float32(2)), Float32(2))
((Float32, 5.0f0), (Float64, 2.0))

I agree calling a constructor should return a value of that type. But I don’t see how that relates to this post. Could you elaborate?

See this:

1 Like

Right, but what does this have to do with the OP?

My understanding is that safe, consistent, and, as appropriate, deep const (or const-adjacent, given the conciseness of better constrained type resolution) propagation of explicitly given machine ints and floats is something known that becomes available to do once these constructive return types assuredly match their constructor.
Is there willingness to enforce the rule on those constructors?
Is there some insight that makes it easier to do neatly and correctly?

Sorry I still don’t understand how this connects. Could you point to the line(s) of my code you’re talking about?

Actually, no. afaik, it is nothing you are doing.

Is it related to this post?

Irrationals are kept similarly to Val{}s.

julia> Ο€
Ο€ = 3.1415926535897...

julia> dump(Ο€)
Irrational{:Ο€} Ο€

julia> dump(Val{:Ο€}()
Val{:Ο€} Val{:Ο€}()

Here is your code,
once with the irrational Ο€ and once with the Float64 pi.

in the Arguments, notice d
in the Body, notice
%2 and that
%2 forms %3 to be either a Const or a Bool
as a const, it is known at runtime
as a bool, it must be resolved at runtime
%13 and that
%13 forms %14 to be either a Const or a Float64
as a const, it is known at runtime
as a float, it must be resolved at runtime


julia> @code_warntype rand(FixedRNG(), TDist(Ο€))
MethodInstance for rand(::FixedRNG, ::TDist{Irrational{:Ο€}})
  from rand(rng::AbstractRNG, d::TDist) @ Distributions _/tdist.jl:82
Arguments
  #self#::Core.Const(rand)
  rng::Core.Const(FixedRNG())
  d::Core.Const(TDist{Irrational{:Ο€}}(Ξ½=Ο€))
Locals
  @_4::Float64
Body::Float64
1 ─ %1  = Distributions.randn(rng)::Core.Const(0.0)
β”‚   %2  = Base.getproperty(d, :Ξ½)::Core.Const(Ο€)
β”‚   %3  = Distributions.isinf(%2)::Core.Const(false)
└──       goto #3 if not %3
2 ─       Core.Const(:(@_4 = 1))
└──       Core.Const(:(goto %13))
3 β”„ %7  = Base.getproperty(d, :Ξ½)::Core.Const(Ο€)
β”‚   %8  = Distributions.Chisq(%7)::Core.Const(Chisq{Irrational{:Ο€}}(Ξ½=Ο€))
β”‚   %9  = Distributions.rand(rng, %8)::Core.Const(2.141592653589793)
β”‚   %10 = Base.getproperty(d, :Ξ½)::Core.Const(Ο€)
β”‚   %11 = (%9 / %10)::Core.Const(0.6816901138162094)
β”‚         (@_4 = Distributions.sqrt(%11))
β”‚   %13 = @_4::Core.Const(0.8256452711765564)
β”‚   %14 = (%1 / %13)::Core.Const(0.0)
└──       return %14


julia> @code_warntype rand(FixedRNG(), TDist(pi))
MethodInstance for rand(::FixedRNG, ::TDist{Float64})
  from rand(rng::AbstractRNG, d::TDist) @ Distributions _/tdist.jl:82
Arguments
  #self#::Core.Const(rand)
  rng::Core.Const(FixedRNG())
  d::TDist{Float64}
Locals
  @_4::Union{Float64, Int64}
Body::Float64
1 ─ %1  = Distributions.randn(rng)::Core.Const(0.0)
β”‚   %2  = Base.getproperty(d, :Ξ½)::Float64
β”‚   %3  = Distributions.isinf(%2)::Bool
└──       goto #3 if not %3
2 ─       (@_4 = 1)
└──       goto #4
3 ─ %7  = Base.getproperty(d, :Ξ½)::Float64
β”‚   %8  = Distributions.Chisq(%7)::Chisq{Float64}
β”‚   %9  = Distributions.rand(rng, %8)::Float64
β”‚   %10 = Base.getproperty(d, :Ξ½)::Float64
β”‚   %11 = (%9 / %10)::Float64
└──       (@_4 = Distributions.sqrt(%11))
4 β”„ %13 = @_4::Union{Float64, Int64}
β”‚   %14 = (%1 / %13)::Float64
└──       return %14
1 Like

This looks like something that’s not guaranteed to return a value of type T. How is the performance of randn?

Are you sure? [Well, except for integer types or irrationals, which are all defined to return Float64 (or BigFloat).]

julia> one(Posit8) / 2
Posit8(0.5)

However:

julia> one(Posit8) * 0.5
ERROR: promotion of types Posit8 and Float64 failed to change any arguments

And unlike Julia/IEEE floats, this not defined (nor in the 2022 Posit standard):

julia> one(Posit8) Γ· 2
ERROR: rem not defined for Posit8

The issue lies in Distributions.jl: rand(::FixedRNG, ::Gamma) is somewhat recursive, which (I guess) causes Julia to be unable to infer that the function terminates.

I see two ways to fix this:

  1. using @assume_effects - ugly but it should work if all else fails
  2. refactor the code a bit to eliminate the recursion. I’ll try to do this now

I made two pull request to Distributions.jl. Although the rand is now (after applying both patches) foldable according Julia’s effects inference, it seem that it still isn’t being replaced with a constant. Maybe I look into it later.

Still, try out my pull requests and see how much they improve the situation, if you’re interested.

1 Like

Thanks. Yes, that’s also why Static.static(3.1) is able to handle it just as well. What I wasn’t getting is why the non-static TDist argument makes things so slow, even with a deterministic RNG.

julia> @btime randn(FixedRNG(), Float64)
  1.334 ns (0 allocations: 0 bytes)
0.0

julia> @btime randn(FixedRNG(), Float32)
  1.335 ns (0 allocations: 0 bytes)
0.0f0

julia> @btime randn(FixedRNG(), Float16)
  1.334 ns (0 allocations: 0 bytes)
Float16(0.0)

And just for good measure (haha)

julia> @btime randn(FixedRNG(), BigFloat)
  23.285 ns (2 allocations: 104 bytes)
0.0

Great, thank you! Since the value depends on the value of Ξ½ (the result is sometimes NaN) it can’t be entirely constant. But I’d expect it to be a lot faster than it started out, with smaller generated code, even if it can’t go away completely.

1 Like

How does possibility of NaNs impede constant-folding? I mean NaNs are still the same type in Julia, so I don’t see how they’re relevant. Even throwing exceptions supposedly shouldn’t impede constant folding, according to the docs: Essentials Β· The Julia Language

It’s nothing about NaN’s in particular. I just mean that if the result depends on the value of Ξ½ instead of just its type, then the compiled code can’t entirely get rid of the dependence on Ξ½. But I would expect the RNG dependence to be gone, and the resulting code to be relatively small and very fast.

I guess to really test constant folding, we should be looking at code where Ξ½ is also constant, but not determined statically by its type.

1 Like