 # Understanding Flux.jl use of `gradient` and `params`

I’m reading the “60-minutes blitz” for Flux, and I’m trying to understand more deeply how the package works. Now,
I have the following code:

``````using Flux: params

W = randn(3, 5)
b = zeros(3)
x = [3,1,0,1,2]

y(x) = sum(W * x .+ b)

``````

What I find odd in the code above is that the function gradient is supposedly taking a function `()->y(x)` that has no arguments. BUT, the function somehow knows that inside `y()` I’m using the variables `W` and `b`. Why does Flux uses this roundabout way? Why not just:

``````func(W,b) = sum(W*x.+b)
``````

Also, with this first implementation `grads = gradient(()->y(x), params([W, b]))`, I can use `grads[W]`, but if I change the value of `W`, then `grads[W]` won’t work anymore… So, what exactly is `grads` storing?

The motivation for `params` is better explained in the main docs: Basics · Flux. There are both advantages and disadvantages to using implicit params: have a read through the discussion around GitHub - FluxML/Optimisers.jl: Optimisers.jl defines many standard optimisers and utilities for learning loops. if you’re interested.

1 Like

Thanks for the answer. But it’s still not clear to me (after reading the documentation link), what exactly is `params` storing. Like in the code I posted above, I find odd the way I’m passing a function without argument (i.e. `()->y(x)`) and yet somehow `gradient` knows that I’m taking the derivative with respect to weights only. I guess I’m trying to understanding how this “implicity” is working, cause I’m not fully understanding what is going on with params. Like, what type of “entity” does it store. I thought it contained a “pointer” to `W` and `b`, but it seems that’s not the case.

Maybe this bit of Zygote’s docs is helpful too? It stores an `IdSet`, which is a relative of Base’s `IdDict`, both using `objectid(W)` as keys. Which is not unlike a pointer to the array.

``````julia> Params([W,b]).params
Zygote.IdSet{Any} with 2 elements:
[0.0, 0.0, 0.0]
[-0.39509977063332824 -0.1822281117995581 … -1.5008473601509207 -0.5464118649620875; -0.714130900…
``````

Maybe the thing to know is that this knowledge is used only right at the end. Roughly, Zygote runs the function forwards, then pushes the gradient backwards, and at the end discards everything except (in this case) the gradient corresponding to the weights – `Params` just tells it what to keep.

1 Like

Building on Michael’s explanation, you can think of `gradient` having two mutually exclusive “modes”:

1. `gradient(f, args...)`. This does the intuitive thing of differentiating wrt. every value in `args`.
2. `gradient(f, ps::Params)`. This does the `objectid` based tracking described above. It’s also why you’ll often find the `IdDict` wrapped by the `Grads` struct returned from `gradient` will have keys that weren’t present in the original `IdSet` in `Params`. You can still access them, but the public interface of `Grads` pretends they don’t exist because we assume most users are only interested in gradients for passed in parameters (e.g. if you want to do some sort of regularization, you wouldn’t want rogue values contributing to that calcuation).
1 Like