RFC – ergonomic juliacall syntax for passing Python variables to Julia

That’s why you wrap the additional argument in a closure, as I said. But you can pass a Python lambda, as I suggested — you shouldn’t need to seval a Julia closure.

This is the standard way to pass additional arguments to functions passed to higher-order functions, regardless of whether you are doing cross-language calling. Using globals for this purpose is a bad habit.

(Even for calls within the same language, I find that this is one of the most common confusions in programming with higher-order functions. I’ve lost count of the number of times people have asked how to pass additional arguments to an objective function in NLopt in Python or Julia or …, or how to pass parameters to an integrand in QuadGK.jl etc. Closures and higher-order functions seem to be a major omission from typical programming pedagogy.)

The point is, you should treat the Julia function exactly as you would treat a Python function that had 4 arguments which you wanted to pass to a 3-argument higher-order function — you would use lambda. You shouldn’t need to learn a new syntax.

Telling people write functions, not scripts; pass arguments, not globals seems like it should be the default solution here. But I accept that lots of people like to write scripts full of globals (that’s what they are used to, and for little projects it can be the easiest approach), so it’s not crazy to provide a nicer syntax. (I did exactly this in PyCall.jl, as I said above!) But it shouldn’t be the first resort, and I don’t think using functions should be viewed as awkward.

2 Likes

Sadly this isn’t possible due to Python’s GIL. SymbolicRegression.jl would need to regain control of the GIL whenever it evaluates the loss function, which isn’t compatible with the default multi-threaded mode, not to mention needing to differentiate through it.

After exploring several approaches to make Python->Julia variable passing more ergonomic while maintaining proper scoping, here are a couple more ideas. I think I agree with @Benny that the context manager approach is probably best avoided as it introduces too much “spooky action at a distance” and is not as transparent.

0. Current approach using closures:

np_weights = # ...

elementwise_loss = jl.seval(
    """
    weights -> let
        function my_weighted_loss(predict, target)
            return sum(
                i -> weights[i] * abs2(predict[i] - target[i]),
                eachindex(predict, target, weights)
            )
        end
    end
    """
)(np_weights)

While this works, it has several drawbacks: (1) it’s unintuitive to Python users, (2) it reduces code readability, and (3) it requires writing a closure for every variable.

Some alternatives:

1. Easier closure

Using the new general Fix{n} from Create `Base.Fix` as general `Fix1`/`Fix2` for partially-applied functions by MilesCranmer · Pull Request #54653 · JuliaLang/julia · GitHub (also available in Compat), we could avoid the GIL problems created from a Python lambda, and readability issues from a chained jl.seval ->closure. This would look like the following:

jl.seval("""
function my_weighted_loss(predict, target, weights)
    return sum(
        i -> weights[i] * abs2(predict[i] - target[i]),
        eachindex(predict, target, weights)
    )
end
""")

elementwise_loss = jl.Fix[3](jl.my_weighted_loss, np_weights)

The nice part about this is that the Fix{3} is a Julia function, so (1) it would immediately do a one-time conversion of np_weights into Julia, rather than performing the conversion at each call, and (2) this would avoid the GIL issue mentioned above. In other words, this would open a nice route for @stevengj’s recommended approach above.

There are some downsides that I see: (1) is that users might reflexively try to use lambda and hit these problems, (2) jl.Fix[3] indexes the 3rd argument, whereas a Python user might expect it to be the 4th with their 0-indexed brain, (3) there isn’t a multi-arg or keyword-arg jl.Fix yet, and (4) using [n] to pass an argument sorta breaks all Python intuition, sadly.

Maybe a special jl.lambda could be created to simplify this? e.g., like jl.lambda(f, {1: x, 3: y, "kwarg": z}) which would return a Julia type.

2. Directly interpolating keywords to jl.seval (or make a new function jl.teval)

elementwise_loss = jl.seval(
    """
    function my_weighted_loss(predict, target)
        return sum(
            i -> $(weights)[i] * abs2(predict[i] - target[i]),
            eachindex(predict, target, $(weights))
        )
    end
    """,
    weights=np_weights
)

Where $(weights) is used to interpolate from the passed keyword arguments. You can’t interpolate directly since it results in a string. Also, the reason for the $() as discussed previously, is that {} conflicts with Julia type annotations. I guess $ could also be problematic if there are @eval calls in the evaluated code, so maybe an alternative syntax is needed.

Relatedly, @Benny found that $ is actually used in some other Python stdlib for string interpolation, so it could be a good option for that reason. That string.Template("...").substitute(d) could be used as the first stage of processing.

3. Chained calls with Namespace:

elementwise_loss = (
    jl.Namespace(weights=np_weights)
      .seval("""
        function my_weighted_loss(predict, target)
            return sum(
                i -> $(weights)[i] * abs2(predict[i] - target[i]),
                eachindex(predict, target, $(weights))
            )
        end
    """)
)

I chose Namespace as this is more Pythonic as its used for argparse.Namespace in the standard library. (noting the comment from the PythonCall.jl README: “the Python code looks like Python and the Julia code looks like Julia.”)

We could also let this be a mutable object:

scope = jl.Namespace()
scope.weights = np_weights
# ^ Stores both `"weights"` and `np_weights`
elementwise_loss = scope.seval("""
    function my_weighted_loss(predict, target)
        return sum(
            i -> $(weights)[i] * abs2(predict[i] - target[i]),
            eachindex(predict, target, $(weights))
        )
    end
    """)

In general I do think it’s good to explicitly indicate interpolation, rather than just writing out weights here. By being more explicit, we can throw errors if the user forgets to pass the variable (which string.Template("...").substitute(d) handles already). This also avoids potential issues if weights is already in the global namespace and no UndefVarError is raised!

One option for a Pythonic automatic closure for any-args any-kwargs:

First, the Julia partial.jl:

struct Partial{F,ARG<:NamedTuple,KWS<:NamedTuple}
    f::F
    args::ARG  # e.g., NamedTuple{(Symbol(1),)}((1.0,)) for 1.0 as first positional argument
    kws::KWS  # e.g., NamedTuple{(:kw,)}((:y,)) for kw=y
end

# @generated is needed to force inference of the key values
@generated function _make_all_args_keys(::Val{N}) where {N}
    return :($(ntuple(i -> Symbol(i), Val(N))))
end

function (p::Partial{F})(args::Vararg{Any,N}; kws...) where {F,N}
    all_args_keys = _make_all_args_keys(Val(N + length(p.args)))
    input_args_keys = filter(k -> k ∉ keys(p.args), all_args_keys)
    input_args = NamedTuple{input_args_keys}(args)
    sorted_args = let all_args = (; p.args..., input_args...)
        # This sorts the keys in order Symbol(1), Symbol(2), ...
        map(i -> all_args[i], all_args_keys)
    end
    return p.f(sorted_args...; kws..., p.kws...)
end

_make_named_tuple(keys, values) = NamedTuple{(keys...,)}((values...,))

Then, the Python side:

from juliacall import Main as jl

jl.include("partial.jl")

def bind(f, args):
    _args = {jl.Symbol(k): v for (k, v) in args.items() if isinstance(k, int)}
    _kwargs = {jl.Symbol(k): v for (k, v) in args.items() if not isinstance(k, int)}

    jl_args = jl._make_named_tuple(_args.keys(), _args.values())
    jl_kwargs = jl._make_named_tuple(_kwargs.keys(), _kwargs.values())

    return jl.Partial(f, jl_args, jl_kwargs)

Then we can use it like this:

In [12]: jl.seval("f(x, y; kw1, kw2) = (@show x y kw1 kw2; nothing)")

In [13]: partial_f = bind(jl.f, {2: 5.0, "kw2": "blah"})

In [14]: partial_f(1.0, kw1=22)
x = 5.0
y = 1.0
kw1 = 22
kw2 = "blah"

Meaning the example above would look like:

from juliacall import bind

jl.seval("""
function my_weighted_loss(predict, target, weights)
    return sum(
        i -> weights[i] * abs2(predict[i] - target[i]),
        eachindex(predict, target, weights)
    )
end
""")

elementwise_loss = bind(jl.my_weighted_loss, {3: np_weights})

And the evaluation would be pure Julia. Still not sure if I like this over the other ones though…