Lifting a Julia function into a Flux "layer"

Flux confused me for a long time. I got the first example in the docs to work:

using Flux.Tracker

W = param(rand(2, 5))
b = param(rand(2))

predict(x) = W*x .+ b
loss(x, y) = sum((predict(x) .- y).^2)

x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3

params = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), params)

But other examples use a params function. This seems to be required to use optimizers like SGD and ADAM.

I spent a few hours trying to use ADAM to optimize a simple example, with no luck. Flux’s optimization seems to require a special “layer”, which is not just a function, but requires some kind of wrapper.

Is there a way to “lift” an arbitrary function so Flux will treat it as a layer?

You need to call the function with Params so that flux can track the gradients. params only returns the Params inside a struct.

Ok, so the docs have this example:

opt = SGD([W, b], 0.1) # Gradient descent with learning rate 0.1

opt() # Carry out the update, modifying `W` and `b`.

So let’s try it for the case above:

using Flux
using Flux.Tracker
using Flux.Optimise

W = param(rand(2, 5))
b = param(rand(2))

predict(x) = W*x .+ b
loss(x, y) = sum((predict(x) .- y).^2)

x, y = rand(5), rand(2) # Dummy data

par = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), par)

opt = SGD(par, 0.1)

This returns

MethodError: no method matching length(::Params)
Closest candidates are:
  length(!Matched::Core.SimpleVector) at essentials.jl:576
  length(!Matched::Base.MethodList) at reflection.jl:728
  length(!Matched::Core.MethodTable) at reflection.jl:802

But it doesn’t make sense, anyway - the call uses parameters, but has no reference to the loss function.

It seems that for proper Flux model, [W,b] or params(m) must store the loss function under the hood. But it’s not clear how to connect this for a hand-rolled loss function.

Not necessary. In many ML libs nowadays optimizers are as simple as a single function that takes 2 arguments - parameter and its update - and applies one to another according to optimizer-specific rules (e.g. for GD it’s just x := x - \alpha \Delta). Flux’s Param struct already contains parameter value and its delta , so optimizer doesn’t care about loss function or whatever.

Now I’m really confused. I can imagine precomputing the gradient and having an optimization algorithm maintain state including momentum, learning rate, etc. But this sounds like the not only the gradient, but also the update is computed before passing to the optimizer.

So, when you call params(m), it computes what update to apply, independent of any context?

I would really appreciate if anyone can post an “easy example done the hard way”, maybe something like a simple linear or logistic regression using the SGD function from Flux.

No, params(m) only returns the parameters stored inside struct m.
When y = f(x::TrackedArray) is called, the output f(x) is returned as a tracked value y, and the gradient of f w.r.t. x is stored inside x.grad (technically, I think you have to call back(y) for the gradients to be computed). When opt is called, it takes x.grad and modifies x according to

  1. x.grad
  2. ´The specific learning rule of the optimizer, e.g., ADAM, SGD etc.

So the whole thing hinges on the fact that you do your calculations using TrackedArray/TrackedReal so that Flux can keep track of the gradients. The optimizers store a pointer to the parameters that they are supposed to update, and the gradients are stored inside these TrackedArrays.

I’m not sure any more what you mean by “parameters”, since it’s so overloaded.

  • Param contains a value (of arbitrary type) and a “\Delta” of the same type, which is somehow related to how the value is to be updated, and so maybe to its gradient.
  • Params is another struct, containing only a value called params.
  • But params is also a function that can be called on a model.
  • params is also the name of the argument passed to optimizers like SGD.

It seems to me that param says “this thing should be considered a parameter” and returns a TrackedArray, in order to track gradients.

I’ve gone through the docs some more, and I think I’m making progress. There’s this example:

struct Affine

Affine(in::Integer, out::Integer) =
  Affine(param(randn(out, in)), param(randn(out)))

# Overload call, so the object can be used as a function
(m::Affine)(x) = m.W * x .+ m.b

a = Affine(10, 5)

a(rand(10)) # => 5-element vector

At the bottom it mentions treelike, which seems to be the important piece I was missing. I guess it’s maybe better to think “treelike!”, because its side effect is the whole point. Once you call this, you can call train! and it seems to actually work.

I’ll try to build logistic regression in this way and see how it goes. Thank you @dfdx and @baggepinnen for your help!


I just started playing with Flux and am confused by the meaning/purpose of Params as well. The intro docs are not clear at all on this.

1 Like