I am looking at the documentation on custom layers, and do not understand the line (l::LuxLinear)(x,ps,st) = ....
. More specifically, I do not see where the variable l::LuxLinear
is used further down when the model is run.
using Lux, Random, NNlib, Zygote
struct LuxLinear <: Lux.AbstractExplicitLayer
init_A
init_B
end
function LuxLinear(A::AbstractArray, B::AbstractArray)
# Storing Arrays or any mutable structure inside a Lux Layer is not recommended
# instead we will convert this to a function to perform lazy initialization
return LuxLinear(() -> copy(A), () -> copy(B))
end
# `B` is a parameter
Lux.initialparameters(rng::AbstractRNG, layer::LuxLinear) = (B=layer.init_B(),)
# `A` is a state
Lux.initialstates(rng::AbstractRNG, layer::LuxLinear) = (A=layer.init_A(),)
(l::LuxLinear)(x, ps, st) = st.A * ps.B * x, st
Where did l::LuxLinear
disappear in the following:
rng = Random.default_rng()
model = LuxLinear(randn(rng, 2, 4), randn(rng, 4, 2))
x = randn(rng, 2, 1)
ps, st = Lux.setup(rng, model)
model(x, ps, st)
gradient(ps -> sum(first(model(x, ps, st))), ps)
Thanks! Gordon.