ANN: Knet v1.2.0: iterators, iterators, iterators

The new Knet release is all about iterators: iterators for minibatching, iterators for training, iterators for monitoring, convergence etc. Why am I so excited about iterators all of a sudden? Allow me to explain:

Knet has used iterators for data generation since 2015. That was about it until recently when I was looking for a way to improve the training interface. See, at the core of every deep learning project there is a training loop that looks like this:

function train(model,data)
    for (x,y) in data
        # improve model parameters so model(x) approaches y
    end
end

And these things can run for hours or days. You want the user to have full control of this loop: how many iterations to go, how to detect convergence and quit, how to monitor progress, how to take model snapshots or measure dev accuracy every n iterations etc.

My original (non)solution was to write a new train function for every experiment. Why restrict the user with a bad interface when they can write their own 5 line loop? (of course you can take this idea a bit further and start thinking why write any package at all but that’s another discussion).

My next (pseudo)solution was to provide a train function with lots of keyword arguments. I soon gave up on that idea when it became clear that I was on my way to implementing a Turing complete programming language using keyword arguments.

Then I thought I had a brilliant flash of insight based on callback functions. See if train just accepts a callback function that gets called inside the for loop, the user can implement any behavior:

function train(model,data,callback)
    for (x,y) in data
        callback() || break
        # improve model parameters so model(x) approaches y
    end
end

You want to display a progress bar, do something every n iterations, or quit after N iterations? Just implement some callback function with state and you are all set! Brilliant? Everybody hated it. Including me. It turns out callback functions are awkward to write and do not lead to very readable code.

Then finally I rediscovered iterators, and iterators that wrap other iterators (inspired by Tqdm.jl). I knew iterators can be these lazy collections that produce their next element only when asked. (Here is a summary with doc links to refresh your memory). See, once you implement the training loop as an iterator you can pause, restart and terminate it whenever you want:

train(model,data) = ((update model and return loss) for (x,y) in data)

What I realized iterators also do is turn the for loop inside out! Make its guts visible so one has explicit control: You can monitor and display its progress, take snapshots or whatever all with very explicit and readable code. Here are some actual examples from Knet v1.2.0. (sgd is a training iterator, f is the model, d is the data):

  • To display a progress bar use progress(sgd(f,d)).
  • To run until convergence use converge(sgd(f,cycle(d))).
  • To run multiple epochs use sgd(f,repeat(d,n)).
  • To run a given number of iterations use sgd(f,take(cycle(d),n)).
  • To do a task every n iterations use: (task(x) for x in every(n, sgd(f,cycle(d))))

Each of the functions like progress, converge, sgd etc. take and return iterators. So they can be composed like crazy. Here is how to (1) train a model on dtrn, (2) measuring loss on dtst every 100 iterations, (3) quitting when dtst performance converges, and (4) displaying a progress bar from the Knet tutorial:

a = adam(model,cycle(dtrn))
b = (model(dtst) for _ in every(100,a))
c = converge(b, alpha=0.1)
progress!(c, alpha=1)

The code reads like the English description! Imagine trying to implement this using keyword arguments or callback functions… and that is why I am excited about iterators.

enjoy,
deniz

P.S. the more nitpicky reader will probably point out that I should have called these things generators or coroutines or streams or something rather than iterators, but you get the idea.

P.P.S. every(n,itr) = (x for (i,x) in enumerate(itr) if i%n == 0) should be a Julia primitive.

30 Likes

This is really nice idea.

This is really cool: it’s not my intention to do self-promotion, but this reminds me the obsession for iterators I had last summer. If I understand this correctly, you’re talking about the same idea of wrapping iterables, right?

5 Likes

This looks awesome and your enthusiasm about iterators is contagious. I can see how I might use them for numerical simulations as well :+1:

IterTools.jl has takenth which is essentially the same:

help?> IterTools.takenth
  takenth(xs, n)

  Iterate through every nth element of xs.

  julia> collect(takenth(5:15,3))
  3-element Array{Int64,1}:
    7
   10
   13
1 Like

Thanks, that blog post is awesome. I actually return to it from time to time as a reference, next to the Julia docs :smile:

3 Likes

Thanks for sharing. I wish I had read your post sooner!

Nice, @denizyuret @lostella, I like those iterators.
Did you ever consider putting some in a separate lightweight package or add them to IterTools.jl?
Halting, side effects, sampling and timing is useful in so many contexts…

I did. One more package on iteration tools does not make much sense to me, probably it would be better to include them in IterTools. The sample one could really be just an option of takenth (something like include_last=False).

2 Likes