Calling Flux.params() inside gradient changes output?

Could anyone help me understand why calling Flux.params() inside a gradient, changes the gradient? Below is a minimal example where, in get_grad_2, I instantiate a variable prowling_variable = Flux.params(m). The gradients returned in g1 are all 0.0, but g2 outputs all 1.0. It is as if the loss has sum(prowling_variable) as a regularizer. You can see that g3.grads == g2.grads, suggesting that calling Flux.params is equivalent to summing the parameters and adding it to the loss.

I know you can avoid this issue by storing the variables in a struct and not calling Flux.params in gradient. But this seems like it could be a bug.

using Flux, BenchmarkTools

m = Chain(Dense(100, 50, relu), Dense(50, 2), softmax);

opt = Descent(0.01);

data, labels = rand(Float32, 100, 100), zeros(Float32, 2, 100);

loss(m, x, y) = sum(Flux.crossentropy(m(x), y));

function get_grad(m, data, labels)
            gs = gradient(Flux.params(m)) do
              l = loss(m, data, labels)
           end
         end

function get_grad_2(m, data, labels)
            gs = gradient(Flux.params(m)) do
              prowling_variable = Flux.params(m)
              l = loss(m, data, labels)
           end
         end

function get_grad_3(m, data, labels)
            gs = gradient(Flux.params(m)) do
              l = loss(m, data, labels) + sum([sum(p) for p in Flux.params(m)])
           end
         end

g1 = get_grad(m, data, labels)
g2 = get_grad_2(m, data, labels)
g3 = get_grad_3(m, data, labels)

println(g1.grads == g2.grads)
println(g2.grads == g3.grads)

1 Like

I don’t have an answer, but I’ve been confused by the same behaviour in my own code and had to work around it.

This is almost certainly a bug. I dropped an MWE on the most likely PR culprit, so feel free to follow/comment on there.

1 Like