Could somebody please explain the use of Flux.functor? Functors return functions, but I do not understand the use case. I would also like to understand the use of @functor when using Flux.

I tried the following and it worked:

a = (x,y) -> x^2 + y
b(x) = a(x,3)
b(3) == a(3,3)

returns True, as it should. So what does Flux.functor buy me? Thanks.


I’m sure someone else can put it better, but it is basically a convenience for exploring and modifying models. For “normal” usage you don’t need to pay any attention to it.

Calling functor(x) returns 1) a tuple of values which x is comprised of and 2) a function which creates a new x when given parameters of the same type as the ones returned in 1). For example, p, f = functor(x); y = f(p...) creates an identical copy of x; y.

As a motivating example, when you create a flux model, (e.g. model = Chain(Dense(2,3), SkipConnection(Chain(Dense(3,4), Dense(4,3)), +))) the parameters become “buried” in nested structs, making it cumbersome to retrieve them.

You do however need them because this is what you want the AD to compute gradients w.r.t when training the model. Flux export a method called params which recursively goes through everything returned in 1) and searches for numerical arrays and returns them in a Params struct.

As for 2), this is used when moving a model to the GPU. By default, all Flux layers create normal Arrays which reside in the RAM. If you want your model to run on the GPU you need to provide all those layers with GPU arrays (e.g CuArrays). This is also a bit of a pita and I guess that due to this, Flux exports the gpu function which recurses through functor, maps all numerical Arrays it finds to CuArrays and creates new layers with said CuArrays using 2) above.

For both these cases you could ofc just create all needed layers individually and keep track of them in some other way, but that would make things much more cumbersome.

Hope this helps.