Help optimising an expectation function

Hey everyone,

Apologies for the vague title, I think part of my problem is I don’t know the words for what I’m doing so I can’t search for it. Anyway:

I have a function of N boolean arguments. I would like to be able to call the function with floats between 0 and 1, interpreted as taking the expectation of the function evaluated at 1 and 0 weighted by the float value.

For example:

struct CallableInt{N}
    int::Int64
end

function read_bit(x, pos)
    pos < 0 && throw(ArgumentError("Bit position $pos must be >=0"))
    return (x >> pos) & 1 != 0
end

(f::CallableInt{N})(xs::Vararg{I,N}) where {N,I<:Integer} = read_bit(f.int, evalpoly(2, xs))

(f::CallableInt{1})(foo) =  foo*f(true) + (1-foo) * f(false)

function (f::CallableInt{2})(foo, bar)
    foo * (
        bar * f(true, true) + (1-bar) * f(true, false)
    ) +
    (1-foo) * (
        bar * f(false, true) + (1-bar) * f(false, false)
    )
end

The above are very fast and I’m not trying to optimize these, I’m trying instead to write a function to do the above generic in N which is also as fast as a handwritten version.

My attempt at doing so falls short:

function (f::CallableInt{N})(args::Vararg{Float64,N}) where {N}
    out = 0.0
    all(0 .<= args .<= 1) || throw(ArgumentError("All arguments must be between 0 and 1"))
    rate_combinations = NTuple{N,NTuple{2,Float64}}((arg, 1 - arg) for arg in args)
    boolean_combinations = NTuple{N,NTuple{2,Bool}}((true, false) for _ in args)
    for (boolean_combination, rate_combination) in zip(
        Iterators.product(boolean_combinations...),
        Iterators.product(rate_combinations...)
    )
        out += f(boolean_combination...) * reduce(*, rate_combination)
    end
    return out
end

It returns the right value, but takes very long to execute. I was wondering if code generation is the right way to go, but I know that it’s a last resort in most cases. I would appreciate any advice on speeding this function up.

Edit: fixed bug in type hints.

Can you check that your code runs, and provide sample input? From your description it sounds like this should accept an arbitrary function somehow.

Here’s some example uses, it’s kinda like fuzzy logic. But you’re right that it could be generalised to arbitrary functions. In my case I just wanted to limit the functions to between 0 and 1 and fix the number of inputs they took (and I don’t really know how to do that any other way).

f = CallableInt{2}(0b1001) # 0b1001 = unsigned 9
f(true, true) == 1
f(true, false) == 0
f(false, true) == 0
f(false, false) == 1
f(0.99, 0.99) == 0.9802
f(0.2, 0.8) == 0.32

There was also a bug in the source as you implied and I fixed this. Served me right for trying to type hint my functions :')

Here’s what I was trying, maybe with the bugs removed…

"""
    callall(f, p) = p * f(true) + (1-p) * f(false)
    callall(f, p, q) = p*q * f(true, true) + p*(1-q) * f(true, false) + 
                   (1-p)*q * f(false, true) + (1-p)*(1-q) * f(false, false)
    callall(f, probs::Real...) = ...

Evaluates a function `f(::Bool...)` accepting `N` arguments `2^N` times,
and adds up the results weighted by the given probabilities as shown.
"""
function callall(f::F, args::Vararg{Real,N}) where {F,N}
    for a in args
        0 <= a <= 1 || error("out of range")
    end
    out = zero(promote_type(map(typeof, args)...))
    for n in 1:2^N
        bools = ntuple(i -> Bool((n >> (i-1)) & 1), N)
        val = f(bools...)
        weight = prod(map((b,a) -> ifelse(b, a, 1-a), bools, args))
        # @show n bools val weight
        out += val * weight
    end
    out
end

callall(-, 0.3) == -0.3
callall(!, 0.3) == 0.7
callall(*, 0.3, 0.9) == 0.27

f1 = CallableInt{1}(0b1001)  # using above code
f1.([true, false])  # this is the function !

f1(0.3) == callall(f1, 0.3) == callall(!, 0.3)

f = CallableInt{2}(0b1001) # 0b1001 = unsigned 9
f.([true, false], [true false])  # truth table, this is the function ==

f(0.2, 0.8) == 0.32 == callall(f, 0.2, 0.8) == callall(==, 0.2, 0.8)

callall(Returns(true), rand(5)...) ≈ 1
callall(Returns(false), rand(5)...) == 0

@btime callall(*, $(Tuple(rand(3)))...);  # 5ns
@btime callall(*, $(Tuple(rand(5)))...);  # 50ns

2 Likes

Wow this is a really clever implementation. I especially like this line

bools = ntuple(i -> Bool((n >> (i-1)) & 1), N)

I thought about iterating over 1:2^N but couldn’t think up how to get the bools from that.

Is Vararg an exception when it comes to abstractly typed typehints? I thought you should avoid at all costs writing abstract types in function arguments?

Are you just covering your bases with this comment or are there some bugs that you encountered while running it?

It had all sorts of off-by-one errors at first!

I think dispatching on g(x::Real) is totally fine, and even g(x...) which is Vararg{Any}. It’s things like Real[1,2] or structs which are the problem.

And I think my bools is the same trick as your read_bit?

I thought there would be more room for improvement, but at first I was just timing things wrong. Both are down to 1.5-2ns per evaluation:

julia> f5 = CallableInt{5}(0b1001);
julia> x5 = Tuple(rand(5));

julia> @btime $f5($x5...)  # with zip(Iterators.product(...
  70.782 ns (0 allocations: 0 bytes)
0.010132227405239502

julia> @btime callall($f5, $x5...)  # only for & ntuple
  56.698 ns (0 allocations: 0 bytes)
0.010132227405239502

julia> @btime callall(+, $x5...)  # cheaper function
  47.697 ns (0 allocations: 0 bytes)
2.321889303862584

Maybe also worth noting that N real numbers for N arguments gives one possible weighting, the one where the probability factorises.

1 Like

Highly forgivable when working with binary numbers due to the nature of Julia’s 1-base indexing!

Perhaps, but I hadn’t thought to apply it in this way.

There is nothing wrong with abstractly typed function arguments. In Julia a function has several methods, e.g.,

julia> foo(x::Real) = 2 * x
foo (generic function with 1 method)

julia> foo(x::String) = "2" * x
foo (generic function with 2 methods)

julia> meths = methods(foo)
# 2 methods for generic function "foo":
[1] foo(x::Real) in Main at REPL[1]:1
[2] foo(x::String) in Main at REPL[2]:1

Here the types are used for dispatch, e.g., if foo(1.0) is called the first method would be used. Furthermore, this method then gets specialized for the actual concrete type of the call Float64:

julia> meths[1].specializations
svec()

julia> foo(1.0)
2.0

julia> meths[1].specializations
svec(MethodInstance for foo(::Float64), nothing, nothing, nothing, nothing, nothing, nothing, nothing)

These specializations contain the actual compiled code of a method and in particular, are specialized for each specific concrete type the method was called with. This is what makes Julia fast and is also listed in the performance tips as function barriers.

The problem arises, when the compiler cannot infer the concrete type of a variable which commonly arises when structs have untyped or abstractly typed fields. In this case, the dispatch logic has to be included into the compiled code in order to select the appropriate specialization:

julia> function known()
           args = (1.0, 1)
           foo.(args)
       end
known (generic function with 1 method)

# Here the compiler can infer the types of the call and select appropriate specialization beforehand
julia> @code_warntype known()
MethodInstance for known()
  from known() in Main at REPL[18]:1
Arguments
  #self#::Core.Const(known)
Locals
  args::Tuple{Float64, Int64}
Body::Tuple{Float64, Int64}
1 ─      (args = Core.tuple(1.0, 1))
│   %2 = Base.broadcasted(Main.foo, args::Core.Const((1.0, 1)))::Core.Const(Base.Broadcast.Broadcasted(foo, ((1.0, 1),)))
│   %3 = Base.materialize(%2)::Core.Const((2.0, 2))
└──      return %3

julia> function unknown()
           args = Any[1.0, "3"]
           foo.(args)
       end
unknown (generic function with 1 method)

# Here, the dispatch needs to be done at runtime
julia> @code_warntype unknown()
MethodInstance for unknown()
  from unknown() in Main at REPL[20]:1
Arguments
  #self#::Core.Const(unknown)
Locals
  args::Vector{Any}
Body::AbstractVector
1 ─      (args = Base.getindex(Main.Any, 1.0, "3"))
│   %2 = Base.broadcasted(Main.foo, args)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(foo), Tuple{Vector{Any}}}
│   %3 = Base.materialize(%2)::AbstractVector
└──      return %3
2 Likes