Hello all,
I have struggled a bit to get the gradient(s) of a loss with a Zygote model both with respect to the model parameters and to the model input.
Is it even sensible to do this in one pass?
Consider this simple setup:
using Flux
layer = Dense(2,3)
layer_params = params( layer )
loss_fn(y_pred, y) = Flux.Losses.mse(y_pred, y)
x = rand(2) # current sample
target = ones(3)
Now both calls work as expected:
Taking the gradient with respect to the model parameters
gradient( () -> loss_fn( layer(x), target ), layer_params )
and with respect to the inputs:
gradient( ( _x ) -> loss_fn( layer( _x ), target ), x )
But due to the way we take gradients with respect to implicit parameters, I could not get both in one call, e.g., this does not work:
gradient( ( _x ) -> loss_fn( layer( _x ), target ), x, layer_params )
In some old discussion (which sadly I cannot find anymore) I read that you can wrap x
as a Flux.Params
object. However, there only appears to be a function with signature
gradient( :: Function, :: Params )
so that currently (for multiple Params
) I do
input_params = params(x)
ps = union( input_params, layer_params )
gradient( () -> loss_fn( layer( x ), target ), ps )
and this works.
However, I wonder if it is performant, especially if I loop over multiple samples x
.
Is there some other way to achieve what I am trying. Or is it just a dumb idea?