Hello,
#Minimal Working Example
using Flux, Zygote
struct MyMat{T}
m::Matrix{T}
n::Int
end
Flux.trainable(a::MyMat) = (a.m,)
Matrix(a::MyMat) = a.m
import Base:transpose
#these two functions differ in name only
function transpose(a::MyMat)
m = mytranspose(a.m)
MyMat(m, a.n)
end
function MatTranspose(a::MyMat)
m = mytranspose(a.m)
MyMat(m, a.n)
end
function mytranspose(m)
Matrix(transpose(m))
end
Zygote.@adjoint function mytranspose(m)
mytranspose(m), Δ -> (transpose(Δ), )
end
#this works
a = MyMat(rand(5, 5), 5)
psa= Flux.params(a)
Zygote.gradient(() -> sum(Matrix(MatTranspose(a))), psa)[a.m]
#this errors
b = MyMat(rand(5, 5), 5)
psb = Flux.params(b)
Zygote.gradient(() -> sum(Matrix(transpose(b))), psb)[b.m]
I’m trying to implement a custom matrix representation that works with Flux and Zygote. When I tried to define differentiable transposition by extending Base.transpose, I get this error when i try to take a gradient of a function that uses the transposition.
#Produced by last line of MWE
ERROR: LoadError: MethodError: no method matching transpose(::NamedTuple{(:m, :n),Tuple{FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}},Nothing}})
Somehow my structure gets converted into named tuple during the computation of the gradient. After some debugging I found out that this is caused by Zygote’s default @adjoint for the transpose function link to Zygote’s github.
@adjoint function transpose(x)
back(Δ) = (transpose(Δ),)
back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,)
return transpose(x), back
end
Without this @adjoint, my code works as I would expect it to, as I confirmed by naming the function a different name (as demonstrated in MWE), or alternatively by commenting out the @adjoint in Zygote’s source.
I tried to create my own @adjoint for Zygote to use with my custom representation, but I kept getting the same error, even when i confirmed that Zygote was using the @adjoint.
Is there some simple way to force Zygote to not use its own @adjoint while it differentiates my representation?