I have defined a type Rotor{T}
representing a unit quaternion. It’s possible to construct a Rotor{T}
in terms of its components — which avoids the (very slightly) expensive step of normalization. There’s also a function rotor
with various methods that construct a Rotor{T}
, while ensuring that it is normalized. Finally, due to complicated type-trickery I have to play to get rotors to play nicely with other types, I also have Rotor
(without the {T}
), which is basically a thin wrapper around rotor
.
The derivatives seem to do what I expect with both Rotor{T}
and rotor
, but now it seems that I have to explicitly define rrule
s for Rotor
so that ChainRules
doesn’t think it should be treated the same as Rotor{T}
. I thought I’d be able to do something like
ChainRulesCore.rrule(::Type{Rotor}, w, x, y, z) = ChainRulesCore.rrule(rotor, w, x, y, z)
but when I try
f(a,b,c,d) = abs2(Rotor(a,b,c,d))
Zygote.gradient(f, 1.2, 3.4, 5.6, 7.8)
I get an error:
MethodError: no method matching iterate(::Nothing)
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen})
@ Base range.jl:880
iterate(::Union{LinRange, StepRangeLen}, ::Integer)
@ Base range.jl:880
iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
@ Base dict.jl:698
...
Stacktrace:
[1] indexed_iterate(I::Nothing, i::Int64)
@ Base ./tuple.jl:91
[2] chain_rrule
@ ~/.julia/packages/Zygote/4SSHS/src/compiler/chainrules.jl:223 [inlined]
[3] macro expansion
@ ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:101 [inlined]
[4] _pullback(::Zygote.Context{false}, ::Type{Rotor}, ::Float64, ::Float64, ::Float64, ::Float64)
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:101
[5] _pullback
@ ~/test_rotor_ad.jl:76 [inlined]
[6] _pullback(::Zygote.Context{false}, ::var"#10#21", ::Float64, ::Float64, ::Float64, ::Float64)
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[7] pullback(::Function, ::Zygote.Context{false}, ::Float64, ::Vararg{Float64})
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:44
[8] pullback(::Function, ::Float64, ::Float64, ::Vararg{Float64})
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:42
[9] gradient(::Function, ::Float64, ::Vararg{Float64})
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:96
[10] macro expansion
@ ~/test_rotor_ad.jl:79 [inlined]
Am I making a silly mistake? Or going about this all wrong?