Understanding Flux.jl use of `gradient` and `params`

I’m reading the “60-minutes blitz” for Flux, and I’m trying to understand more deeply how the package works. Now,
I have the following code:

using Flux: params

W = randn(3, 5)
b = zeros(3)
x = [3,1,0,1,2]

y(x) = sum(W * x .+ b)

grads = gradient(()->y(x), params([W, b]))

grads[W], grads[b]

What I find odd in the code above is that the function gradient is supposedly taking a function ()->y(x) that has no arguments. BUT, the function somehow knows that inside y() I’m using the variables W and b. Why does Flux uses this roundabout way? Why not just:

func(W,b) = sum(W*x.+b)
grads = gradient(func, W,b)

Also, with this first implementation grads = gradient(()->y(x), params([W, b])), I can use grads[W], but if I change the value of W, then grads[W] won’t work anymore… So, what exactly is grads storing?

The motivation for params is better explained in the main docs: Basics · Flux. There are both advantages and disadvantages to using implicit params: have a read through the discussion around GitHub - FluxML/Optimisers.jl: Optimisers.jl defines many standard optimisers and utilities for learning loops. if you’re interested.

1 Like

Thanks for the answer. But it’s still not clear to me (after reading the documentation link), what exactly is params storing. Like in the code I posted above, I find odd the way I’m passing a function without argument (i.e. ()->y(x)) and yet somehow gradient knows that I’m taking the derivative with respect to weights only. I guess I’m trying to understanding how this “implicity” is working, cause I’m not fully understanding what is going on with params. Like, what type of “entity” does it store. I thought it contained a “pointer” to W and b, but it seems that’s not the case.

I’ll read the discussion of your second link, perhaps this will clarify this “implicity” you talked about.

Maybe this bit of Zygote’s docs is helpful too? It stores an IdSet, which is a relative of Base’s IdDict, both using objectid(W) as keys. Which is not unlike a pointer to the array.

julia> Params([W,b]).params
Zygote.IdSet{Any} with 2 elements:
  [0.0, 0.0, 0.0]
  [-0.39509977063332824 -0.1822281117995581 … -1.5008473601509207 -0.5464118649620875; -0.714130900…

Maybe the thing to know is that this knowledge is used only right at the end. Roughly, Zygote runs the function forwards, then pushes the gradient backwards, and at the end discards everything except (in this case) the gradient corresponding to the weights – Params just tells it what to keep.

1 Like

Building on Michael’s explanation, you can think of gradient having two mutually exclusive “modes”:

  1. gradient(f, args...). This does the intuitive thing of differentiating wrt. every value in args.
  2. gradient(f, ps::Params). This does the objectid based tracking described above. It’s also why you’ll often find the IdDict wrapped by the Grads struct returned from gradient will have keys that weren’t present in the original IdSet in Params. You can still access them, but the public interface of Grads pretends they don’t exist because we assume most users are only interested in gradients for passed in parameters (e.g. if you want to do some sort of regularization, you wouldn’t want rogue values contributing to that calcuation).
1 Like