Writing complex Flux Models

Flux’s typical go to when writing models is Chains, and for most feedforward models (with the occasional skip connection), with one input, it works quite beautifully.

Problem is, when defining models as functions, we lose the ability to get their params directly, or move them to the GPU. Even as the example given in the Flux docs

function linear(in, out)
  W = randn(out, in)
  b = randn(out)
  x -> W * x .+ b
end

results in an empty parameter list

model = linear(3, 2)
params(model) #Params([])

would it be possible to extend the Flux.@functor machinery so that models defined in a format like this, for instance

function custommodel() # defines model parameters
    A = Dense(20, 40)
    B = Dense(40, 60)
    C = Dense(60, 20)
    return function(L, M, N, O) # defines the forward pass
        H = A(L)
        H = H * M .+ B(N) * O
        H = C(H) + L
    end

or any other forward pass, with any number of inputs can be written, while still returning having methods like params(), gpu() and so on be available? The output of custommodel() is effectively an anonymous type as well, for instance,

mymodel = custommodel()
mymodel.B # Dense(40, 60)

since Flux.@functor already works on user-defined types to extend params() and gpu() to them, would it be possible to extend to these anonymous types as well? It’d be much cleaner to write more complex models if this were to be the case

So why not just define your custommodel as struct here?

It’s definitely a no-brainer to do that, thing is structs can’t be redefined the way Chains can, which can make iterating on a model, especially when making changes in network structure somewhat tiresome. It’s hardly a breaking issue admittedly, but it would be a nice QOL improvement

I think I’ve seen somewhere that closures being anonymous structs is an implementation detail which it is not recommended to depend on. Of course, if it is a single purpose project then by all means, just go ahead and implement that functor in there and use it while it works.

# Too lazy to load Flux, I think something like this might work except closurefunctor shall be Flux.functor
julia> function test(a, b)
       return function(x)
          return a .+ b .* x
       end
       end
test (generic function with 2 methods)

julia> tt = test(1, [2,3]);

julia> fieldnames(typeof(tt))
(:a, :b)

julia> closurefunctor(x) = map(fn -> getfield(x, fn), fieldnames(typeof(x))), test
closurefunctor (generic function with 1 method)

julia> p,re = closurefunctor(tt);

julia> p
(1, [2, 3])

julia> re(4, [1,2,3])
#21 (generic function with 1 method)

julia> ttt = re(4, [1,2,3])
#21 (generic function with 1 method)

julia> ttt(3)
3-element Array{Int64,1}:
  7
 10
 13

Shameless plug: This package supports a static computation graph format designed to be created and manipulated programatically: GitHub - DrChainsaw/NaiveNASflux.jl: Your local Flux surgeon

Doesn’t putting the model struct into a module solve that problem? I think you can redefine it then.