Stop gradients from propagating through some branches of computation tree in Flux

Is it possible to perform some computation in my loss function that is excluded from the gradient, while still using the simple Flux.train! API?

Tensorflow and Keras have something called stop_gradient , which can be applied to an output to signal that it should be treated as a constant ( This is very handy when programming things like contrastive divergence or expectation maximization, where there is a part of the computation graph that should not be taken into account in the loss gradients.

Otherwise I can do the training loop myself, but having a stop_gradient in Keras was so handy that I think Flux could have something simlar?

In Flux 0.10, you can use Zygote.dropgrad for this.

1 Like

Thanks, this looks like just what I need.

However I was expecting to be able to mutate arrays within a dropgrad block. Unfortunately something like this:

A = randn(2,2)
function f(x)
  Zygote.dropgrad(A .= x)
  return sum(A + x)

doesn’t work.

What did you expect this to do? The code will still execute as normal, just no gradient will be taken. You need to put dropgrad on the line below, where it is actually used.

You mean like this?

julia> function f(x)
       A .= x
       return sum(Zygote.dropgrad(A) + x)
julia> f'(randn(2,2))
ERROR: Can't differentiate gc_preserve_end expression

It still doesn’t work. I want to treat A as a constant when taking the gradient of f, and be able to mutate the contents of A.

Could you provide an example of what you’re actually trying to achive? You could put all the mutation in a seperate function and surround it with dropgrad, but it’s difficult to say in general.

I have a function f(x), and an external array A. When f executes it mutates A and then uses the mutated array to compute its output. However when the gradient is computed I want A to be treated as if it were a constant array. In my example above:

function f(x)
  A .= x
  return sum(A + x)

I expect f'(x) to be an array of ones (whereas if A is taken into account the gradient would be an array of twos).

I am programming something like contrastive divergence. In these kind of methods, you have a mutable state that you update, but then the gradient of the loss should not consider these changes. It’s the same use-case of

Ok I found a way to do it:

A = randn(5)
mut(x) = (A .= x)
Zygote.@nograd mut
f(x) = (mut(x); sum(x .+ A))
# returns array of ones

BTW it’s very cool that this example realizes the gradient is the same for all components and returns a FillArray of ones instead of an ordinary Array.

1 Like

See also: