What is zip(data) doing to the data, and what type does Flux::train! expects for the data argument?
I am confused because data is an array of 60 batches, and zip as only the one argument data. If I pass data directly, it stalls and does nothing. I suspect it has to do with the fact that the type of zip(data) is iterator, and may-be that is what Flux::train! needs for efficiency, but since I can’t find documentation for train!, I have no clue.
Maybe here’s another way to think about it (or at least, here’s how I think about it)-
The data argument to train! just has to be an iterable of tuples that are splatted to loss. train! pretty much just does this:
for datapoint in data
loss(datapoint...)
end
(Of course, it’s actually slightly fancier that than, since it takes the gradient and updates the parameters and all that stuff, but the training loop itself is quite straightforward, I think- the source is here, in case you’d like to take a look!)
So, if you’ve defined your loss function to look like this:
function loss(x, y)
# ...
end
…then you can pass in a vector of (x, y) tuples to train! as the data argument. It could just as easily be loss(a, b, c), if you calculate your model’s loss that way, where you’d want to pass in (a, b, c) tuples.
This case is interesting in that — since it’s an autoencoder — the data is itself the label. That means that the loss function can be defined with just one argument.
A more typical use-case of flux might do something like:
Where features is the vector of all the input data, and labels is the corresponding vector of their corresponding known outputs. Zipping them together converts the two vectors to a single vector with each datapoint in the same tuple as its label.
You could define an auto-encoder with the loss definition above just by zipping data with itself: zip(data, data), or you could do as the model zoo does: just recognize that a single argument is sufficient and then the one-argument zip is just a cute way of putting each element of the data vector into a 1-tuple that can be splatted into loss.
Oh I see, yes that’s a good way to put it - I pay atttention now the loss function had only one argument, because of this particular case (loss(x) = mse(m(x), x)).
Yes that helps - and thank you also for linking the source code, good idea to go look, especially in Julia, where the source code is often relatively concise and readable.