A slightly more comprehensible `Fix{N}`

I had the pleasure to fix several variables with Base.Fix. Iterating with Fix{2}(Fix{5}(Fix{1}(...)))) isn’t very …er… ergonomic.

I wrote a slightly different Fix which can take a tuple of positions as a parameter, so one can do

fun(u,v,w,x,y,z) = “$u $v $w $x $y $z”
f = Fix{(1,4,5,2)}(fun, 1, 4, 5, 2)
f(-3, -6)
g = Fix{2}(f, -6)
g(-3)

Any thoughts? I think it’s backwards compatible with the current Fix. Should I submit a PR?

Code
const _stable_typeof = Base._stable_typeof  # throw in this if running outside of Base

struct Fix{N,F,T} <: Function
    f::F
    x::T

    function Fix{N}(f::F, x...) where {N, F}
        if N isa Int
            fix = (N,)
        elseif N isa NTuple{L, Int} where L
            fix = N
        else
            throw(ArgumentError(LazyString("expected type parameter in Fix to be `Int` or a tuple of `Ints`, but got `", N, "::", typeof(N), "`")))
        end

        any(<(1), N) && throw(ArgumentError(LazyString("expected type parameter in Fix to be integers greater than 0, but got ", N)))

        length(N) == length(x) || throw(ArgumentError(LazyString("type parameter in Fix specifies $(length(N)) fixed arguments $N, but got $(length(x)): ",x)))

        new{fix, _stable_typeof(f), _stable_typeof(x)}(f, x)
    end
end

@generated function (f::Fix{N,F,T})(args...; kws...) where {N,F,T}

    callexpr = :(f.f(; kws...))
    allargs = callexpr.args

    offset = length(allargs)  # for function name and parameters

    # make room for all args
    resize!(allargs, offset+length(args)+length(N))
    for i in offset+1:length(allargs)
        allargs[i] = :undef
    end

    # fill in the fixed args
    for n in 1:length(N)
        allargs[N[n]+offset] = :(f.x[$n])
    end

    # and the others from args
    nextarg = 1
    for i in offset+1:length(allargs)
        if allargs[i] === :undef
            allargs[i] = :(args[$nextarg])
            nextarg += 1
        end
    end

    return callexpr
end
1 Like

I’m wondering if it would make sense to order the fixed arguments and to support iterated fixing of arguments. The following would then be equal (and of the same type):

Fix{(1,2)}(f, -1, -2)
Fix{(2,1)}(f, -2, -1)
Fix{1}(Fix{2}(f, -2), -1)
Fix{1}(Fix{1}(f, -1), -2)

This would make it more user-friendly to dispatch on such types.

Ordering the fixed arguments is reasonably easy, if only there were a sortperm(::Tuple). The iteration might be somewhat more involved, but probably a good idea.

Intuitively I imagine this could work recursively. Seems like there’s something like that in Curry.jl

I’m not saying it’s necessarily a good idea but maybe something to consider.

1 Like

There’s one thing I don’t understand in the Fix in Base. It calls Base._stable_typeof which is defined as

_stable_typeof(x) = typeof(x)
_stable_typeof(::Type{T}) where {T} = @isdefined(T) ? Type{T} : DataType

What does this function do? Or more specifically, for which inputs does it return DataType?

Some specific issues:

  • Currently Base.Fix2 === Base.Fix{2}. With your change you presumably want to have Base.Fix2 === Base.Fix{(2,)}, which would be breaking.

  • With your change T(args...) isa T would not hold for T === Base.Fix{1}, for example.

Potential issue (not sure): are you sure that arguments which are types could be supported in a performant manner with your proposal. Keep in mind:

julia> (Int, String)
(Int64, String)

julia> typeof(ans)
Tuple{DataType, DataType}

In general, IMO the proposal “smells bad” to me, because it complicates the implementation for no certain benefit. Furthermore, as extensively discussed already on the original Fix PR, something like this should go into a package first, instead of forcing it into Base prematurely.

1 Like

The first method is for non-Types, the second method is for Types. The branch in the second method is for handling incomplete types.

Ah, I hadn’t thought about that. That’s a good point. For arguments that are types, like Int, we have Int isa Type{Int}, so the T in Fix can be Type{Int} (as well as the non-specific DataType), but this relation does not extend to tuples, i.e. we do not have (Int,) isa Tuple{Type{Int}}, so the T in Fix has to be Tuple{DataType}, which is non-specific and probably not performant.

So, one would have to use single argument fixing for types.

Interestingly, while I can’t use

x = Tuple{Type{Int},Type{String}}((Int, String))

because Tuple meddles with the datatypes:

julia> typeof(x)
Tuple{DataType, DataType}

It’s possible to get around this by using a NamedTuple:

julia> x = NamedTuple{(:a,:b), Tuple{Type{Int},Type{String}}}((Int, String))
@NamedTuple{a::Type{Int64}, b::Type{String}}((Int64, String))

julia> typeof(x)
@NamedTuple{a::Type{Int64}, b::Type{String}}

While typeof(x[1]) === DataType, we have fieldtype(typeof(x), 1) === Type{Int}, so the sharper type information in the NamedTuple is used for inference. It’s at least not flagged by code_warntype or JET, and doesn’t seem to allocate on use. So, by letting the T in Fix be a NamedTuple (with dummy names), performance can be achieved also for type arguments.

In short we have:

T = Tuple{Type{Int}, Type{String}}
NT = @NamedTuple{a::Type{Int}, b::Type{String}}
x = (Int, String)

julia> NT(x)::NT
@NamedTuple{a::Type{Int64}, b::Type{String}}((Int64, String))

julia> T(x)::T
ERROR: TypeError: in typeassert, expected Tuple{Type{Int64}, Type{String}}, got a value of type Tuple{DataType, DataType}

Or,

julia> @code_typed Tuple{Type{Int}}((Int,))[1]
CodeInfo(
1 ─ %1 = $(Expr(:boundscheck))::Bool
│   %2 =   builtin Base.getfield(t, i, %1)::DataType
└──      return %2
) => DataType

julia> @code_typed NamedTuple{(:a,), Tuple{Type{Int}}}((Int,))[1]
CodeInfo(
1 ─       builtin Base.getfield(t, i)::Type{Int64}
└──     return Int64
) => Type{Int64}

Maybe I’m misreading the proposal, but I thought that this was about adding a constructor to Fix but that the result would still be a nested Fix type. So Fix{(3,1)}(f, a, b) => Fix{1}(Fix{3}(f, a), b). So it wouldn’t break Fix2, just make it easier to fix multiple arguments at once.

In my opinion notation like f = Fix{(1,4,5,2)}(fun, 1, 4, 5, 2) is really hard to read in code not written by yourself. You have to stop to think for a second to understand what is actually called under the hood when we invoke f(x). And the more arguments you fix, the more complicated it gets.

There are several ways I can see here:

  • Create a macro for this exact purpose: @fix fun(1, 2, _, 4, 5, _) or @fix(fun, 1, 2, _, 5, 6, _). I think, Underscores.jl already has something similar.
  • Allow single-argument chainable Fix constructor: fun |> Fix{5}(5) |> Fix{4}(4) |> Fix{2}(2) |> Fix{1}(1). Note that the order is important here.
  • Write a lambda: (u, v) -> f(1, 2, u, 4, 5, v). There should be no performance difference, and the readability is much better

I agree that the macro solution @fix is very readable. It would be very easy to build this on top:

macro fix(call)
    @capture(call, f_(xs__))
    fixpos = findall(!=(:_), xs)
    fixargs = (xs[fixpos]...,)
    quote
        FixArg{$((fixpos...,))}($(esc(f)), $(esc.(fixargs)...))
    end
end

The lambda variant will do capturing if the fixed values aren’t literal, with all the problems this has in practice.

But I think the ordering and combination suggested by @matthias314 really is good for performance, not relying on inlining of the inner function. This works fine, also for type arguments.

1 Like

An advantage of something like Fix{(1,4,5,2)} is that you can dispatch on such a type. This would be lost with anonymous functions. In order to be really useful, this should be combined with ordering the arguments (hence not (1,4,5,2)) and combining iterated Fix. I admit that for four fixed arguments this is contrived, but two fixed arguments would be more realistic.

I’ll see if I can polish it and make it into a package:

Updated code
using MacroTools
const _stable_typeof = Base._stable_typeof


struct FixArg{N,F,T} <: Function
    f::F
    x::T
    function FixArg{N,F,T}(f, x) where {N,F,T}
        if N isa Tuple
            (all(t -> isa(t, Int), N) && issorted(N) && all(≥(1), N)) ||
                throw(ArgumentError(LazyString("expected type parameter in Fix to be `Int` or an ordered tuple of `Int`s, but got `", N, "::", typeof(N), "`")))

            length(N) == length(x) || throw(ArgumentError(LazyString("type parameter in Fix specifies $(length(N)) fixed arguments $N, but got $(length(x)): ",x)))
        else
            (N isa Int && N ≥ 1) || throw(ArgumentError(LazyString("expected type parameter in Fix to be `Int` or an ordered tuple of `Int`s, greater than zero, but got `", N, "::", typeof(N), "`")))
        end
        new{N,F,T}(f, x)
    end
end

_ziptuples(tup::Tuple...) = ntuple(Val(length(tup[1]))) do i
    map(t -> t[i], tup)
end

function _sortperm(tup::Tuple)
    seq = ntuple(identity, Val(length(tup)))
    s = sort(_ziptuples(tup, seq))
    last.(s)
end


function FixArg{N}(f, x...) where {N}
    if N isa Tuple
        order = _sortperm(N)
        sortedN = ntuple(i -> N[order[i]], Val(length(N)))
        ox = ntuple(i -> x[order[i]], Val(length(N)))

        types = ntuple(Val(length(N))) do i
            _stable_typeof(x[order[i]])
        end
        nm = ntuple(i->Symbol('f',i), Val(length(types)))
        nt = NamedTuple{nm, Tuple{types...}}
        FixArg{sortedN, _stable_typeof(f), nt}(f, ox)
    else
        x1 = x[1]
        FixArg{N, _stable_typeof(f), _stable_typeof(x1)}(f, x1)
    end
end


function _mergesortedtuples(tup1, tup2, ord1=tup1, ord2=tup2)
    idx1 = Ref(0)
    idx2 = Ref(0)
    ntuple(Val(length(tup1) + length(tup2))) do i
        if idx2[] >= length(tup2)
            return tup1[idx1[] += 1]
        end
        if idx1[] >= length(idx1)
            return tup2[idx2[] += 1]
        end
        b = ord2[idx2[]+1] < ord1[idx1[]+1]
        if b
            return tup2[idx2[] += 1]
        else
            return tup1[idx1[] += 1]
        end
    end
end


# Fixing arguments in a FixArg function merges the two into a single FixArg
@generated function FixArg{N}(f::FixArg{innerN, innerF, innerT}, outerx...) where {N, innerN, innerF, innerT}

    fix = isa(N, Int) ? (N,) : N
    innerT = innerN isa Int ? Tuple{innerT} : innerT

    length(outerx) == length(N) || throw(ArgumentError(LazyString("number of arguments must equal number of positions")))
    seq = ntuple(identity, Val(length(N)))
    _sorted = sort(_ziptuples(fix, outerx, seq))
    outerN = first.(_sorted)
    outertypes = map(x -> x[2], _sorted)
    order = last.(_sorted)

    allseq = ntuple(identity, Val(maximum(innerN)+maximum(outerN)))
    holes = filter(∉(innerN), allseq)
    newouter = map(i -> holes[i], outerN)

    # now, merge innerN and newouter, and their types
    newN = _mergesortedtuples(newouter, innerN)
    types = _mergesortedtuples(outertypes, fieldtypes(innerT), newouter, innerN)


    nm = ntuple(i -> Symbol('f',i), Val(length(types)))
    newtype = NamedTuple{nm, Tuple{types...,}}


    # then create a call to the inner constructor, and fill it with the right arguments
    callexpr = :(FixArg{$newN, innerF, $newtype}(f.f, ()))

    args = callexpr.args[3].args
    resize!(args, length(newN))
    innerarg = 0
    outerarg = 0
    for i in 1:length(newN)
        if newN[i] in innerN
            if innerN isa Int
                args[i] = :(f.x)
            else
                args[i] = :(f.x[$(innerarg += 1)])
            end
        else
            args[i] = :(outerx[$(order[outerarg += 1])])
        end
    end

    return callexpr
end


@generated function (f::FixArg{N})(args...; kws...) where {N}

    callexpr = :(f.f(; kws...))
    allargs = callexpr.args

    offset = length(allargs)  # for function name and parameters

    # make room for all args
    resize!(allargs, offset+length(args)+length(N))
    for i in offset+1:length(allargs)
        allargs[i] = :undef
    end


    # fill in the fixed args
    if N isa Int
        allargs[N+offset] = :(f.x)
    else
        for n in 1:length(N)
            allargs[N[n]+offset] = :(f.x[$n])
        end
    end

    # and the others from args
    nextarg = 1
    for i in offset+1:length(allargs)
        if allargs[i] === :undef
            allargs[i] = :(args[$nextarg])
            nextarg += 1
        end
    end

    return callexpr
end

(f::FixArg{1})(arg; kws...) = f.f(f.x, arg; kws...)
(f::FixArg{2})(arg; kws...) = f.f(arg, f.x; kws...)

const FixArg1{F,T} = FixArg{1,F,T}
const FixArg2{F,T} = FixArg{2,F,T}

macro fix(call)
    @capture(call, f_(xs__))
    fixpos = findall(!=(:_), xs)
    fixargs = (xs[fixpos]...,)
    quote
        FixArg{$((fixpos...,))}($(esc(f)), $(esc.(fixargs)...))
    end
end
Some tests
fun(u,v,w,x,y,z) = (u,v,w,x,y,z)
fun1(u,v,w,x,y,z) = z


f = FixArg{(1,4,5,2)}(fun, 1, 4, 5, 2)
@assert f === FixArg{(1,2,4,5)}(fun, 1, 2, 4, 5)
@assert f(-3.0, -6) === (1,2,-3.0,4,5,-6)
g = FixArg{2}(f, "-6")
@assert g(-3%Int8) === (1,2,-3%Int8,4,5,"-6")

f1 = FixArg{(5,3)}(fun, -5.0, -3)
f2 = FixArg{(3,5)}(fun, -3, -5.0)
f3 = FixArg{3}(FixArg{5}(fun, -5.0), -3)
f4 = FixArg{4}(FixArg{3}(fun, -3), -5.0)
f5 = @fix fun(_, _, -3, _, -5.0, _)

@assert allequal((f1, f2, f3, f4, f5))

g1 = FixArg{3}(fun, Vector)
g2 = FixArg{4}(g1, Float64)
@assert g2 === FixArg{(3,5)}(fun, Vector, Float64)
@assert g2(2,3,4,5) === (2, 3, Vector, 4, Float64, 5)

# combination:
f = FixArg{1}(FixArg{(1,2)}(FixArg{1}(FixArg{1}(fun,1), 2), 3, 4), 5)
@assert f == FixArg{(1,2,3,4,5)}(fun, 1,2,3,4,5)
@assert f(6) === (1,2,3,4,5,6)

T = FixArg{1}
@assert T(fun, 1.0) isa T

@assert FixArg1 === FixArg{1}
@assert FixArg2 === FixArg{2}

# inference with type arguments:
function fun2(::Type{T1}, ::Type{T2}, ::Type{T3}, x) where {T1,T2,T3}
  T = typejoin(T1,T2,T3)
  convert(T, x)
end

f2 = FixArg{(2,3)}(fun2, Float16, Float64)
@assert code_typed(f2, (Type{Float16}, Float32))[1].second === Float32
f2 = FixArg{2}(fun2, Int8)
@assert code_typed(f2, (Type{Int16}, Type{UInt32}, Float32))[1].second === Int64

If time allows, try comparing against pre-existing attempts before registering, to prevent cluttering the General registry package name space. For example:

1 Like