I’m new to Flux so apologies if this is a dumb question/ has already been answered (I cannot figure it out based on a search here and the model zoo)
I am trying to train an LSTM on a sequence of choices where more than one thing can be chosen per time period. Let’s say we have 3 periods and 5 possible choices. The “multi-hot” representation would look like this:
X₁ = [[0;1;0;0;1], [1;0;0;0;0], [0;0;1;1;0]]
where the choices at time t are
X₁[t]
I then create my labels by using the choices of the next period:
y₁ = X₁[2:end]
X₁ = X₁[1:end-1]
And run the following model:
data = Flux.Data.DataLoader((X₁, y₁))
m = Chain(LSTM(5,5), softmax)
loss(x, y) = sum(logitcrossentropy.(m.(x), y))
loss(X₁, y₁)
opt = ADAM()
ps = Flux.params(m)
for _ in 1:10000
Flux.train!(loss, ps, data, opt)
end
Which seems to work:
julia> reset!(m)
julia> m(X₁[2])
5-element Array{Float32,1}:
0.07553877
0.19324568
0.31979644
0.3028726
0.10854646
My questions are:
a) Does this seem correct?
b) When does the hidden state have to be reset?
c) How can I extend this example to batch training?
If this is more easily done in Knet I will try that.
Thanks! Any help is appreciated!