Thanks very much for your reply.
I am still very puzzled after reading the docs again. I’m working on a simplified example but I cannot make it work. Maybe you can help me get the simple case working and then I can figure out my actual use case?
# Suppose I have four sentences, each 20 words long,
# w/ embedding dim per word of 10
a1 = rand(Float32, 10,20)
a2 = rand(Float32, 10,20)
a3 = rand(Float32, 10,20)
a4 = rand(Float32, 10,20)
fake_train = cat(a1, a2, a3, a4; dims = 3)
fake_labels = Flux.onehotbatch([0,1,0,1], [0,1])
# First question: is this step necessary to give the labels a similar shape to fake_train??
fake_labels = reshape(fake_labels, (2,1,4))
DL = Flux.DataLoader((data = fake_train, label=fake_labels), batchsize=2)
l1 = LSTM(10, 5)
d1 = Dense(5,2)
function model(x, scanner, encoder)
state = scanner(x)[:,end]
reset!(scanner)
encoder(state)
end
ps = params(l1, d1)
opt = ADAM(1e-3)
loss(x,y)= Flux.logitbinarycrossentropy(model(x, l1, d1), y)
# returns a number, not an error, so the loss function is working, I believe
loss(fake_train[:,:,1], fake_labels[:,:,1])
for i = 1:3
@info i
Flux.train!(loss, ps, DL, opt)
end
ERROR: LoadError: MethodError: no method matching loss(::NamedTuple{(:data, :label), Tuple{Array{Float32, 3}, Array{Bool, 3}}})
# this error makes sense given the way I defined loss(x,y)
The docs say “If d
is a tuple of arguments to loss
call loss(d...)
, else call loss(d)
.” I tried to follow advice here and do literally that (i.e., train!(x -> loss(x...), ...)
, but that also generated an error:
ERROR: LoadError: BoundsError: attempt to access 5×20×2 Array{Float32, 3} at index [1:5, 20]
I have also tried defining loss(x,y) = logitbinarycrossentropy.(...
(i.e, broadcasting) or loss((x,y)) = ...
(i.e., defining the loss over tuples directly), but these generate errors of their own.
Maybe I am a little closer to an MWE now…?