This is a MWE for a problem i’m having, I’ve adapted matrix multiplication to accept trans
arguments telling whether the matrix is to be transposed.
As for the adjoint definition, I looked at NNlib.batched_mul
When I try to get the gradient of x.x^T w.r.t. x using my_mul
, get an error shown below.
x = rand(2, 3)
function my_mul(A, B, transA, transB) #transA/transB is either 'N' or 'T'
(transA == 'N' ? A : A')*(transB == 'N' ? B : B')
end
my_mul(x, x, 'N', 'T')
using Zygote
Zygote.@adjoint function my_mul(A, B, transA, transB)
d = Dict('N'=>'T', 'T'=>'N')
C = my_mul(A, B, transA, transB)
C, Δ -> (my_mul(Δ, B, 'N', d[transB]), my_mul(A, Δ, d[transA], 'N'))
end
Zygote.gradient(x) do x
sum(x*x')
end # expected result
Zygote.gradient(x) do x
sum(my_mul(x, x, 'N', 'T'))
end # unexpected!
ERROR: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 2 and 3")
Stacktrace:
[1] _bcs1 at .\broadcast.jl:501 [inlined]
[2] _bcs at .\broadcast.jl:495 [inlined]
[3] broadcast_shape at .\broadcast.jl:489 [inlined]
[4] combine_axes at .\broadcast.jl:484 [inlined]
[5] instantiate at .\broadcast.jl:266 [inlined]
[6] materialize at .\broadcast.jl:837 [inlined]
[7] accum(::Array{Float64,2}, ::Array{Float64,2})
at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\lib\lib.jl:16
[8] (::Zygote.var"#41#42"{typeof(∂(#13))})(::Float64) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\interface.jl:45
[9] gradient(::Function, ::Array{Float64,2}) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\interface.jl:54
[10] top-level scope at REPL[1]:1
[11] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1088
I believe the failed broadcasting is coming from Zygote.accum
which tries to add both the adjoint results together.
Note that this only happens if I’m multiplying an array with itself, that is, replacing one of the x
’s with a copy fixes it.
When I take the gradient of sum(x*x')
I do get a sensible answer.
Does anybody know what I’m doing wrong here?
Thanks in advance,
Jules