I can’t seem to update the weights selectively. See below example where x
is a set of 3 weights and I want to only update, say the first component. But can’t seem to do so. Is there a way to do that?
x = param(rand(3))
loss() = sum((x .- float.(1:3)).^2)
gs = Tracker.gradient(() -> loss(), params(x))
println(x)
update!(x, -0.01gs[x]) # works
println(x)
println(x[1])
update!(x[1], -0.01gs[x][1]) #doesn't work
println(x[1])
I am trying to use Flux and it’s AD to do this Same Stats, Different Graphs: Generating Datasets with Varied Appearance and Identical Statistics through Simulated Annealing
2 Likes
I found that this works but I am almost sure that it’s not the best solution
println(x)
update!(x, -0.01gs[x].*[1,0,0]) # works
println(x)
That can’t work because you’re trying to mutate a scalar.
Don’t get it… I just selectively update the weights… Clearly it can be done with my method in the 2nd post. Wondering if there’s a better way.
You can do as follows:
println(x[1])
x.data[1] += -0.01gs[x].data[1]
Tracker.tracker(x).grad .= 0
println(x[1])
2 Likes
I am trying to understand this. What’s the purpose of setting the .grad
to 0?
Without this step, the gradient will accumulate during each backward propagation, which is not what we want.
1 Like
Are you meant to set the head to 0 for for all grads? For just the ones you are not updating?
Tracker.tracker(x).grad .= 0
will reset all grads to 0.
Tracker.tracker(x).grad[1] = 0
will only reset the first grad to 0.
I think this should be a standard feature. Is it worthy opening a new issue in Flux.jl?
1 Like