Any thoughts? I think it’s backwards compatible with the current Fix. Should I submit a PR?
Code
const _stable_typeof = Base._stable_typeof # throw in this if running outside of Base
struct Fix{N,F,T} <: Function
f::F
x::T
function Fix{N}(f::F, x...) where {N, F}
if N isa Int
fix = (N,)
elseif N isa NTuple{L, Int} where L
fix = N
else
throw(ArgumentError(LazyString("expected type parameter in Fix to be `Int` or a tuple of `Ints`, but got `", N, "::", typeof(N), "`")))
end
any(<(1), N) && throw(ArgumentError(LazyString("expected type parameter in Fix to be integers greater than 0, but got ", N)))
length(N) == length(x) || throw(ArgumentError(LazyString("type parameter in Fix specifies $(length(N)) fixed arguments $N, but got $(length(x)): ",x)))
new{fix, _stable_typeof(f), _stable_typeof(x)}(f, x)
end
end
@generated function (f::Fix{N,F,T})(args...; kws...) where {N,F,T}
callexpr = :(f.f(; kws...))
allargs = callexpr.args
offset = length(allargs) # for function name and parameters
# make room for all args
resize!(allargs, offset+length(args)+length(N))
for i in offset+1:length(allargs)
allargs[i] = :undef
end
# fill in the fixed args
for n in 1:length(N)
allargs[N[n]+offset] = :(f.x[$n])
end
# and the others from args
nextarg = 1
for i in offset+1:length(allargs)
if allargs[i] === :undef
allargs[i] = :(args[$nextarg])
nextarg += 1
end
end
return callexpr
end
I’m wondering if it would make sense to order the fixed arguments and to support iterated fixing of arguments. The following would then be equal (and of the same type):
Ordering the fixed arguments is reasonably easy, if only there were a sortperm(::Tuple). The iteration might be somewhat more involved, but probably a good idea.
In general, IMO the proposal “smells bad” to me, because it complicates the implementation for no certain benefit. Furthermore, as extensively discussed already on the original Fix PR, something like this should go into a package first, instead of forcing it into Base prematurely.
Ah, I hadn’t thought about that. That’s a good point. For arguments that are types, like Int, we have Int isa Type{Int}, so the T in Fix can be Type{Int} (as well as the non-specific DataType), but this relation does not extend to tuples, i.e. we do not have (Int,) isa Tuple{Type{Int}}, so the T in Fix has to be Tuple{DataType}, which is non-specific and probably not performant.
So, one would have to use single argument fixing for types.
While typeof(x[1]) === DataType, we have fieldtype(typeof(x), 1) === Type{Int}, so the sharper type information in the NamedTuple is used for inference. It’s at least not flagged by code_warntype or JET, and doesn’t seem to allocate on use. So, by letting the T in Fix be a NamedTuple (with dummy names), performance can be achieved also for type arguments.
In short we have:
T = Tuple{Type{Int}, Type{String}}
NT = @NamedTuple{a::Type{Int}, b::Type{String}}
x = (Int, String)
julia> NT(x)::NT
@NamedTuple{a::Type{Int64}, b::Type{String}}((Int64, String))
julia> T(x)::T
ERROR: TypeError: in typeassert, expected Tuple{Type{Int64}, Type{String}}, got a value of type Tuple{DataType, DataType}
Maybe I’m misreading the proposal, but I thought that this was about adding a constructor to Fix but that the result would still be a nested Fix type. So Fix{(3,1)}(f, a, b) => Fix{1}(Fix{3}(f, a), b). So it wouldn’t break Fix2, just make it easier to fix multiple arguments at once.
In my opinion notation like f = Fix{(1,4,5,2)}(fun, 1, 4, 5, 2) is really hard to read in code not written by yourself. You have to stop to think for a second to understand what is actually called under the hood when we invoke f(x). And the more arguments you fix, the more complicated it gets.
There are several ways I can see here:
Create a macro for this exact purpose: @fix fun(1, 2, _, 4, 5, _) or @fix(fun, 1, 2, _, 5, 6, _). I think, Underscores.jl already has something similar.
Allow single-argument chainable Fix constructor: fun |> Fix{5}(5) |> Fix{4}(4) |> Fix{2}(2) |> Fix{1}(1). Note that the order is important here.
Write a lambda: (u, v) -> f(1, 2, u, 4, 5, v). There should be no performance difference, and the readability is much better
I agree that the macro solution @fix is very readable. It would be very easy to build this on top:
macro fix(call)
@capture(call, f_(xs__))
fixpos = findall(!=(:_), xs)
fixargs = (xs[fixpos]...,)
quote
FixArg{$((fixpos...,))}($(esc(f)), $(esc.(fixargs)...))
end
end
The lambda variant will do capturing if the fixed values aren’t literal, with all the problems this has in practice.
But I think the ordering and combination suggested by @matthias314 really is good for performance, not relying on inlining of the inner function. This works fine, also for type arguments.
An advantage of something like Fix{(1,4,5,2)} is that you can dispatch on such a type. This would be lost with anonymous functions. In order to be really useful, this should be combined with ordering the arguments (hence not(1,4,5,2)) and combining iterated Fix. I admit that for four fixed arguments this is contrived, but two fixed arguments would be more realistic.
I’ll see if I can polish it and make it into a package:
Updated code
using MacroTools
const _stable_typeof = Base._stable_typeof
struct FixArg{N,F,T} <: Function
f::F
x::T
function FixArg{N,F,T}(f, x) where {N,F,T}
if N isa Tuple
(all(t -> isa(t, Int), N) && issorted(N) && all(≥(1), N)) ||
throw(ArgumentError(LazyString("expected type parameter in Fix to be `Int` or an ordered tuple of `Int`s, but got `", N, "::", typeof(N), "`")))
length(N) == length(x) || throw(ArgumentError(LazyString("type parameter in Fix specifies $(length(N)) fixed arguments $N, but got $(length(x)): ",x)))
else
(N isa Int && N ≥ 1) || throw(ArgumentError(LazyString("expected type parameter in Fix to be `Int` or an ordered tuple of `Int`s, greater than zero, but got `", N, "::", typeof(N), "`")))
end
new{N,F,T}(f, x)
end
end
_ziptuples(tup::Tuple...) = ntuple(Val(length(tup[1]))) do i
map(t -> t[i], tup)
end
function _sortperm(tup::Tuple)
seq = ntuple(identity, Val(length(tup)))
s = sort(_ziptuples(tup, seq))
last.(s)
end
function FixArg{N}(f, x...) where {N}
if N isa Tuple
order = _sortperm(N)
sortedN = ntuple(i -> N[order[i]], Val(length(N)))
ox = ntuple(i -> x[order[i]], Val(length(N)))
types = ntuple(Val(length(N))) do i
_stable_typeof(x[order[i]])
end
nm = ntuple(i->Symbol('f',i), Val(length(types)))
nt = NamedTuple{nm, Tuple{types...}}
FixArg{sortedN, _stable_typeof(f), nt}(f, ox)
else
x1 = x[1]
FixArg{N, _stable_typeof(f), _stable_typeof(x1)}(f, x1)
end
end
function _mergesortedtuples(tup1, tup2, ord1=tup1, ord2=tup2)
idx1 = Ref(0)
idx2 = Ref(0)
ntuple(Val(length(tup1) + length(tup2))) do i
if idx2[] >= length(tup2)
return tup1[idx1[] += 1]
end
if idx1[] >= length(idx1)
return tup2[idx2[] += 1]
end
b = ord2[idx2[]+1] < ord1[idx1[]+1]
if b
return tup2[idx2[] += 1]
else
return tup1[idx1[] += 1]
end
end
end
# Fixing arguments in a FixArg function merges the two into a single FixArg
@generated function FixArg{N}(f::FixArg{innerN, innerF, innerT}, outerx...) where {N, innerN, innerF, innerT}
fix = isa(N, Int) ? (N,) : N
innerT = innerN isa Int ? Tuple{innerT} : innerT
length(outerx) == length(N) || throw(ArgumentError(LazyString("number of arguments must equal number of positions")))
seq = ntuple(identity, Val(length(N)))
_sorted = sort(_ziptuples(fix, outerx, seq))
outerN = first.(_sorted)
outertypes = map(x -> x[2], _sorted)
order = last.(_sorted)
allseq = ntuple(identity, Val(maximum(innerN)+maximum(outerN)))
holes = filter(∉(innerN), allseq)
newouter = map(i -> holes[i], outerN)
# now, merge innerN and newouter, and their types
newN = _mergesortedtuples(newouter, innerN)
types = _mergesortedtuples(outertypes, fieldtypes(innerT), newouter, innerN)
nm = ntuple(i -> Symbol('f',i), Val(length(types)))
newtype = NamedTuple{nm, Tuple{types...,}}
# then create a call to the inner constructor, and fill it with the right arguments
callexpr = :(FixArg{$newN, innerF, $newtype}(f.f, ()))
args = callexpr.args[3].args
resize!(args, length(newN))
innerarg = 0
outerarg = 0
for i in 1:length(newN)
if newN[i] in innerN
if innerN isa Int
args[i] = :(f.x)
else
args[i] = :(f.x[$(innerarg += 1)])
end
else
args[i] = :(outerx[$(order[outerarg += 1])])
end
end
return callexpr
end
@generated function (f::FixArg{N})(args...; kws...) where {N}
callexpr = :(f.f(; kws...))
allargs = callexpr.args
offset = length(allargs) # for function name and parameters
# make room for all args
resize!(allargs, offset+length(args)+length(N))
for i in offset+1:length(allargs)
allargs[i] = :undef
end
# fill in the fixed args
if N isa Int
allargs[N+offset] = :(f.x)
else
for n in 1:length(N)
allargs[N[n]+offset] = :(f.x[$n])
end
end
# and the others from args
nextarg = 1
for i in offset+1:length(allargs)
if allargs[i] === :undef
allargs[i] = :(args[$nextarg])
nextarg += 1
end
end
return callexpr
end
(f::FixArg{1})(arg; kws...) = f.f(f.x, arg; kws...)
(f::FixArg{2})(arg; kws...) = f.f(arg, f.x; kws...)
const FixArg1{F,T} = FixArg{1,F,T}
const FixArg2{F,T} = FixArg{2,F,T}
macro fix(call)
@capture(call, f_(xs__))
fixpos = findall(!=(:_), xs)
fixargs = (xs[fixpos]...,)
quote
FixArg{$((fixpos...,))}($(esc(f)), $(esc.(fixargs)...))
end
end
If time allows, try comparing against pre-existing attempts before registering, to prevent cluttering the General registry package name space. For example: