ZygoteRules dispatch on traits?

Is it possible to specialize a ZygoteRule using a Holy trait? I’m not sure how I would go about doing this.

OK, here’s what I have come up with so far (doesn’t work, but shows what I’m trying to do):

using Zygote
abstract type SymbolicTrait end
struct IsSymbolic <: SymbolicTrait end
struct NotSymbolic <: SymbolicTrait end
function SymbolicTrait(x)
    if typeof(x) <: Int
        IsSymbolic()
    else
        NotSymbolic()
    end
end

function bloop(x)
    x^2
end

Zygote.@adjoint function bloop(x)
     bloop_adjoint(SymbolicTrait(x),x) #How do I actually refer to the adjoint defined by the @adjoint macro?
end

Zygote.@adjoint function bloop(::IsSymbolic, x)
    function bloop_adjoint(d)
        2*d
        print("Symbolic")
    end
    (x^2,bloop_adjoint)
end
Zygote.@adjoint function bloop(::NotSymbolic, x)
    function bloop_adjoint(d)
        2*d
        print("NonSymbolic")
    end
    (x^2,bloop_adjoint)
end

I’m using Int/Float to stand in for Symbolic/numeric types for simplicity of the example.

julia> Zygote.gradient(bloop,1)
ERROR: UndefVarError: bloop_adjoint not defined
Stacktrace:
 [1] adjoint
   @ ./REPL[7]:2 [inlined]
 [2] _pullback(__context__::Zygote.Context, 276::typeof(bloop), x::Int64)
   @ Main ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:57
 [3] _pullback(f::Function, args::Int64)
   @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:34
 [4] pullback(f::Function, args::Int64)
   @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:40
 [5] gradient(f::Function, args::Int64)
   @ Zygote ~/.julia/packages/Zygote/0da6K/src/compiler/interface.jl:58
 [6] top-level scope
   @ REPL[10]:100:

The function you call within bloop(x) isn’t visible within that bloop method. Instead, try returning bloop(SymbolicTrait(x),x).

You could also at that point remove the @adjoint annotations from your other two bloop methods, as they’re only interfaced with internally from your first bloop method; Zygote will only deal with the returned adjoint and primal from bloop(x), which wraps the other two methods.

Lastly, unless you’re specifically overloading a method which has an existing adjoint, you might consider moving to ChainRulesCore to define your rrules, rather than sticking with ZygoteRules.@adjoint. I believe the plan is to migrate away from ZygoteRules and over to ChainRules exclusively in the future.

1 Like