Associating arguments with partial functions to simplify nesting

Suppose I have some nested functions, and the lowest level function has some arguments which are caches and/or outputs whose values need to be written to and retained between calls:

function a!(arg1, arg2)
    b!(arg1, arg2)
    return nothing
end 

function b!(a1, a2)
    c!(a1, a2) 
    return nothing
end

function c!(a1, a2)
    a1[:] = <something>
    ...
    return nothing
end

mem1 = <something>
mem2 = <something else>
a!(mem1, mem2)

In the above, it’s clear that a! and b! need to have the arguments to c! in their call signatures so that they can be passed down. This means that if I want to change c! to be a different function, I have to change a! and b! as well, which is cumbersome if c! takes a lot of arguments or if I want to try many different functions.

Suppose instead I do the following:

function a!()
    b!()
    return nothing
end 

function b!()
    c!() 
    return nothing
end

function f!(a1, a2)
    a1[:] = <something>
    ...
    return nothing
end

mem1 = <something>
mem2 = <something else>
c!() = f!(mem1, mem2)
a!()

In this case, a! and b! don’t need to know anything about the arguments to c!, which makes it easier to redefine c! to be different things. Is there a reason I shouldn’t do this, particularly in terms of performance (it is the case that in my application, a! will be called many times)? I’d also happily welcome any other suggestions to achieve this goal of not having to pass the arguments down in the first way (I thought about using some kind of data structure to encapsulate all the arguments to c! then pass that down, but it’s also important that these functions are able to be compiled to a GPU kernel, so any argument that is not isbits won’t work).

Thank you for your help!

1 Like

You want closures probably:

https://m3g.github.io/JuliaNotes.jl/stable/closures/

1 Like

This looks promising, and there seems to be an example for CUDA here, thank you! It seems to mention that in certain cases, there can be a performance penalty associated with type instability - any tips on how to avoid this?

Sounds like c! is a callback. You can handle it simply by

call(cb, args...) = cb(args...)

function c!(a1, a2)    # your example 
    a1[:] = <something>
    ...
    return nothing
end

call(c!, mem1, mem2)

call would not have to change. It can pass arbitrary arguments to an arbitrarily changing callback function.

If you just want to pass arguments down the chain, then use the variable-arity syntax:

function a!(args...)
    b!(args...)
    return nothing
end 

function b!(args...)
    c!(args...) 
    return nothing
end

function c!(a1, a2)
    a1[:] = <something>
    ...
    return nothing
end

mem1 = <something>
mem2 = <something else>
a!(mem1, mem2)

It is nothing special about closures, it is for Julia in general. Avoid non-constant global variables. The only thing is that the closure syntax might trick one to think that the variable is not constant, if one misses the point that the closure is parsed at the scope of the calling function. The “Scope” section of that link explains that in detail.

This is the same problem you will have if you do:

julia> mem1 = 1.
1.0

julia> mem2 = 2.
2.0

julia> f(x,y) = x*y
f (generic function with 1 method)

julia> c() = f(mem1,mem2)
c (generic function with 1 method)

julia> @code_warntype c()
Variables
  #self#::Core.Compiler.Const(c, false)

Body::Any
1 ─ %1 = Main.f(Main.mem1, Main.mem2)::Any
└──      return %1

julia>

In both cases you can solve this problem by defining the variables as “consts”:

julia> const a1 = 1. 
1.0

julia> const a2 = 2.
2.0

julia> d() = f(a1,a2)
d (generic function with 1 method)

julia> @code_warntype d()
Variables
  #self#::Core.Compiler.Const(d, false)

Body::Float64
1 ─ %1 = Main.f(Main.a1, Main.a2)::Core.Compiler.Const(2.0, false)
└──      return %1


Or, with the closures or simply passing the function as an argument, you can also add a function barrier:

julia> mem1 = 1.; mem2 = 2.;

julia> f(x,y) = x*y
f (generic function with 1 method)

julia> c(f) = f(mem1,mem2) # type unstable, mem1 and mem2 are global
c (generic function with 1 method)

julia> @code_warntype c(f)
Variables
  #self#::Core.Compiler.Const(c, false)
  f::Core.Compiler.Const(f, false)

Body::Any
1 ─ %1 = (f)(Main.mem1, Main.mem2)::Any
└──      return %1

julia> h(f,x,y) = f(x,y) # type stable
h (generic function with 1 method)

julia> @code_warntype h(f,mem1,mem2)
Variables
  #self#::Core.Compiler.Const(h, false)
  f::Core.Compiler.Const(f, false)
  x::Float64
  y::Float64

Body::Float64
1 ─ %1 = (f)(x, y)::Float64
└──      return %1



2 Likes

How does the idea of non-constant variables change for arrays? Meaning, if my argument is an array where I need to change the elements but not the type or length, is this something that would fall under a constant (since the pointer to the array could be) or non-constant? If the former, how would I declare this?

Thanks for the suggestion! Unfortunately I can’t pass functions as arguments to GPU kernels since functions are not isbits. I’ll keep this in mind for the CPU component though!

Yes, all that. You can mutate the values of a constant array:

const a = [0,0]
a[1] = 1 # works
a = [1,0] # error

That’s not true. Most functions are isbits, and can be passed to GPU kernels. Closures that capture type-unstable variables or non-const globals aren’t, and that’s just because the generated struct will contain a badly-typed field (as illustrated above).

Closures have extra problems related to type instability. See this answer of mine. If you hit the problem with type-unstable closures you may want to use the callable structs described there. However, I am not sure if you should not simply pass a function (without captured arguments) and pass the arguments down the chain with args... as other pointed out.

Well, that depends on what the OP wants exactly. With that option you cannot pass a different function to the outer function, just different arguments.

Also, the args... option, in my opinion, is quite limited and hard to maintain. I do not see many advantaged in that relative to putting the arguments of the nested function into a data structure. Meaning, I like the second option better:

julia> f(x,y) = x*y
f (generic function with 1 method)

julia> g(args...) = f(args...) # No idea what "args..." should be, and order dependent
g (generic function with 1 method)

julia> x = 1; y = 2;

julia> g(x,y)
2

julia> h(fArgs) = f(fArgs.x,fArgs.y) # I know what is going on
h (generic function with 1 method)


julia> fArgs = (x=1,y=2)
(x = 1, y = 2)

julia> h(fArgs)
2

Technically, one of these arguments may be a function.

f(args...) = g(args...)
g(f, args...) = f(args...)

I agree, but at this point we can either group the arguments in a struct and use the callable struct pattern, or use ; args... instead to pass around a named tuple instead of just a tuple (with a small overhead of using keyword arguments but not specializing the function for the different groups of arguments, what can be positive in this case).

1 Like