ChainRulesCore.rrule that just calls another rrule

I have defined a type Rotor{T} representing a unit quaternion. It’s possible to construct a Rotor{T} in terms of its components — which avoids the (very slightly) expensive step of normalization. There’s also a function rotor with various methods that construct a Rotor{T}, while ensuring that it is normalized. Finally, due to complicated type-trickery I have to play to get rotors to play nicely with other types, I also have Rotor (without the {T}), which is basically a thin wrapper around rotor.

The derivatives seem to do what I expect with both Rotor{T} and rotor, but now it seems that I have to explicitly define rrules for Rotor so that ChainRules doesn’t think it should be treated the same as Rotor{T}. I thought I’d be able to do something like

ChainRulesCore.rrule(::Type{Rotor}, w, x, y, z) = ChainRulesCore.rrule(rotor, w, x, y, z)

but when I try

f(a,b,c,d) = abs2(Rotor(a,b,c,d))
Zygote.gradient(f, 1.2, 3.4, 5.6, 7.8)

I get an error:

  MethodError: no method matching iterate(::Nothing)
  Closest candidates are:
    iterate(::Union{LinRange, StepRangeLen})
     @ Base range.jl:880
    iterate(::Union{LinRange, StepRangeLen}, ::Integer)
     @ Base range.jl:880
    iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
     @ Base dict.jl:698
    [1] indexed_iterate(I::Nothing, i::Int64)
      @ Base ./tuple.jl:91
    [2] chain_rrule
      @ ~/.julia/packages/Zygote/4SSHS/src/compiler/chainrules.jl:223 [inlined]
    [3] macro expansion
      @ ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:101 [inlined]
    [4] _pullback(::Zygote.Context{false}, ::Type{Rotor}, ::Float64, ::Float64, ::Float64, ::Float64)
      @ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:101
    [5] _pullback
      @ ~/test_rotor_ad.jl:76 [inlined]
    [6] _pullback(::Zygote.Context{false}, ::var"#10#21", ::Float64, ::Float64, ::Float64, ::Float64)
      @ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
    [7] pullback(::Function, ::Zygote.Context{false}, ::Float64, ::Vararg{Float64})
      @ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:44
    [8] pullback(::Function, ::Float64, ::Float64, ::Vararg{Float64})
      @ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:42
    [9] gradient(::Function, ::Float64, ::Vararg{Float64})
      @ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:96
   [10] macro expansion
      @ ~/test_rotor_ad.jl:79 [inlined]

Am I making a silly mistake? Or going about this all wrong?

I think this may have something to do with RuleConfig and calling back into AD, so maybe adding a ruleconfig argument somewhere would help. But I’m not 100% sure.

Ah, yes! This seems to work perfectly:

rrule(config::RuleConfig, ::Type{Rotor}, args...) = rrule_via_ad(config, rotor, args...)

(And the RuleConfig seems to break some ambiguities that prevented me from slurping before, so it’s even easier.) Thanks very much!

1 Like

Hmm… Except that if I use this, test_rrule(Rotor, w, x, y, z) fails when checking inferred types, complaining that

test_rrule: Rotor on Float64,Float64,Float64,Float64: Error During Test at ~/.julia/packages/ChainRulesTestUtils/C9L2i/src/testers.jl:202
  Got exception outside of a @test
  return type Tuple{NoTangent, Vararg{Float64, 4}} does not match inferred return type Tuple

It fails even if I explicitly insert the arguments instead of slurping with args.... I wonder if this will cause slowdowns.

For now I’m just passing check_inferred=false to skip that check, but I welcome any suggestions.

I’ve encountered that before, the trouble with rrule_via_ad is that it uses the AD backend to derive the pullback, so if the backend is unstable you get what’s coming for you. But in this case the backend is FiniteDifferences in ChainRulesTestUtils so that shouldn’t happen?
Ideally I’d like to call another rrule directly without resorting to rrule_via_ad but I don’t know how to do that. And when I don’t know, I ruthlessly tag @oxinabox

A lot of things break inference. so that’s not the worst. A lot of rrules have that check disabled.
and that check doesn’t test it in an ideal scenario since it isn’t put somewhere it can inline.
So often things are not so bad as they.

You will have to look into why inference is failing e.g. using Cthulhu.jl

If you are sure you will have an rrule you can directly call one rrule from another.
The no method matching iterate(::Nothing) suggest that you do not have the rrule for the function rotor.
So you need rrule_via_ad so it can determine it via the AD system.

The test config in ChainRulesTestUtils first tries to just directly use the rrule and then if it doesn’t find one it falls back to finite differences

If you want to disable that fallback you could write a new rule config

struct RuleOnlyConfig <: RuleConfig{Union{HasReverseMode, HasForwardsMode}} end
function ChainRulesCore.frule_via_ad(config::RuleOnlyConfig, ȧrgs, f, args...; kws...)

    # try using a rule
    ret = frule(config, ȧrgs, f, args...; kws...)
    if isnothing(ret)
        error("rule found")
        return ret

function ChainRulesCore.rrule_via_ad(config::RuleOnlyConfig, f, args...; kws...)
    ret = rrule(config, f, args...; kws...)
    if isnothing(ret)
        error("rule found")
        return ret

then pass that to test_rrule vis the rule_config kwarg

1 Like

That is a frequent error, do you think it deserves a mention in the docs? Sometimes it’s hard to catch because the operations on nothing go on for a while until they break later.

It’s weird cause there was an rrule indeed, or at least OP thought so.
I assume that the nothing error happens because there is a

function rrule(args...; kwargs...) end

defined somewhere. Would it make sense to get rid of it so that we can have a proper informative MethodError in these situations?

It is an intentional feature of ChainRulesCore that it returns nothing when there is no rule.
It’s from before I was involed in the project and is from when julia was a different language (like 0.4 days? When ChainRules was part of the plans for Capstan.jl).
We should consider getting rid of it was part of ChainRules 2.0
feel encourages to open an issue to revisit that decision and we will will put the 2.0 milestone on it.

That is a frequent error, do you think it deserves a mention in the docs?

Yes, we can at least add it to the FAQ.
Please open a PR

1 Like

Hi @moble ,

am I right that you were, like me, looking at Rotations.jl, found unit quaternions for rotations implemented, but were missing AD support?
I also ended up doing my own basic implementation (for the additional reason that I wanted GPU/CUDA support).
There exists an issue though for adding ChainRulesCore support to it, mentioning that PRs would be welcome.
Would you be interested in tackling this together?

Upon closer inspection, I hadn’t covered the right combination of types. Once I fixed that, it worked. Thanks!

No, I hadn’t actually seen Rotations.jl when I started my package. But more than that, I do need non-unit quaternions for a lot of things also, so it doesn’t look like it’s quite right for me. Even Quaternions.jl — which Rotations.jl appears to base its QuatRotation on — didn’t really fit what I was going for. At this point, my package has grown to be quite large, and fits my needs well, so I don’t think I’ll be contributing to Rotations.jl.

But certainly feel free to use anything I’ve developed (currently working in my chainrules branch).

1 Like

Thanks for the heads up!
I’ll give it a shot at some point