Get intertermediate results from gradient computation

Hi,

I am calculating the gradient of a function, but I also want to know some other values that are calculated.

So in addition to wanting the loss and the gradient of the loss, I want my function to output more values.

Here is what I got so far:

using Zygote

function my_fun(p, x)
    return p.*x.^2
end

function my_loss(fun, p, x)
    vec = fun(p,x)
    return sum(vec), vec
end

my_p = [2]
my_x = [1, 2, 3, 4, 5]

(loss_value, vec), gradients = Zygote.pullback(()->my_loss(my_fun, my_p, my_x), params(my_p))
 
println("loss: ", loss_value)
println("vector: ", vec)
relevant_gradients = gradients((1, zeros(size(vec))))
println("gradients: ", relevant_gradients.grads)
println("params: ", relevant_gradients.params)

So in this example, I actually only want my loss value, my loss gradient with respect to ‘my_p’ and the values of my_fun(my_x).
I do not want any gradients wrt my_x and I also do not want any gradients of my_fun wrt my_x or my_p. In my example I am avoiding this by feeding (1, zeros(size(vec))) to gradients, but this seems… messy?

Is there a more elegant/easier way to do this?

thanks in advance for the help!