I am hoping to implement a custom “layer” soon.
- I think the requirements for this are
a) it accepts input from previous layer, and produces outputs,
b) Flux.params() can find its parameters,
c) gpu() can find its parameters to transfer to the gpu.
Is this sort-of correct?
- How to cause b),c) to happen?
This flux-custom-model-feedback-of-the-output-layer-to-the-input-layer discussion indicates that overriding Flux.trainable(mytype)
and Flux.@functor mytype
are involved.
However, what is the difference between Flux.trainable and Flux.@functor?
The documentation seems to suggest that these have the same purpose.
The first way of achieving this is through overloading the trainable
function.
Another way of achieving this is through the @functor
macro directly.
Actually, I think that the should be different. A model might need some extra arrays, that are not directly parameters for AD, and gpu() should move these extra arrays to the gpu. (Just a guess, I cannot give an example).
I am still trying to figure this out too, but here is what I understand (perhaps incorrectly) so far.
The trainable
function takes a Layer and returns it’s trainable parameters (Flux.params
uses trainable
to collect params from all layers in a model). If you implement a new trainable
method for your layer, Flux will use that to sort out how to differentiate and update it.
The other option is to implement a functor (callable struct) for your layer, and apply the @functor
macro. This macro turns your functor into a standard shape used inside Flux so that the already implemented trainable
method works. This shape is a tuple of (trainable parameters, function that operates on input).
I have only used the second option, so I’m not too sure what else might be needed to get the first to work.
The cpu
and gpu
functions are wonderfully sneaky, they just walk the expression tree of your function, find any arrays, and mark them as Array
or CuArray
. So it doesn’t matter if it is a trainable parameter or not, if your function uses it it will be copied and use the correct method for the device you asked for.
The most helpful resource for sorting this out is the source.
This is helpful, thank you.
I wish the docs covered this