Flux/Zygote: Gradient with respect to inputs and implicit parameters (in 2021)

Hello all,

I have struggled a bit to get the gradient(s) of a loss with a Zygote model both with respect to the model parameters and to the model input.
Is it even sensible to do this in one pass?
Consider this simple setup:

using Flux
layer = Dense(2,3)
layer_params = params( layer )
loss_fn(y_pred, y) = Flux.Losses.mse(y_pred, y)

x = rand(2)	    # current sample
target = ones(3)

Now both calls work as expected:
Taking the gradient with respect to the model parameters

gradient( () -> loss_fn( layer(x), target ),  layer_params )

and with respect to the inputs:

gradient( ( _x ) -> loss_fn( layer( _x ), target ),  x )

But due to the way we take gradients with respect to implicit parameters, I could not get both in one call, e.g., this does not work:

gradient( ( _x ) -> loss_fn( layer( _x ), target ), x, layer_params )

In some old discussion (which sadly I cannot find anymore) I read that you can wrap x as a Flux.Params object. However, there only appears to be a function with signature

gradient( :: Function, :: Params )

so that currently (for multiple Params) I do

input_params = params(x)
ps = union( input_params, layer_params )
gradient( () -> loss_fn( layer( x ), target ), ps )

and this works.

However, I wonder if it is performant, especially if I loop over multiple samples x.
Is there some other way to achieve what I am trying. Or is it just a dumb idea?

I think that your solution is the only solution at the moment.
I also think that the overhead would be small. In the union, you essentially create a shallow copy of IdDict and that should be pretty fast, in comparison of the price of the gradient.

You can check it out by yourself. Do few iterations where you will just take gradient with respect to parameters (no union) and then of your solutions. The preformance diff will be small.