_generic_matmatmul error in sciml_train

Iterators.flatten seems like a complicated way to make a vector here, and perhaps difficult for AD to think about. Some other ways are:

julia> u = [1,2,3];

julia> v = collect(Iterators.flatten(u * transpose(u)));

julia> v == vec(u * transpose(u))
true

julia> v == vec(kron(u,u))
true

Trying these out, it appears that ReverseDiff may not like u * transpose(u), that’s a bug surely, but the kron option works:

julia> ReverseDiff.gradient([1,2,3]) do u
         v = vec(u * transpose(u))
         v[1]
       end
ERROR: DimensionMismatch: matrix A has dimensions (3,3), matrix B has dimensions (1,3)

julia> ReverseDiff.gradient([1,2,3]) do u
         v = vec(kron(u,u))
         v[1]
       end
3-element Vector{Int64}:
 2
 0
 0

julia> Zygote.gradient([1,2,3]) do u
         v = vec(u * transpose(u))
         v[1]
       end
([2.0, 0.0, 0.0],)

julia> Zygote.gradient([1,2,3]) do u
         v = vec(kron(u,u))
         v[1]
       end
([2.0, 0.0, 0.0],)

# Errors with  collect(Iterators.flatten(...))

julia> ReverseDiff.gradient([1,2,3]) do u
         v = collect(Iterators.flatten(u * transpose(u)))
         v[1]
       end
ERROR: DimensionMismatch: matrix A has dimensions (3,3), matrix B has dimensions (1,3)
Stacktrace:
  [1] _generic_matmatmul!(C::Vector{Int64}, tA::Char, tB::Char, A::Matrix{Int64}, B::Vector{Int64}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
...

julia> Zygote.gradient([1,2,3]) do u
         v = collect(Iterators.flatten(u * transpose(u)))
         v[1]
       end
ERROR: Mutating arrays is not supported -- called copyto!(::Vector{Int64}, _...)
Stacktrace:
...
  [5] (::typeof(∂(_collect)))(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/DkIUK/src/compiler/interface2.jl:0
...
1 Like