Custom layer in Lux

Hi,
I’m trying to implement a custom layer in Lux that is similar to a linear layer but needs to apply a user-defined mask to the weights. It needs to compute

h(x) = g(b + (W ⊙ M)(x))

where g is an activation function, b the bias vector, W the weight matrix and M the user-defined mask.

Some things that are unclear to me after reading the documentation:

  1. The weight matrix in my layer is user-definable and may need to be updated. Should the matrix be defined in the struct definition of the layer or in the parameters?
  2. My uses explicit function definitions as in the tutorial linked above, but there exists also the @compact macro. What exactly does this macro do?
  3. The definition of dense, calls Lux._vec and Lux._getproperty. How are these different from vec and getproperty?
  4. What is the rationale for using F1 and F2 as type parameters in the tutorial?
struct Linear{F1, F2} <: LuxCore.AbstractExplicitLayer
    in_dims::Int
    out_dims::Int
    init_weight::F1
    init_bias::F2
end
  1. The weight matrix in my layer is user-definable and may need to be updated. Should the matrix be defined in the struct definition of the layer or in the parameters?

In the parameters. Model structs should never contain mutable elements. See Migrating from Flux to Lux | Lux.jl Docs.

  1. My uses explicit function definitions as in the tutorial linked above, but there exists also the @compact macro. What exactly does this macro do?

See Utilities | Lux.jl Docs for a detailed description. It essentially automatically writes all the boilerplate code needed for state handling and defining initialparameters / initialstates. See this tutorial for a example showcasing both kinds of layers.

  1. The definition of dense , calls Lux._vec and Lux._getproperty. How are these different from vec and getproperty?

Mostly an implementation detail. _getproperty takes Val as input and returns nothing if no such field is present in the struct. _vec allows nothing as input.

  1. What is the rationale for using F1 and F2 as type parameters in the tutorial?

Just to specialize on the functions

1 Like