Is it possible to specialize a ZygoteRule using a Holy trait? I’m not sure how I would go about doing this.
OK, here’s what I have come up with so far (doesn’t work, but shows what I’m trying to do):
using Zygote
abstract type SymbolicTrait end
struct IsSymbolic <: SymbolicTrait end
struct NotSymbolic <: SymbolicTrait end
function SymbolicTrait(x)
if typeof(x) <: Int
IsSymbolic()
else
NotSymbolic()
end
end
function bloop(x)
x^2
end
Zygote.@adjoint function bloop(x)
bloop_adjoint(SymbolicTrait(x),x) #How do I actually refer to the adjoint defined by the @adjoint macro?
end
Zygote.@adjoint function bloop(::IsSymbolic, x)
function bloop_adjoint(d)
2*d
print("Symbolic")
end
(x^2,bloop_adjoint)
end
Zygote.@adjoint function bloop(::NotSymbolic, x)
function bloop_adjoint(d)
2*d
print("NonSymbolic")
end
(x^2,bloop_adjoint)
end
I’m using Int/Float to stand in for Symbolic
/numeric types for simplicity of the example.
julia> Zygote.gradient(bloop,1)
ERROR: UndefVarError: bloop_adjoint not defined
Stacktrace:
[1] adjoint
@ ./REPL[7]:2 [inlined]
[2] _pullback(__context__::Zygote.Context, 276::typeof(bloop), x::Int64)
@ Main ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57
[3] _pullback(f::Function, args::Int64)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:34
[4] pullback(f::Function, args::Int64)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:40
[5] gradient(f::Function, args::Int64)
@ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:58
[6] top-level scope
@ REPL[10]:100:
The function you call within bloop(x)
isn’t visible within that bloop
method. Instead, try returning bloop(SymbolicTrait(x),x)
.
You could also at that point remove the @adjoint
annotations from your other two bloop
methods, as they’re only interfaced with internally from your first bloop
method; Zygote will only deal with the returned adjoint and primal from bloop(x)
, which wraps the other two methods.
Lastly, unless you’re specifically overloading a method which has an existing adjoint, you might consider moving to ChainRulesCore
to define your rrule
s, rather than sticking with ZygoteRules.@adjoint
. I believe the plan is to migrate away from ZygoteRules
and over to ChainRules
exclusively in the future.