Hey everyone,
Apologies for the vague title, I think part of my problem is I don’t know the words for what I’m doing so I can’t search for it. Anyway:
I have a function of N boolean arguments. I would like to be able to call the function with floats between 0 and 1, interpreted as taking the expectation of the function evaluated at 1 and 0 weighted by the float value.
For example:
struct CallableInt{N}
int::Int64
end
function read_bit(x, pos)
pos < 0 && throw(ArgumentError("Bit position $pos must be >=0"))
return (x >> pos) & 1 != 0
end
(f::CallableInt{N})(xs::Vararg{I,N}) where {N,I<:Integer} = read_bit(f.int, evalpoly(2, xs))
(f::CallableInt{1})(foo) = foo*f(true) + (1-foo) * f(false)
function (f::CallableInt{2})(foo, bar)
foo * (
bar * f(true, true) + (1-bar) * f(true, false)
) +
(1-foo) * (
bar * f(false, true) + (1-bar) * f(false, false)
)
end
The above are very fast and I’m not trying to optimize these, I’m trying instead to write a function to do the above generic in N
which is also as fast as a handwritten version.
My attempt at doing so falls short:
function (f::CallableInt{N})(args::Vararg{Float64,N}) where {N}
out = 0.0
all(0 .<= args .<= 1) || throw(ArgumentError("All arguments must be between 0 and 1"))
rate_combinations = NTuple{N,NTuple{2,Float64}}((arg, 1 - arg) for arg in args)
boolean_combinations = NTuple{N,NTuple{2,Bool}}((true, false) for _ in args)
for (boolean_combination, rate_combination) in zip(
Iterators.product(boolean_combinations...),
Iterators.product(rate_combinations...)
)
out += f(boolean_combination...) * reduce(*, rate_combination)
end
return out
end
It returns the right value, but takes very long to execute. I was wondering if code generation is the right way to go, but I know that it’s a last resort in most cases. I would appreciate any advice on speeding this function up.
Edit: fixed bug in type hints.