Is it possible to perform some computation in my loss function that is excluded from the gradient, while still using the simple Flux.train! API?
Tensorflow and Keras have something called stop_gradient , which can be applied to an output to signal that it should be treated as a constant (tf.stop_gradient | TensorFlow v2.9.1). This is very handy when programming things like contrastive divergence or expectation maximization, where there is a part of the computation graph that should not be taken into account in the loss gradients.
Otherwise I can do the training loop myself, but having a stop_gradient in Keras was so handy that I think Flux could have something simlar?
What did you expect this to do? The code will still execute as normal, just no gradient will be taken. You need to put dropgrad on the line below, where it is actually used.
Could you provide an example of what you’re actually trying to achive? You could put all the mutation in a seperate function and surround it with dropgrad, but it’s difficult to say in general.
I have a function f(x), and an external array A. When f executes it mutates A and then uses the mutated array to compute its output. However when the gradient is computed I want A to be treated as if it were a constant array. In my example above:
function f(x)
A .= x
return sum(A + x)
end
I expect f'(x) to be an array of ones (whereas if A is taken into account the gradient would be an array of twos).
I am programming something like contrastive divergence. In these kind of methods, you have a mutable state that you update, but then the gradient of the loss should not consider these changes. It’s the same use-case of tf.stop_gradient | TensorFlow v2.9.1.
A = randn(5)
mut(x) = (A .= x)
Zygote.@nograd mut
f(x) = (mut(x); sum(x .+ A))
f'(randn(5))
# returns array of ones
BTW it’s very cool that this example realizes the gradient is the same for all components and returns a FillArray of ones instead of an ordinary Array.