Help optimising an expectation function

Hey everyone,

Apologies for the vague title, I think part of my problem is I don’t know the words for what I’m doing so I can’t search for it. Anyway:

I have a function of N boolean arguments. I would like to be able to call the function with floats between 0 and 1, interpreted as taking the expectation of the function evaluated at 1 and 0 weighted by the float value.

For example:

struct CallableInt{N}
    int::Int64
end

function read_bit(x, pos)
    pos < 0 && throw(ArgumentError("Bit position $pos must be >=0"))
    return (x >> pos) & 1 != 0
end

(f::CallableInt{N})(xs::Vararg{I,N}) where {N,I<:Integer} = read_bit(f.int, evalpoly(2, xs))

(f::CallableInt{1})(foo) =  foo*f(true) + (1-foo) * f(false)

function (f::CallableInt{2})(foo, bar)
    foo * (
        bar * f(true, true) + (1-bar) * f(true, false)
    ) +
    (1-foo) * (
        bar * f(false, true) + (1-bar) * f(false, false)
    )
end

The above are very fast and I’m not trying to optimize these, I’m trying instead to write a function to do the above generic in N which is also as fast as a handwritten version.

My attempt at doing so falls short:

function (f::CallableInt{N})(args::Vararg{Float64,N}) where {N}
    out = 0.0
    all(0 .<= args .<= 1) || throw(ArgumentError("All arguments must be between 0 and 1"))
    rate_combinations = NTuple{N,NTuple{2,Float64}}((arg, 1 - arg) for arg in args)
    boolean_combinations = NTuple{N,NTuple{2,Bool}}((true, false) for _ in args)
    for (boolean_combination, rate_combination) in zip(
        Iterators.product(boolean_combinations...),
        Iterators.product(rate_combinations...)
    )
        out += f(boolean_combination...) * reduce(*, rate_combination)
    end
    return out
end

It returns the right value, but takes very long to execute. I was wondering if code generation is the right way to go, but I know that it’s a last resort in most cases. I would appreciate any advice on speeding this function up.

Edit: fixed bug in type hints.

Can you check that your code runs, and provide sample input? From your description it sounds like this should accept an arbitrary function somehow.

Here’s some example uses, it’s kinda like fuzzy logic. But you’re right that it could be generalised to arbitrary functions. In my case I just wanted to limit the functions to between 0 and 1 and fix the number of inputs they took (and I don’t really know how to do that any other way).

f = CallableInt{2}(0b1001) # 0b1001 = unsigned 9
f(true, true) == 1
f(true, false) == 0
f(false, true) == 0
f(false, false) == 1
f(0.99, 0.99) == 0.9802
f(0.2, 0.8) == 0.32

There was also a bug in the source as you implied and I fixed this. Served me right for trying to type hint my functions :')

Here’s what I was trying, maybe with the bugs removed…

"""
    callall(f, p) = p * f(true) + (1-p) * f(false)
    callall(f, p, q) = p*q * f(true, true) + p*(1-q) * f(true, false) + 
                   (1-p)*q * f(false, true) + (1-p)*(1-q) * f(false, false)
    callall(f, probs::Real...) = ...

Evaluates a function `f(::Bool...)` accepting `N` arguments `2^N` times,
and adds up the results weighted by the given probabilities as shown.
"""
function callall(f::F, args::Vararg{Real,N}) where {F,N}
    for a in args
        0 <= a <= 1 || error("out of range")
    end
    out = zero(promote_type(map(typeof, args)...))
    for n in 1:2^N
        bools = ntuple(i -> Bool((n >> (i-1)) & 1), N)
        val = f(bools...)
        weight = prod(map((b,a) -> ifelse(b, a, 1-a), bools, args))
        # @show n bools val weight
        out += val * weight
    end
    out
end

callall(-, 0.3) == -0.3
callall(!, 0.3) == 0.7
callall(*, 0.3, 0.9) == 0.27

f1 = CallableInt{1}(0b1001)  # using above code
f1.([true, false])  # this is the function !

f1(0.3) == callall(f1, 0.3) == callall(!, 0.3)

f = CallableInt{2}(0b1001) # 0b1001 = unsigned 9
f.([true, false], [true false])  # truth table, this is the function ==

f(0.2, 0.8) == 0.32 == callall(f, 0.2, 0.8) == callall(==, 0.2, 0.8)

callall(Returns(true), rand(5)...) ≈ 1
callall(Returns(false), rand(5)...) == 0

@btime callall(*, $(Tuple(rand(3)))...);  # 5ns
@btime callall(*, $(Tuple(rand(5)))...);  # 50ns

2 Likes

Wow this is a really clever implementation. I especially like this line

bools = ntuple(i -> Bool((n >> (i-1)) & 1), N)

I thought about iterating over 1:2^N but couldn’t think up how to get the bools from that.

Is Vararg an exception when it comes to abstractly typed typehints? I thought you should avoid at all costs writing abstract types in function arguments?

Are you just covering your bases with this comment or are there some bugs that you encountered while running it?

It had all sorts of off-by-one errors at first!

I think dispatching on g(x::Real) is totally fine, and even g(x...) which is Vararg{Any}. It’s things like Real[1,2] or structs which are the problem.

And I think my bools is the same trick as your read_bit?

I thought there would be more room for improvement, but at first I was just timing things wrong. Both are down to 1.5-2ns per evaluation:

julia> f5 = CallableInt{5}(0b1001);
julia> x5 = Tuple(rand(5));

julia> @btime $f5($x5...)  # with zip(Iterators.product(...
  70.782 ns (0 allocations: 0 bytes)
0.010132227405239502

julia> @btime callall($f5, $x5...)  # only for & ntuple
  56.698 ns (0 allocations: 0 bytes)
0.010132227405239502

julia> @btime callall(+, $x5...)  # cheaper function
  47.697 ns (0 allocations: 0 bytes)
2.321889303862584

Maybe also worth noting that N real numbers for N arguments gives one possible weighting, the one where the probability factorises.

1 Like

Highly forgivable when working with binary numbers due to the nature of Julia’s 1-base indexing!

Perhaps, but I hadn’t thought to apply it in this way.

There is nothing wrong with abstractly typed function arguments. In Julia a function has several methods, e.g.,

julia> foo(x::Real) = 2 * x
foo (generic function with 1 method)

julia> foo(x::String) = "2" * x
foo (generic function with 2 methods)

julia> meths = methods(foo)
# 2 methods for generic function "foo":
[1] foo(x::Real) in Main at REPL[1]:1
[2] foo(x::String) in Main at REPL[2]:1

Here the types are used for dispatch, e.g., if foo(1.0) is called the first method would be used. Furthermore, this method then gets specialized for the actual concrete type of the call Float64:

julia> meths[1].specializations
svec()

julia> foo(1.0)
2.0

julia> meths[1].specializations
svec(MethodInstance for foo(::Float64), nothing, nothing, nothing, nothing, nothing, nothing, nothing)

These specializations contain the actual compiled code of a method and in particular, are specialized for each specific concrete type the method was called with. This is what makes Julia fast and is also listed in the performance tips as function barriers.

The problem arises, when the compiler cannot infer the concrete type of a variable which commonly arises when structs have untyped or abstractly typed fields. In this case, the dispatch logic has to be included into the compiled code in order to select the appropriate specialization:

julia> function known()
           args = (1.0, 1)
           foo.(args)
       end
known (generic function with 1 method)

# Here the compiler can infer the types of the call and select appropriate specialization beforehand
julia> @code_warntype known()
MethodInstance for known()
  from known() in Main at REPL[18]:1
Arguments
  #self#::Core.Const(known)
Locals
  args::Tuple{Float64, Int64}
Body::Tuple{Float64, Int64}
1 ─      (args = Core.tuple(1.0, 1))
│   %2 = Base.broadcasted(Main.foo, args::Core.Const((1.0, 1)))::Core.Const(Base.Broadcast.Broadcasted(foo, ((1.0, 1),)))
│   %3 = Base.materialize(%2)::Core.Const((2.0, 2))
└──      return %3

julia> function unknown()
           args = Any[1.0, "3"]
           foo.(args)
       end
unknown (generic function with 1 method)

# Here, the dispatch needs to be done at runtime
julia> @code_warntype unknown()
MethodInstance for unknown()
  from unknown() in Main at REPL[20]:1
Arguments
  #self#::Core.Const(unknown)
Locals
  args::Vector{Any}
Body::AbstractVector
1 ─      (args = Base.getindex(Main.Any, 1.0, "3"))
│   %2 = Base.broadcasted(Main.foo, args)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(foo), Tuple{Vector{Any}}}
│   %3 = Base.materialize(%2)::AbstractVector
└──      return %3
2 Likes

Apologies for the revival for those who get a notification for this, but there was one more optimisation:
changing

weight = prod(map((b,a) -> ifelse(b, a, 1-a), bools, args))

to one of

weight = prod(ifelse(b, a, 1-a) for (b, a) in zip(bools, args))
weight = mapreduce((b, a) -> ifelse(b, a, 1 - a), *, bools, args)

doesn’t initialise the vector before reduction and so runs at almost identical speed to the manually written one. While it’s not a zero-cost abstraction, it is a mere 30 nanosecond-cost abstraction.

There was, after all, a much faster version of this function waiting to be released:

using BenchmarkTools

# With help from @mcabbott
function original_lerp(f, args::Vararg{Real,N}) where {N}
    for a in args
        0 <= a <= 1 || error("out of range")
    end
    out = zero(promote_type(map(typeof, args)...))
    for n in 1:(2^N)
        bools = ntuple(i -> Bool((n >> (i - 1)) & 1), N)
        val = f(bools...)
        weight = prod(ifelse(b, a, 1 - a) for (b, a) in zip(bools, args))
        out += val * weight
    end
    return out
end

function lerp(f, args::Vararg{Real,N}) where {N}
    for a in args
        0 <= a <= 1 || error("out of range")
    end
    head, tail... = args
    ftrue(t...) = f(true, t...)
    ffalse(t...) = f(false, t...)
    if isempty(tail)
        return head * f(true) + (1 - head) * f(false)
    elseif head isa Integer
        if head == true
            return head * lerp(ftrue, tail...)
        else
            return (1-head) * lerp(ffalse, tail...)
        end
    else
        return head * lerp(ftrue, tail...) + (1 - head) * lerp(ffalse, tail...)
    end
end

testf(xs...) = sum(xs)
# A large enough input not to be const-folded in `lerp`
input = (0, 0.5, 1, 0.5, 1, 1, 0, 0.5, 0.5, 0.5)
@benchmark original_lerp(testf, input...)
@benchmark lerp(testf, input...)
Benchmark results
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  67.208 μs …  3.466 ms  ┊ GC (min … max): 0.00% … 96.75%
 Time  (median):     69.250 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   73.034 μs ± 59.873 μs  ┊ GC (mean ± σ):  1.26% ±  1.67%

  ▄██▅▅▅▅▄▄▄▄▅▅▄▃▂▂▁▁▁▁                                       ▂
  ███████████████████████████▇█▇▇██▇▇▆▇▇▇▆▅▆▆▆▇██▇▆▆▆▆▆▅▅▆▆▅▃ █
  67.2 μs      Histogram: log(frequency) by time      95.6 μs <

 Memory estimate: 15.73 KiB, allocs estimate: 658.

BenchmarkTools.Trial: 3426 samples with 1 evaluation.
 Range (min … max):  1.227 ms …   4.566 ms  ┊ GC (min … max): 0.00% … 69.82%
 Time  (median):     1.319 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.458 ms ± 614.546 μs  ┊ GC (mean ± σ):  9.36% ± 14.53%

  ▅█▆▃                                                        ▁
  ████▆▃▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▅▇█▇█▆▆▅▆█ █
  1.23 ms      Histogram: log(frequency) by time      4.38 ms <

 Memory estimate: 2.11 MiB, allocs estimate: 54291.

Oops I accidentally a macro
Edit: Oops I accidentally bad at Julia. I don’t think my benchmarks are at all accurate because nothing is escaped. I’m not exactly using macros properly here (because I don’t know how to). I’d appreciate any help…

macro lerp(f, varying::Real...)
    lerp_rec(f, Tuple{}(), varying...)
end

function construct_lerp_exp(maybe_bool, true_exp, false_exp)
    if maybe_bool == true
        exp = :($true_exp)
    elseif maybe_bool == false
        exp = :($false_exp)
    else
        exp = :($true_exp + $false_exp)
    end
    return exp
end

function lerp_rec(f, fixed::Tuple{Vararg{Integer}}, varying::Real)
    true_exp = :($varying * $f($(fixed...), true))
    false_exp = :($(1-varying) * $f($(fixed...), false))
    return construct_lerp_exp(varying, true_exp, false_exp)
end

function lerp_rec(f, fixed::Tuple{Vararg{Integer}}, varying::Real...)
    true_exp = :($(varying[1]) * $(lerp_rec(f, (fixed..., true), Base.tail(varying)...)))
    false_exp = :($(1 - varying[1]) * $(lerp_rec(f, (fixed..., false), Base.tail(varying)...)))
    return construct_lerp_exp(varying[1], true_exp, false_exp)
end
New benchmark
struct SR{N}
    int
end

function read_bit(x, pos)
    pos < 0 && throw(ArgumentError("Bit position $pos must be >=0"))
    return (x >> pos) & 1 != 0
end

(rule::SR{N})(xs::Vararg{I,N}) where {N,I<:Integer} = read_bit(rule.int, evalpoly(2, xs))

sr5 = SR{5}(195)

@benchmark auto(sr5, true, false, true, false, true)
BenchmarkTools.Trial: 10000 samples with 10 evaluations.

 Range (min … max):  1.812 μs …  11.375 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     1.858 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.926 μs ± 230.515 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▁▆██▅▄▄▂▂▂▂▁▂▄▄▄▄▃                                          ▂
  ██████████████████▇▆▆▅▆▆▅▇▅▆▆▅▆▆▆▆▆▆▆▇▇▆▆▆▆▅▆▅▆▅▆▆▆▅▅▆▆▇▇▇▆ █
  1.81 μs      Histogram: log(frequency) by time      2.62 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

@benchmark fauto(sr5, true, false, true, false, true)
BenchmarkTools.Trial: 10000 samples with 54 evaluations.
 Range (min … max):  886.574 ns …  25.325 μs  ┊ GC (min … max): 0.00% … 95.21%
 Time  (median):     919.759 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   941.803 ns ± 329.273 ns  ┊ GC (mean ± σ):  0.48% ±  1.34%

  ▁▃▂▄▄▇█▅▄▄▃▁▁▂▂▁ ▁▁▁▁▁                                        ▂
  █████████████████████████▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▇▆▆▆▇█▇▆▆▆▆▅▆▅▅▅▅▆▅▄▅ █
  887 ns        Histogram: log(frequency) by time        1.2 μs <

 Memory estimate: 224 bytes, allocs estimate: 14.

@benchmark @lerp(sr5, true, false, true, false, true)
BenchmarkTools.Trial: 10000 samples with 951 evaluations.
 Range (min … max):   96.828 ns … 277.647 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):      99.632 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   101.619 ns ±   6.056 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▆▇▁ ▂██▃▃▁ ▃▃▄▂▂▂▁▁▂▃▂▁  ▁   ▁▂      ▁                        ▂
  ███▅████████████████████████████████████▇▆▆▇▇▆▆▇▇▆▆▇▆▅▆▆▆▆▆▅▆ █
  96.8 ns       Histogram: log(frequency) by time        125 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

tldr benchmarks:

  • the original version was 1.8 μs,
  • the version from yesterday was 919 ns,
  • the macro version is 96 ns

The macro is essentially just rewriting all of the function calls the same way a human would manually, while removing those that wouldn’t contribute to the final sum (due to their coefficient being zero).

1 Like