How does one define an rrule
for a function, where the argument you want the derivative for is a key word argument (kwargs) ?
Here is some previous discussion
but this is a while ago and the whole structure seems to have changed since then, and it also does not give a clear instruction how to achieve an rrule supporting kwargs. I noted the Zygote
supports gradients with kwargs:
julia> f(a,b;c)=a*b*c
f (generic function with 1 method)
julia> gradient((x)->f(2.0,2.0,c=x), 1.0)
(4.0,)
But how do you define an rrule
for such a function? The issue is that you can easily put kwargs
in the function header, but the return type needs to be a Tuple
of Tangents
or base types, but you cannot mix named and non-named tuples. Also reversing the orders does not seem to work. I did not find anything on this in the docs of ChainRulesCore
. Does anyone know how to achieve this within ChairRulesCore.jl without being specific about which AD system to use?