How to use Flux.train! to train custom layer?

Hi

I am trying to build a toy custom layer according to the instruction from Building Layers (Basics · Flux) in Basics · Flux, and train it with the train! method in Flux.jl. However, the loss does not decrease and the params do not seem to change. Here are my code and output.

function Linear(in, out)
	W = param(randn(out,in))
    x -> W*x
end

model=Linear(10,1)
loss(x, y) = Flux.mse(model(x), y)
opt = ADAM()
dataset = repeated((train_data, target),10)
evalcb = () -> @show(loss(train_data, target))
println(params(model))
Flux.train!(loss, params(model), dataset, opt, cb=evalcb)

The output is

Params([])
loss(train_data, target) = 39.910205426070895 (tracked)
loss(train_data, target) = 39.910205426070895 (tracked)
loss(train_data, target) = 39.910205426070895 (tracked)
loss(train_data, target) = 39.910205426070895 (tracked)
loss(train_data, target) = 39.910205426070895 (tracked)
loss(train_data, target) = 39.910205426070895 (tracked)
loss(train_data, target) = 39.910205426070895 (tracked)
loss(train_data, target) = 39.910205426070895 (tracked)
loss(train_data, target) = 39.910205426070895 (tracked)
loss(train_data, target) = 39.910205426070895 (tracked)

How can I use the train! method to train the model with custom layer? Thank you very much!

This is the problem:

Since params(model) doesn’t contain any parameters, none get updated.

You can access the paramter by Linear(2,2).W, just because it’s enclosed by the function. So I think this works:

Flux.train!(loss, Params([model.W]), dataset, opt, cb=evalcb)

But if you look at what Flux does, it usually makes a struct to hold the parameters, which is then made callable. And @treelike adds this struct to the list of things which params (and some other functions) understand.

The

Flux.train!(loss, Params([model.W]), dataset, opt, cb=evalcb)

does not work for me. The params[] remains blank. It seems that we cannot access the attribute in a function. But maintaining a struct to hold parameters works. Now the train! will update the parameters and decrease the loss. Thank you very much!