Custom Layer in Lux.jl

I am looking at the documentation on custom layers, and do not understand the line (l::LuxLinear)(x,ps,st) = ..... More specifically, I do not see where the variable l::LuxLinear is used further down when the model is run.

using Lux, Random, NNlib, Zygote

struct LuxLinear <: Lux.AbstractExplicitLayer
    init_A
    init_B
end

function LuxLinear(A::AbstractArray, B::AbstractArray)
    # Storing Arrays or any mutable structure inside a Lux Layer is not recommended
    # instead we will convert this to a function to perform lazy initialization
    return LuxLinear(() -> copy(A), () -> copy(B))
end

# `B` is a parameter
Lux.initialparameters(rng::AbstractRNG, layer::LuxLinear) = (B=layer.init_B(),)

# `A` is a state
Lux.initialstates(rng::AbstractRNG, layer::LuxLinear) = (A=layer.init_A(),)

(l::LuxLinear)(x, ps, st) = st.A * ps.B * x, st

Where did l::LuxLinear disappear in the following:

rng = Random.default_rng()
model = LuxLinear(randn(rng, 2, 4), randn(rng, 4, 2))
x = randn(rng, 2, 1)
ps, st = Lux.setup(rng, model)
model(x, ps, st)
gradient(ps -> sum(first(model(x, ps, st))), ps)

Thanks! Gordon.

So the syntax (x::Type)(y) = ... means if we have a variable a of type Type and write a(b) this method would be called with a bound to x and b bound to y.
One could create a similar effect by defining somename(x::Type, y) and calling it with somename(a, b), so this is just syntax to skip having this extra function name and make the type callable by itself.

In your example you have the variable model that is of type LuxLinear, so when you write model(x, ps, st) you call the method (l::LuxLinear)(x, ps, st) with model bound to l.

1 Like