Let’s say I have a type WeirdNumber <: Number
is so weird that I don’t want its derivative of power function (^
) be calculated by rrule
of ^
and literal_pow
, and should instead go through and differentiate its definition. Since Zygote defines AD-specific rules for literal_pow
with RuleConfig
in its code base, I had to also @opt_out
this rule. However, the following MWE didn’t work:
using ChainRulesCore: @opt_out, RuleConfig
import Base: ^
struct WeirdNumber <: Number
a::Float64
end
^(x::WeirdNumber, p::Int) = WeirdNumber(x.a ^ p)
@opt_out rrule(::typeof(Base.literal_pow), ::typeof(^), x::WeirdNumber, ::Val{p}) where {p}
@opt_out rrule(::RuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::WeirdNumber,
::Val{p}) where {p}
using Zygote
fun(x::WeirdNumber) = (x^4).a
gradient(fun, WeirdNumber(2.))
Error:
ERROR: type Nothing has no field method
Stacktrace:
[1] getproperty(x::Nothing, f::Symbol)
@ Base ./Base.jl:38
[2] has_chain_rrule(T::Type)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/chainrules.jl:21
[3] #s2948#1074
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:20 [inlined]
[4] var"#s2948#1074"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote ./none:0
[5] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:582
[6] _pullback
@ ~/Applications/project/TaylorDiff.jl/.vscode/opt_lit_pow.jl:16 [inlined]
[7] _pullback(ctx::Zygote.Context{false}, f::typeof(fun), args::WeirdNumber)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface2.jl:0
[8] pullback(f::Function, cx::Zygote.Context{false}, args::WeirdNumber)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:44
[9] pullback
@ ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:42 [inlined]
[10] gradient(f::Function, args::WeirdNumber)
@ Zygote ~/.julia/packages/Zygote/SmJK6/src/compiler/interface.jl:96