How to deal with Zygote sometimes "pirating" its own adjoints with worse ones?

The situation I’m in is that I have some custom arrays and some custom implementation of multiplication, e.g.

function (A::CustomMatrix * b::CustomVector)
    some_specialized_code()
end

Now if I use Zygote to derivative through code which does A*b somewhere, what will happen is that Zygote uses the adjoint defined here, which is:

@adjoint function(A::AbstractMatrix * x::AbstractVector)
  return A * x, Δ::AbstractVector->(Δ * x', A' * Δ)
end

But in this particular case, the resulting adjoint is extremely inefficient as compared to if Zygote had simply taken a gradient through my custom implementation in some_specialized_code().

My question is whether there is a way to keep Zygote from “pirating” its own adjoint like this, or whether there is an easy way to write a custom adjoint that forwards the adjoint to some_specialized_code() instead? Thanks.

function (A::CustomMatrix * b::CustomVector)
    some_specialized_code(A,b)
end

@adjoint function (A::CustomMatrix * b::CustomVector)
    Zygote.pullback(some_specialized_code,A,b)
end
2 Likes

Perfect, thanks!

Followup question, but is there any even more automatic way to do it that doesn’t refer to some_specialized_code()? I ask because in some cases, I don’t actually have access to this function as its just a function in Base. This happens e.g. for b' * A. If Zygote derived through Base's definition which is (A'*b)', things would work fine since A' isa CustomMatrix so it’d fall to the above adjoint definition suggested by Chris. Instead though its using Zygote’s own definition for adjoint of (::AbstractMatrix * ::AbstractMatrix) which is now inefficient.

I guess ideally I need something like:

@adjoint function (b_adj::Adjoint{<:Any, CustomVector} * A::CustomMatrix)
    # teach Zygote that for this call, explicitly derive 
    # through the code that `b_adj * A` actually calls, rather 
    # than using the rule for adjoint of `b_adj * A`
end

Is anything like this possible? Or maybe there’s a better design that side-steps the need for this? Thanks.