Hi,
I am calculating the gradient of a function, but I also want to know some other values that are calculated.
So in addition to wanting the loss and the gradient of the loss, I want my function to output more values.
Here is what I got so far:
using Zygote
function my_fun(p, x)
return p.*x.^2
end
function my_loss(fun, p, x)
vec = fun(p,x)
return sum(vec), vec
end
my_p = [2]
my_x = [1, 2, 3, 4, 5]
(loss_value, vec), gradients = Zygote.pullback(()->my_loss(my_fun, my_p, my_x), params(my_p))
println("loss: ", loss_value)
println("vector: ", vec)
relevant_gradients = gradients((1, zeros(size(vec))))
println("gradients: ", relevant_gradients.grads)
println("params: ", relevant_gradients.params)
So in this example, I actually only want my loss value, my loss gradient with respect to ‘my_p’ and the values of my_fun(my_x).
I do not want any gradients wrt my_x and I also do not want any gradients of my_fun wrt my_x or my_p. In my example I am avoiding this by feeding (1, zeros(size(vec))) to gradients, but this seems… messy?
Is there a more elegant/easier way to do this?
thanks in advance for the help!