The situation I’m in is that I have some custom arrays and some custom implementation of multiplication, e.g.
function (A::CustomMatrix * b::CustomVector)
some_specialized_code()
end
Now if I use Zygote to derivative through code which does A*b
somewhere, what will happen is that Zygote uses the adjoint defined here, which is:
@adjoint function(A::AbstractMatrix * x::AbstractVector)
return A * x, Δ::AbstractVector->(Δ * x', A' * Δ)
end
But in this particular case, the resulting adjoint is extremely inefficient as compared to if Zygote had simply taken a gradient through my custom implementation in some_specialized_code()
.
My question is whether there is a way to keep Zygote from “pirating” its own adjoint like this, or whether there is an easy way to write a custom adjoint that forwards the adjoint to some_specialized_code()
instead? Thanks.
function (A::CustomMatrix * b::CustomVector)
some_specialized_code(A,b)
end
@adjoint function (A::CustomMatrix * b::CustomVector)
Zygote.pullback(some_specialized_code,A,b)
end
2 Likes
Followup question, but is there any even more automatic way to do it that doesn’t refer to some_specialized_code()
? I ask because in some cases, I don’t actually have access to this function as its just a function in Base
. This happens e.g. for b' * A
. If Zygote derived through Base
’s definition which is (A'*b)'
, things would work fine since A' isa CustomMatrix
so it’d fall to the above adjoint definition suggested by Chris. Instead though its using Zygote’s own definition for adjoint of (::AbstractMatrix * ::AbstractMatrix)
which is now inefficient.
I guess ideally I need something like:
@adjoint function (b_adj::Adjoint{<:Any, CustomVector} * A::CustomMatrix)
# teach Zygote that for this call, explicitly derive
# through the code that `b_adj * A` actually calls, rather
# than using the rule for adjoint of `b_adj * A`
end
Is anything like this possible? Or maybe there’s a better design that side-steps the need for this? Thanks.