Rrule (or frule) with kwargs

How does one define an rrule for a function, where the argument you want the derivative for is a key word argument (kwargs) ?

Here is some previous discussion

but this is a while ago and the whole structure seems to have changed since then, and it also does not give a clear instruction how to achieve an rrule supporting kwargs. I noted the Zygote supports gradients with kwargs:

julia> f(a,b;c)=a*b*c
f (generic function with 1 method)
julia> gradient((x)->f(2.0,2.0,c=x), 1.0)
(4.0,)

But how do you define an rrule for such a function? The issue is that you can easily put kwargs in the function header, but the return type needs to be a Tuple of Tangents or base types, but you cannot mix named and non-named tuples. Also reversing the orders does not seem to work. I did not find anything on this in the docs of ChainRulesCore. Does anyone know how to achieve this within ChairRulesCore.jl without being specific about which AD system to use?

2 Likes

So ChainRules doesn’t support giving deriviatives wrt to keyword arguments.
Only receiving the keyword argument primals.
Since it was found basically they are never needed, because they tend to be things like
or they tend to dispatch to something with positional arguments only under the hood.

The normal work around is to rewrap the thing into something that takes the keyword arguments in question as positional and write a rule for that.

However, it is possible.
Wether or not a particular AD supports it I can’t promiste.

For some function foo(x; a, b)

The following works for frules for Differactor:

function ChainRulesCore.frule((_, dkwargs, _, dx), ::typeof(kwcall), kwargs, ::typeof(foo), x)
    y = foo(x; kwargs...)
    dy = ...
    return y, dy
end

and is used eg in CedarEDA.jl/src/extra_rules.jl at 759d5c5c173f509829fb741e748ecffd85303ec5 · CedarEDA/CedarEDA.jl · GitHub

For reverse mode the equivalent is

function ChainRulesCore.rrule(::typeof(kwcall), kwargs, ::typeof(foo), x)
    y = foo(x; kwargs...)
    function foo_pullback(dy)
        da = haskey(kwargs, :a) ? ... : NoTangent()
        db = haskey(kwargs, :b) ? ... : NoTangent()
        dkwargs = Tangent{typeof(kwargs)}(a=da, b=db)
        dx = ...
        return NoTangent(), dkwargs, NoTangent(), dx
    end
end

In particular I am not sure Zygote will work right with this, it might, it might not.
I know this line

drops the derivatives from keyword arguments but I am not sure if it will actually be hit if you provide a overload for the rrule for kwcall directly.

Thanks a lot for your reply! I think kwargs support would be really nice and I am sure that there are many use cases. E.g. in IndexFunArrays.jl or SeparableFunctions.jl (not realeased) the offset, center and scale of each function can be modified from the default via keyword arguments. In my current usecase I want to have AD support for SeparableFunctions.jl to get a fast Gaussian fitting routine where typically one would use the pos keyword argument.

Yes, trying to define the rule for the kwcall function is what I tried as well, but so far without any success:

using ChainRulesCore
using Zygote

foo(a,b;c)=a*b*c
foo(1,2;c=3) # works
g = Core.kwcall((c = 3,), foo, 1, 2)  # works
function ChainRulesCore.rrule(::typeof(Core.kwcall), kwargs, ::typeof(foo), a, b)
    println("kwargs foo in chainrule")
    y = foo(a,b; kwargs...)
    function foo_pullback(dy)
        println("kwargs foo pulling back")
        da = dy*b*c
        db = dy*a*c
        dc = haskey(kwargs, :c) ? dy*a*b : NoTangent()
        dkwargs = Tangent{typeof(kwargs)}(c=dc)
        return NoTangent(), dkwargs, NoTangent(), da, db
    end
end
gradient((x)->foo(2.0,3.0,c=x), 1.0) # works but does not use the above chain rule

as one can see the chain rule is never used by the gradient function, but it is nevertheless able to deal with kwargs. The interesting definition you found in Zygote.jl seems to ignore kwargs, but the framework somehow does suppports it.

Since there does not seem to be a way to specify kwargs support for gradients in CainRulesCore.jl, maybe it would be time to come up with one?
E.g. One could wrap the return type in KWArgsTangent((standardTangents...),(NamedTangents...)) or something like this?
But would this be type stable?

I would be most tempted just to change Zygote so the code above worked.
Its nonambigious, though not the nicest.
And I think probably it could be made to work.
It would just need to be changed to check if there was an actual rrule for kwfunc before falling back to just calling the rule for the function with primal keywords.

And something like it would be required in order to support anything nicer anyway.
Maybe for the something nicer we can gate the rule on a conflig trait type parameter, like we do for the calling back into AD traits.
but first would definately need to make the above work in Zygote cos that is where the gradients need to end up.

As a workaround, I think for Zygote only defining

using ZygoteRules

function ZygoteRules._pullback(::typeof(Core.kwcall), kwargs, ::typeof(foo), a, b))
    println("kwargs foo in zygote rule")
    y = foo(a,b; kwargs...)
    function foo_pullback(dy)
        println("kwargs foo pulling back")
        da = dy*b*c
        db = dy*a*c
        dc = haskey(kwargs, :c) ? dy*a*b : nothing()
        dkwargs = (;c=dc)
        return nothing, dkwargs, nothing da, db
    end
end

should work.

It should get hit before the chainrule’s inferface gets hit

You are right. Changing Zygote to be able to deal with kwcall rrule definitions should be the first step. Yet Zygote does not even seem to deal with your suggestion (a few corrections):

using ZygoteRules
foo(a,b;c)=a*b*c
function ZygoteRules._pullback(::typeof(Core.kwcall), kwargs, ::typeof(foo), a, b)
    println("kwargs foo in zygote rule")
    y = foo(a,b; kwargs...)
    function foo_pullback(dy)
        println("kwargs foo pulling back")
        da = dy*b*c
        db = dy*a*c
        dc = haskey(kwargs, :c) ? dy*a*b : NoTangent()
        dkwargs = (;c=dc)
        return nothing, dkwargs, nothing, da, db
    end
end
using Zygote
gradient((x)->foo(2.0,3.0,c=x), 1.0) # does not call the above definition

Does the rrule for kwcall work directly with Diffractor.jl?

Right now rrules with keyword arguments don’t work in Diffractor at all.
Though there is a PR that fixes that and makes the above rrule for kwcall work.

For Zygote i am pretty sure that should work.
I would need to dig in and debug it to workout why it doesn’t.

Oh I relalised I had the signature of _pullback wrong, forgot the context argument.
and there were a few other issues.

I have fixed them and the following works. :tada:

using ZygoteRules
foo(a,b;c=42.0)=a*b*c
function ZygoteRules._pullback(::ZygoteRules.AContext, A::typeof(Core.kwcall), kwargs, ::typeof(foo), a, b)
    println("kwargs foo in zygote rule")
    y = foo(a,b; kwargs...)
    c =  get(kwargs, :c, 42.0)  # insert default as we need that in pullback.
    function foo_pullback(dy)
        println("kwargs foo pulling back")
        da = dy*b*c
        db = dy*a*c
        dc = haskey(kwargs, :c) ? dy*a*b : NoTangent()
        dkwargs = (;c=dc)
        return nothing, dkwargs, nothing, da, db
    end
    return y, foo_pullback
end
using Zygote
gradient((x)->foo(2.0,3.0,c=x), 1.0)

I made that a little more complex so you could see how to handle kwarg with a default.

1 Like

Thanks! This is a nice step forward. But then there is still the problem that this solution would be Zygote-specific and there seem to be no way to specify this as an rrule. I also assume that ZygoteRules is a heavier package than ChainRulesCore, and in each case there is a kwargs involved for which AD is wanted, one would need to make the package depend on it.

No ZygoteRules is no heavier than ChainRulesCore.
In fact, I expect it is technically ligher, but CRC is already so fast to load it doesn’t matter.
Its just that it is Zygote specific, and doesn’t work for any of the other 9 ADs that support ChainRulesCore (to varying extents).

There is indeed no way to specify this as a rrule that Zygote will recognize.
As I said, I think what we need to do is teach Zygote to recognize this in rrule form.
And then maybe add some syntactic sugar to do this in ChainRulesCore without the kwfunc weirdness. Maybe.

The code in Zygote that would need to be changed is:

and

Which is not the most complicated code in the world.
But it isn’t simple.

1 Like

There seem to be other problems. I tried implementing your idea for my problem, but failed when the datastructures are a little more complex:

using Zygote, ZygoteRules

foo(a; c=[42.0, 43.0]) = (a .*c, ) 
function ZygoteRules._pullback(::ZygoteRules.AContext, A::typeof(Core.kwcall), kwargs, ::typeof(foo), a)
    println("kwargs foo in zygote rule")
    y = foo(a; kwargs...)
    c =  get(kwargs, :c, [42.0, 43.0])  # insert default as we need that in pullback.
    function foo_pullback(dy)
        println("kwargs foo pulling back: dy=$dy, c=$c")
        da = (c' * dy[1],)  # (reshape(dy .*c[1], (2,1)), reshape(dy .*c[2], (1,2))) # dy .*c
        dc = haskey(kwargs, :c) ? (a * dy[1],)  : NoTangent() # #  (reshape(dy[1] .*a, (2,1)), reshape(dy[2] .*a, (1,2))) 
        dkwargs = (;c=dc)
        @show da
        return nothing, dkwargs, nothing, da
    end
    return y, foo_pullback
end

q = pullback((x)->foo(x), 1.0)[2](([1.0, 1.0],)) # works fine and returns (85.0,)
gradient((x)->sum(.*(foo(x)...)), 2.0) # works fine!

q = pullback((x)->foo(x, c=[42.0, 43.0]), 1.0)[2](([1.0, 1.0],)) # yields an error. WHY? How to fix?
gradient((x)->sum(.*(foo(x, c=[42.0, 43.0])...)), 2.0) # yields an error. WHY?

The error is
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, @NamedTuple{}})(::Vector{Float64})

Since the returned result is identical in both cases and one works and the other does not, I am a bit puzzled. It seems like Zygote is doing something else behind the scenes which may need fixing to make this idea work?