Making an efficient function for ignoring unsupported kwargs

Given a function f that may accept some keyword arguments, I want to make a function g that accepts any keyword arguments and ignores those that are unsupported by f. Below is code that achieves this goal:

function ignore_unsupported_kwargs(f)
    function g(args...; kwargs...)
        fmethod = which(Tuple{typeof(f), typeof.(args)...})
        accepted_kwargs = Tuple(Base.kwarg_decl(fmethod))
        filtered_kwargs = pairs(kwargs[intersect(accepted_kwargs, keys(kwargs))])
        return f(args...; filtered_kwargs...)

z(;x) = x
z2 = ignore_unsupported_kwargs(z)
z2(;x=3,y=4) # works! but much slower than z, allocates
# 0.000043 seconds (30 allocations: 1.688 KiB)

However, this code seems to be quite slow: I’d ideally like there to be almost zero overhead. I’ve played around with generated functions to try and move the type-dependent computation out, but I haven’t gotten anything to work. I’d appreciate any help!

1 Like

Wouldn’t it be easier to declare f(arg1, arg2,...; kw1, kw2,..., kwargs...) and simply ignore kwargs in the function body?

Right, but f is user-provided and I want to “correct” the signature to drop the kwargs. So I would want to automatically take in any f and make the approporiate wrapping such as what you wrote.

I managed to come up with something that doesn’t seem to add much overhead, using @generated:

function _get_kwarg_fields(kwargs_type::Type{Iterators.Pairs{K, V, I, T}}) where {K,V,I,T}
    return fieldnames(T)

@generated function _ignore_unsupported_kwargs(f, args...; kwargs...)
    fmethod = which(Tuple{f, args...})
    accepted_kwarg_fields = Base.kwarg_decl(fmethod)
    given_kwarg_fields = _get_kwarg_fields(kwargs)
    if Symbol("kwargs...") in given_kwarg_fields
        return :(f(args...; kwargs...))
    filtered_kwarg_fields = Tuple(intersect(accepted_kwarg_fields, given_kwarg_fields))
    return quote
        filtered_kwarg_values = map(Base.Fix1(Base.getindex, kwargs), $filtered_kwarg_fields)
        filtered_kwargs = NamedTuple{$filtered_kwarg_fields}(filtered_kwarg_values)
        return f(args...; filtered_kwargs...)

function ignore_unsupported_kwargs(f)
    return (args...; kwargs...) -> _ignore_unsupported_kwargs(f, args...; kwargs...)

Let me know if you have any thoughts! An example usage is:

julia> f(; x, y) = x + y
f (generic function with 1 method)

julia> f(a::Int; b, c) = a + b + c
f (generic function with 2 methods)

julia> f2 = ignore_unsupported_kwargs(f)
#5 (generic function with 1 method)

julia> f2(; x=5, y=7, z=10)

julia> f2(3; b=4, c=10, d=12)
1 Like

From what I’ve read online, the use of which in the generated function is not ideal since it depends on the global method table which may change over time. I haven’t been able to find any cases where there are any issues with the above, but I’d appreciate any help with writing an implementation of this functionality that does not violate any assumptions for generated functions.