Is it possible to selectively update the params in Flux.jl?

flux
#1

I can’t seem to update the weights selectively. See below example where x is a set of 3 weights and I want to only update, say the first component. But can’t seem to do so. Is there a way to do that?

x = param(rand(3))
loss() = sum((x .- float.(1:3)).^2)
gs = Tracker.gradient(() -> loss(), params(x))

println(x)
update!(x, -0.01gs[x]) # works
println(x)

println(x[1])
update!(x[1], -0.01gs[x][1]) #doesn't work
println(x[1])

I am trying to use Flux and it’s AD to do this https://www.autodeskresearch.com/publications/samestats

0 Likes

#2

I found that this works but I am almost sure that it’s not the best solution

println(x)
update!(x, -0.01gs[x].*[1,0,0]) # works
println(x)
0 Likes

#3

That can’t work because you’re trying to mutate a scalar.

0 Likes

#4

Don’t get it… I just selectively update the weights… Clearly it can be done with my method in the 2nd post. Wondering if there’s a better way.

0 Likes

#5

You can do as follows:

println(x[1])
x.data[1] += -0.01gs[x].data[1]
Tracker.tracker(x).grad .= 0
println(x[1])
2 Likes

#6

I am trying to understand this. What’s the purpose of setting the .grad to 0?

0 Likes

#7

Without this step, the gradient will accumulate during each backward propagation, which is not what we want.

1 Like

#9

Are you meant to set the head to 0 for for all grads? For just the ones you are not updating?

0 Likes

#10

Tracker.tracker(x).grad .= 0 will reset all grads to 0.
Tracker.tracker(x).grad[1] = 0 will only reset the first grad to 0.

0 Likes

#11

I think this should be a standard feature. Is it worthy opening a new issue in Flux.jl?

0 Likes