Could anyone help me understand why calling Flux.params()
inside a gradient, changes the gradient? Below is a minimal example where, in get_grad_2
, I instantiate a variable prowling_variable = Flux.params(m)
. The gradients returned in g1
are all 0.0
, but g2
outputs all 1.0
. It is as if the loss has sum(prowling_variable)
as a regularizer. You can see that g3.grads == g2.grads, suggesting that calling Flux.params is equivalent to summing the parameters and adding it to the loss.
I know you can avoid this issue by storing the variables in a struct and not calling Flux.params
in gradient
. But this seems like it could be a bug.
using Flux, BenchmarkTools
m = Chain(Dense(100, 50, relu), Dense(50, 2), softmax);
opt = Descent(0.01);
data, labels = rand(Float32, 100, 100), zeros(Float32, 2, 100);
loss(m, x, y) = sum(Flux.crossentropy(m(x), y));
function get_grad(m, data, labels)
gs = gradient(Flux.params(m)) do
l = loss(m, data, labels)
end
end
function get_grad_2(m, data, labels)
gs = gradient(Flux.params(m)) do
prowling_variable = Flux.params(m)
l = loss(m, data, labels)
end
end
function get_grad_3(m, data, labels)
gs = gradient(Flux.params(m)) do
l = loss(m, data, labels) + sum([sum(p) for p in Flux.params(m)])
end
end
g1 = get_grad(m, data, labels)
g2 = get_grad_2(m, data, labels)
g3 = get_grad_3(m, data, labels)
println(g1.grads == g2.grads)
println(g2.grads == g3.grads)