How to make Zygote not use its default @adjoint for transposition in my custom matrix representation

Hello,

#Minimal Working Example
using Flux, Zygote
struct MyMat{T}
	m::Matrix{T}
	n::Int
end

Flux.trainable(a::MyMat) = (a.m,)

Matrix(a::MyMat) = a.m

import Base:transpose
#these two functions differ in name only
function transpose(a::MyMat)
	m = mytranspose(a.m)
	MyMat(m, a.n)
end
function MatTranspose(a::MyMat)
	m = mytranspose(a.m)
	MyMat(m, a.n)
end

function mytranspose(m)
	Matrix(transpose(m))
end

Zygote.@adjoint function mytranspose(m)
	mytranspose(m), Δ -> (transpose(Δ), )
end

#this works
a = MyMat(rand(5, 5), 5)
psa= Flux.params(a)
Zygote.gradient(() -> sum(Matrix(MatTranspose(a))), psa)[a.m]
#this errors
b = MyMat(rand(5, 5), 5)
psb = Flux.params(b)
Zygote.gradient(() -> sum(Matrix(transpose(b))), psb)[b.m]

I’m trying to implement a custom matrix representation that works with Flux and Zygote. When I tried to define differentiable transposition by extending Base.transpose, I get this error when i try to take a gradient of a function that uses the transposition.

#Produced by last line of MWE
ERROR: LoadError: MethodError: no method matching transpose(::NamedTuple{(:m, :n),Tuple{FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},Nothing}})

Somehow my structure gets converted into named tuple during the computation of the gradient. After some debugging I found out that this is caused by Zygote’s default @adjoint for the transpose function link to Zygote’s github.

@adjoint function transpose(x)
  back(Δ) = (transpose(Δ),)
  back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,)
  return transpose(x), back
end

Without this @adjoint, my code works as I would expect it to, as I confirmed by naming the function a different name (as demonstrated in MWE), or alternatively by commenting out the @adjoint in Zygote’s source.
I tried to create my own @adjoint for Zygote to use with my custom representation, but I kept getting the same error, even when i confirmed that Zygote was using the @adjoint.

Is there some simple way to force Zygote to not use its own @adjoint while it differentiates my representation?

I think the adjoint you need is:

@adjoint Matrix(a::MyMat) = a.m, Δ -> (MyMat(Matrix(Δ), 0), )

I think what’s happening is that the default version for this doesn’t quite work with your matrix type and so it defaults to returning a named tuple (m=.., n=..) rather than a proper MyMat, but then the next step in the backpropagation tries to take the transpose of that named tuple which doesn’t work.

Btw, this is in some sense the same question I asked here (How to deal with Zygote sometimes "pirating" its own adjoints with worse ones?), there is a generic solution given there which might end up being helpful, since this sort of thing tends to happen alot in my experience with Zygote and custom matrix types.

The topic you linked is exactly the problem I am facing, so I used the solutions provided there (I used google instead of searching directly here, so I unfortunately did not find it earlier). This now leads to the second gradient not giving an error, but for some reason the result is ‘nothing’ instead of an array. I am not sure how it is possible, because by my understanding of how the AD works, the two gradients should now be running the same code. Curiously, computing the same gradient without using Flux.params gives the correct result.

using Flux, Zygote
struct MyMat{T}
	m::Matrix{T}
	n::Int
end

Flux.trainable(a::MyMat) = (a.m,)

Matrix(a::MyMat) = a.m

import Base:transpose
function transpose(a::MyMat)
	MatTranspose(a)
end

function MatTranspose(a::MyMat)
	m = mytranspose(a.m)
	MyMat(m, a.n)
end

function mytranspose(m)
	Matrix(transpose(m))
end

Zygote.@adjoint function mytranspose(m)
	mytranspose(m), Δ -> (transpose(Δ), )
end

Zygote.@adjoint function transpose(a::MyMat)
	Zygote.pullback(MatTranspose, a)
end

#this works
a = MyMat(rand(5, 5), 5)
psa= Flux.params(a)
Zygote.gradient(() -> sum(Matrix(MatTranspose(a))), psa)[a.m]
#this results in nothing
b = MyMat(rand(5, 5), 5)
psb = Flux.params(b)
Zygote.gradient(() -> sum(Matrix(transpose(b))), psb)[b.m]
#this now gives correct result
x = rand(5, 5)
Zygote.gradient(x -> sum(Matrix(transpose(MyMat(x, 0)))), x)[1]

I am currently looking into implementing the adjoint for Matrix, as it is not as straight forward as the solution provided in the linked topic.

Defining the adjoint for Matrix runs into the same problem as the approach used in my previous reply, differentiating with Flux.params results in ‘nothing’ being returned, but returning the expected array without Flux.params.
These are the adjoints i added to the code in OP.

Zygote.@adjoint MyMat(m, n) = MyMat(m, n), Δ -> (Δ.m, Δ.n)
Zygote.@adjoint Matrix(a::MyMat) = a.m, Δ -> (MyMat(Matrix(Δ), 0), )

I find using this quite difficult in my code, as MyMat.m is not the matrix the struct represents (unlike in MWE), but a bunch of parameters defining the matrix, so all of the functions defined on MyMat are not trivial. Thus having an adjoint of type MyMat is quite confusing for me, and I’d rather avoid it.