Zygote.jl: DimensionMismatch: matrix is not square error

Hi, I am using Zygote.jl. I’d like to change a matrix from 3x3 to 5x3 like this MWE:

using Zygote
using LinearAlgebra

function test( c )
    m = rand(3,3) * c
    @show m 
    M = Matrix(1.0I, 5, 3) * m
    @show M 
    return sum(M)
end 

gradient(test, 0.5)

Then I get

ERROR: LoadError: DimensionMismatch("matrix is not square: dimensions are (5, 3)")
Stacktrace:
 [1] checksquare at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/LinearAlgebra.jl:223 [inlined]
 [2] tr at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/LinearAlgebra/src/dense.jl:331 [inlined]
 [3] #827 at /root/.julia/packages/Zygote/CgsVi/src/lib/array.jl:686 [inlined]
 [4] #3238#back at /root/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [5] test at /root/codes/test_zygote/test_chi.jl:7 [inlined]
 [6] (::typeof(∂(test)))(::Float64) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface2.jl:0
 [7] (::Zygote.var"#41#42"{typeof(∂(test))})(::Float64) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:41
 [8] gradient(::Function, ::Float64) at /root/.julia/packages/Zygote/CgsVi/src/compiler/interface.jl:59
 [9] top-level scope at /root/codes/test_zygote/test_chi.jl:12
in expression starting at /root/codes/test_zygote/test_chi.jl:12

Is Zygote.jl currently not supporting differentiation through non-square matrix multiplication?

My Zygote version is v0.6.8, and julia version is 1.5.3.

Thanks for any reply.

The bug is in gradient(c -> sum(Matrix(c*I, 5, 3)), 1), because that tries to compute tr(Matrix(I, 5, 3)).

But you don’t need this here, you can construct const m53 = Matrix(1.0I, 5, 3) once outside the gradient call.

1 Like

Thanks a lot! @mcabbott
I can use @ignore to tell Zygote to ignore the gradient of this non-square matrix. This is the working version:

using Zygote
using LinearAlgebra

function test( c )
    m = rand(3,3) * c
    @show m 

    M = Zygote.@ignore Matrix(1.0I, 5, 3)
    M = M * m
    @show M 
    return sum(M)
end 

gradient(test, 0.5)