[ANN] Flux v0.5

Hey all,

I’m pleased to announce the 0.5 release of the Flux machine learning library. Here’s an incomplete list of things that have changed since the last announcement.

  • Experimental JIT compilation work for models, applying optimisations such as pre-allocating memory.
  • Run models in the browser via Flux.JS
  • Among many GPU performance improvements, CUDNN integration for RNNs – RNNs will now be much faster, no changes to your code needed!
  • A new and improved N-dimensional API for convolutions, whose CPU versions now have pure-Julia implementations.
  • Regularisation of model weights
  • Stable APIs for saving and loading models
  • Tracked scalars for more advanced AD use cases
  • Much new functionality and layers, including the GRU RNN, numerically-stable log versions of softmax and sigmoid, binary cross entropy, permutedims, kronecker product, and much more.
  • Many more models in the model zoo

Thanks to all contributors to Flux!

– Mike

28 Likes

This looks great!

Both tracked scalars and derivatives of kron will be useful for something I was trying to do, not involving neural networks at all. So maybe this is a good place to ask: What are your thoughts on using Flux.Tracker as a way to do AD more generally?

I’ve learned quite a bit by taking it apart, and it seems good at what I want, right now (once I fix a few bugs!). But I presume that the fact that it’s not its own package implies that I should be wary about its future. I have heard about Cassette.jl don’t understand much about either what this would replace, or when.

Please do! It should be pretty effective for any array code in Julia – that’s all ML models are anyway. And any improvements you can make will certainly improve things for everyone.

I’d even be open to splitting it into it’s own package if there’s a good reason to, though I’d hope that there’s no real downside to just importing Flux directly.

If all goes well with Cassette/Capstan, we’d like to replace Flux’s AD with that, but that’s a long way off (perhaps a year or more). For the time being, you can consider the AD to be part of the supported Flux interface and rely on it.

OK, that’s good to hear.

The new Base.kron is much neater than my implementation! I’ll see if there are any other bits to contribute back.

One more question. Suppose I want to provide (an approximation to) the gradient of some big slow function f, as below. Most of the work of calculating this f∇(x) is just calculating f(x). I think this has already been done on the forward pass… is there any way that I can access it, or save it for use by some f∇(x, fval)?

using Flux.Tracker: back, track, grad, TrackedArray

f(x::TrackedArray) = track(f, x) ## scalar output

Flux.Tracker.back(::typeof(f), Δ, x) = back(x, Δ * f∇(x))

xp = param(rand(3,3))
fxp = f(xp)
Flux.back!(fxp)
grad(xp)

There’s https://github.com/simonster/Memoize.jl and similar but I don’t know if using it could cause issues with AD.

Yup, easy peasy, just overload this method instead.

1 Like

Perfect, thanks!

And so that I don’t forget: Flux.Tracker.back_(::typeof(f), fval, Δ, x) = back(x, Δ * f∇(x,fval))

Can I use for loops and arbitrary Julia constructs with flux? How about autodiff with functions from packages not written with this in mind? What are the constraints with doing so?

This package is really great. It’s amazing how close the julia code in the optimisers file is to their arxiv papers! The regularisation part of the documentation is great for the same reason. The code is just really welcoming to a beginner.

Just curious what the plans are for parallel CPU training?

2 Likes

Not an expert but loops and branches in what operations you perform on a tracked array should be no problem.

But creating a new array by iteration over elements (e.g. writing your own matrix multiplication with loops) will be problematic. If there’s one step which can’t be handled it is easy to splice a ForwardDiff.gradient there, or a hand-written derivative.

Right. Basically, the AD works in terms of high-level array operations (e.g. broadcast, reduce, or linear algebra). As long as a you stick to those it’ll work (even in package code), even if it has crazy control flow, or uses recursion, or whatever. The treebank model is a good example of writing a model as a recursive function.

As @improbable22 said, you’ll have an issue if a function is “lower level” and implements works by looping over array elements, or similar, but in that case you can just tell Flux what the gradient is directly.

1 Like