Is it possible to selectively update the params in Flux.jl?

I can’t seem to update the weights selectively. See below example where x is a set of 3 weights and I want to only update, say the first component. But can’t seem to do so. Is there a way to do that?

x = param(rand(3))
loss() = sum((x .- float.(1:3)).^2)
gs = Tracker.gradient(() -> loss(), params(x))

println(x)
update!(x, -0.01gs[x]) # works
println(x)

println(x[1])
update!(x[1], -0.01gs[x][1]) #doesn't work
println(x[1])

I am trying to use Flux and it’s AD to do this https://www.autodeskresearch.com/publications/samestats

2 Likes

I found that this works but I am almost sure that it’s not the best solution

println(x)
update!(x, -0.01gs[x].*[1,0,0]) # works
println(x)

That can’t work because you’re trying to mutate a scalar.

Don’t get it… I just selectively update the weights… Clearly it can be done with my method in the 2nd post. Wondering if there’s a better way.

You can do as follows:

println(x[1])
x.data[1] += -0.01gs[x].data[1]
Tracker.tracker(x).grad .= 0
println(x[1])
2 Likes

I am trying to understand this. What’s the purpose of setting the .grad to 0?

Without this step, the gradient will accumulate during each backward propagation, which is not what we want.

1 Like

Are you meant to set the head to 0 for for all grads? For just the ones you are not updating?

Tracker.tracker(x).grad .= 0 will reset all grads to 0.
Tracker.tracker(x).grad[1] = 0 will only reset the first grad to 0.

I think this should be a standard feature. Is it worthy opening a new issue in Flux.jl?

1 Like