I am calculating the gradient of a function, but I also want to know some other values that are calculated.
So in addition to wanting the loss and the gradient of the loss, I want my function to output more values.
Here is what I got so far:
using Zygote function my_fun(p, x) return p.*x.^2 end function my_loss(fun, p, x) vec = fun(p,x) return sum(vec), vec end my_p =  my_x = [1, 2, 3, 4, 5] (loss_value, vec), gradients = Zygote.pullback(()->my_loss(my_fun, my_p, my_x), params(my_p)) println("loss: ", loss_value) println("vector: ", vec) relevant_gradients = gradients((1, zeros(size(vec)))) println("gradients: ", relevant_gradients.grads) println("params: ", relevant_gradients.params)
So in this example, I actually only want my loss value, my loss gradient with respect to ‘my_p’ and the values of my_fun(my_x).
I do not want any gradients wrt my_x and I also do not want any gradients of my_fun wrt my_x or my_p. In my example I am avoiding this by feeding (1, zeros(size(vec))) to gradients, but this seems… messy?
Is there a more elegant/easier way to do this?
thanks in advance for the help!