Can Zygote do derivatives w.r.t. keyword arguments which get captured in `kwargs...`?

Zygote can do derivatives w.r.t. to keyword arguments, which is super cool:

foo(;x) = x^2
gradient(x -> foo(x=x), 2) # gives 4

In my case though, I’m doing some programatic stuff with the keywords of a certain function, so at some point I need to capture them all with a kwargs... then pass only some of them along to another function. A MWE of something like this is a function which accepts any keywords, but passes only the x keyword to another call. Here’s two ways I could think of writing this:

call_with_x_keyword_1(f; kwargs...) = f(;x=kwargs[:x])
call_with_x_keyword_2(f; kwargs...) = f(;(k=>v for (k,v) in kwargs if k==:x)...)

# both return 4, and ignore the `y` or any other keywords:
call_with_x_keyword_1(foo, x=2, y=3) 
call_with_x_keyword_2(foo, x=2, y=3) 

Unfortunately, Zygote can’t do these gradient for either one:

gradient(x -> call_with_x_keyword_1(foo, x=x), 2) # "mutating arrays is not supported"
gradient(x -> call_with_x_keyword_2(foo, x=x), 2) # "Need an adjoint for constructor Base.Iterators.Pairs{...}"

I’m wondering if there’s any suggestions for how to make this work (or maybe even if its possible at all)? The mutating array error is cryptic, I’m not sure where that would come in. I guess I could try and write that adjoint for Base.Iterators.Pairs but I’m having a hard time with where to even start with thinking about what the adjoint of that means. Finally, although I know that here I could do something like foo(;x, ignored...) = x^2, that doesn’t really work with my non-MWE example. Thanks for any help!

what about if you define a new funciton e.g.

bar(x) = foo(x=x)

Still an error, the problem I don’t think is in passing the argument, but rather in doing stuff with it from inside kwargs.

Works for me

foo(;x) = x^2
bar(x) = foo(x = x)

call_with_x(f, x) = f(x)

using Zygote

gradient(x -> call_with_x(bar, x), 2)

I have Julia 1.3.1 and Zygote 0.4.7.

I suspect the issue is your call_with_x_keyword function.

1 Like

Oh thats what you were referring to, sorry. That’s not a solution for me since it does away entirely with capturing keywords as kwargs..., which in my case is the important piece that allows me to do some programatic things with the arguments (and is also the part Zygote is struggling on).

Are these predetermined? E.g. they are always called x? Or is that variable?

This works for me if you know that x is always gonna be called x. See how I defined foo2

using Zygote

foo(;x) = x^2

foo2(;x, kwargs...) = foo(; x = x) # this will ignore all non `x` arguments.

gradient(x -> foo(x=x), 2) # gives 4

call_with_x_keyword_1(f; kwargs...) = f(;kwargs...)

gradient(x->call_with_x_keyword_1(foo; x = x), 2)

gradient(x->call_with_x_keyword_1(foo2; x = x), 2)

Or am I still not getting you?

I get that you want to differentatiate foo(;x) but kwarg can contain parameters named other than x, so you are doing kwarg[:x] (which I think is what Zygote does not support). But you can just break the kwarg... apart like foo2(;x, kwargs) = foo(;x=x)

Unfortunately, like I mentioned in the original post,