Flux’s typical go to when writing models is Chains, and for most feedforward models (with the occasional skip connection), with one input, it works quite beautifully.
Problem is, when defining models as functions, we lose the ability to get their params directly, or move them to the GPU. Even as the example given in the Flux docs
function linear(in, out)
W = randn(out, in)
b = randn(out)
x -> W * x .+ b
end
results in an empty parameter list
model = linear(3, 2)
params(model) #Params([])
would it be possible to extend the Flux.@functor machinery so that models defined in a format like this, for instance
function custommodel() # defines model parameters
A = Dense(20, 40)
B = Dense(40, 60)
C = Dense(60, 20)
return function(L, M, N, O) # defines the forward pass
H = A(L)
H = H * M .+ B(N) * O
H = C(H) + L
end
or any other forward pass, with any number of inputs can be written, while still returning having methods like params()
, gpu()
and so on be available? The output of custommodel() is effectively an anonymous type as well, for instance,
mymodel = custommodel()
mymodel.B # Dense(40, 60)
since Flux.@functor already works on user-defined types to extend params() and gpu() to them, would it be possible to extend to these anonymous types as well? It’d be much cleaner to write more complex models if this were to be the case