How can I go about updating scalar parameters in Flux.jl?

question
differentiation
flux

#1

Hey everyone and happy holidays,

I’ve been playing around with Flux.jl as I would like to base my new Bayesian deep learning package on this framework. When running the examples from the documentation I noticed something I couldn’t wrap my head around.

The first example in the “Basic usage” documentation in Flux which uses the update! function looks something like this

using Flux
using Flux.Tracker
using Flux.Tracker: update!

W, b = param(rand(2, 5)), param(rand(2))

predict(x) = W*x .+ b
loss(x, y) = sum((y .- predict(x)).^2)

x, y = rand(5), rand(2) # Dummy data
pars = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), pars)

update!(W, -0.1*grads[W])
loss(x, y)

which works as expected. So far so good. Now however, moving back to the example before in the same section I would like to try to update the parameters of that model and that’s where I fail. The following code shows the issue.

using Flux
using Flux.Tracker
using Flux.Tracker: update!

W, b = param(2), param(3)

predict(x) = W*x + b
loss(x, y) = sum((y - predict(x))^2)

x, y = 4, 15
pars = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), pars)

update!(W, -0.1*grads[W])
loss(x, y)

The error you get is

ERROR: MethodError: no method matching copyto!(::Float64, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0},Tuple{},typeof(+),Tuple{Float64,Float64}})

so there’s some broadcasting error as far as I can see but there’s no broadcasting done in my functions in the latest code. So a simple answer could be that Flux doesn’t support updating of scalar type parameters. But it seems a bit counter intuitive since the manual shows an example of calculating gradients of a function parameterized by scalars. Did anyone run into this or did I make a mistake somewhere?


#2

Does that return a vector or a scalar?


#3

grads[W] returns

-32.0 (tracked)

i.e. a scalar value.


#4

I have a fix in this PR. Thanks for reporting.