This is a MWE for a problem i’m having, I’ve adapted matrix multiplication to accept `trans`

arguments telling whether the matrix is to be transposed.

As for the adjoint definition, I looked at `NNlib.batched_mul`

When I try to get the gradient of x.x^T w.r.t. x using `my_mul`

, get an error shown below.

```
x = rand(2, 3)
function my_mul(A, B, transA, transB) #transA/transB is either 'N' or 'T'
(transA == 'N' ? A : A')*(transB == 'N' ? B : B')
end
my_mul(x, x, 'N', 'T')
using Zygote
Zygote.@adjoint function my_mul(A, B, transA, transB)
d = Dict('N'=>'T', 'T'=>'N')
C = my_mul(A, B, transA, transB)
C, Δ -> (my_mul(Δ, B, 'N', d[transB]), my_mul(A, Δ, d[transA], 'N'))
end
Zygote.gradient(x) do x
sum(x*x')
end # expected result
Zygote.gradient(x) do x
sum(my_mul(x, x, 'N', 'T'))
end # unexpected!
```

```
ERROR: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 2 and 3")
Stacktrace:
[1] _bcs1 at .\broadcast.jl:501 [inlined]
[2] _bcs at .\broadcast.jl:495 [inlined]
[3] broadcast_shape at .\broadcast.jl:489 [inlined]
[4] combine_axes at .\broadcast.jl:484 [inlined]
[5] instantiate at .\broadcast.jl:266 [inlined]
[6] materialize at .\broadcast.jl:837 [inlined]
[7] accum(::Array{Float64,2}, ::Array{Float64,2})
at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\lib\lib.jl:16
[8] (::Zygote.var"#41#42"{typeof(∂(#13))})(::Float64) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\interface.jl:45
[9] gradient(::Function, ::Array{Float64,2}) at C:\Users\jules\.julia\packages\Zygote\Xgcgs\src\compiler\interface.jl:54
[10] top-level scope at REPL[1]:1
[11] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1088
```

I believe the failed broadcasting is coming from `Zygote.accum`

which tries to add both the adjoint results together.

Note that this only happens if I’m multiplying an array with itself, that is, replacing one of the `x`

's with a copy fixes it.

When I take the gradient of `sum(x*x')`

I do get a sensible answer.

Does anybody know what I’m doing wrong here?

Thanks in advance,

Jules