We already have Base.Fix1
and Base.Fix2
for partial function application. It seems that it would be nice and probably often useful to define a similar facility for the general case, available upon request. Is this something that would be considered for inclusion in Base
?
I mocked up a quick prototype for myself (see the end of the post). This creates efficient code for the 1 or 2 arguments fixed case at least:
julia> f = bind(+, (Val(1), Val(2)), 1, 2)
(::Main.Bindings.Bind{1,Main.Bindings.Bind{1,typeof(+),Int64},Int64}) (generic function with 1 method)
julia> @code_llvm f(3, 4)
; @ /Users/mcbane1/Desktop/Bindings.jl:15 within `Bind'
define i64 @julia_Bind_15981({ { i64 }, i64 } addrspace(11)* nocapture nonnull readonly dereferenceable(16), i64, i64) {
top:
; ┌ @ Base.jl:20 within `getproperty'
%3 = getelementptr inbounds { { i64 }, i64 }, { { i64 }, i64 } addrspace(11)* %0, i64 0, i32 1
; └
; @ /Users/mcbane1/Desktop/Bindings.jl:15 within `Bind' @ /Users/mcbane1/Desktop/Bindings.jl:15
; ┌ @ Base.jl:20 within `getproperty'
%4 = getelementptr inbounds { { i64 }, i64 }, { { i64 }, i64 } addrspace(11)* %0, i64 0, i32 0, i32 0
; └
; ┌ @ operators.jl:529 within `+' @ int.jl:53
%5 = load i64, i64 addrspace(11)* %4, align 8
%6 = load i64, i64 addrspace(11)* %3, align 8
%7 = add i64 %2, %1
%8 = add i64 %7, %5
; │ @ operators.jl:529 within `+'
; │┌ @ operators.jl:516 within `afoldl'
; ││┌ @ int.jl:53 within `+'
%9 = add i64 %8, %6
; └└└
; @ /Users/mcbane1/Desktop/Bindings.jl:15 within `Bind'
ret i64 %9
}
And I have included the feature that an argument may be bound to a Val
in the case that it is known at compile-time; this results in the desired optimization as well:
julia> f = bind(+, (Val(1), Val(2)), Val(1), Val(2))
(::Main.Bindings.Bind{1,Main.Bindings.Bind{1,typeof(+),Val{1}},Val{2}}) (generic function with 1 method)
julia> f(3, 4)
10
julia> @code_llvm f(3, 4)
; @ /Users/mcbane1/Desktop/Bindings.jl:16 within `Bind'
define i64 @julia_Bind_15994(i64, i64) {
top:
; @ /Users/mcbane1/Desktop/Bindings.jl:16 within `Bind' @ /Users/mcbane1/Desktop/Bindings.jl:16
; ┌ @ operators.jl:529 within `+' @ int.jl:53
%2 = add i64 %0, 3
; │ @ operators.jl:529 within `+'
; │┌ @ operators.jl:516 within `afoldl'
; ││┌ @ int.jl:53 within `+'
%3 = add i64 %2, %1
; └└└
; @ /Users/mcbane1/Desktop/Bindings.jl:16 within `Bind'
ret i64 %3
}
It seems to me that for as few lines of code as are needed to implement this feature there is no reason it should not be included. Below is the full implementation.
module Bindings
struct Bind{N, F, T} <: Function
f::F
x::T
function Bind{N}(f::Function, x) where N
if !(N isa Int)
throw(TypeError(:Bind, "Bind{N}(f, x)", Int, typeof(N)))
end
new{N, typeof(f), typeof(x)}(f, x)
end
end
(b::Bind{N})(xs...) where N = b.f(arg_list_(b.x, xs, Val(N))...)
(b::Bind{N, F, Val{X}})(xs...) where {N, F, X} = b.f(arg_list_(X, xs, Val(N))...)
function arg_list_(x, xs::Tuple, ::Val{N}) where N
if N == 1
(x, xs...)
else
(xs[1], arg_list_(x, Base.tail(xs), Val(N-1))...)
end
end
bind(f::Function, ::Val{N}, x) where N = Bind{N}(f, x)
function bind(f::Function, t::Tuple{Vararg{Val}}, xs...)
if length(t) != length(xs)
throw(ArgumentError("In bind: specified $(length(xs)) arguments to bind but $(length(t)) binding indices"))
end
# Early return for the trivial case.
if length(t) == 1
return bind(f, t[1], xs[1])
end
if !issorted(unwrap_val_.(t))
throw(ArgumentError("In bind: binding indices are required to be specified in ascending order"))
elseif !allunique(t)
throw(ArgumentError("In bind: binding indices must be unique"))
end
# Bind the first value.
g = bind(f, t[1], xs[1])
# Reduce the index of each trailing parameter by 1
Ns = subtract_one_.(Base.tail(t))
bind(g, Ns, Base.tail(xs)...)
end
Base.@pure unwrap_val_(::Val{X}) where X = X
subtract_one_(::Val{X}) where X = Val(X-1)
end # module