Compiler matmul optimization prevents constant folding?

I have a use case where I iterate over a container of small vectors. Say vs is a Vector{SVector{2, Float64}}, M is a global constant SMatrix{2,2,Float64,4}, and an illustrative use case is

function f(vs)
total = zero(eltype(vs))
for v in vs
   total += v + M*v + M*M*v
end
return total
end

To my surprise, the compiler does not fold in the constant M*M, as it right-associates the expression to M*(M*v) according to

*(A::AbstractMatrix, B::AbstractMatrix, x::AbstractVector) = A * (B*x) # from matmul.jl

Specifying (M*M)*v explicitly fixes the problem, and invokes StaticArrays’s multiplication definitions. Is this the desired behavior, or should I open an issue? (I set aside that had M not been a global constant, M*M is still a loop invariant.)

This is expected. Sure, the compiler could always be smarter about this, and maybe it should, but it’s pretty normal.

We’d basically need a transformation that can re-associate multiplications depending on what it knows about constants.

(M*M)*v will give different answers than M*(M*v), so the compiler can’t be allowed to arbitrarily reorder them without fast math flags.
So some default order has to be picked for chained operations.

Note that it still won’t do this even for types with truly associative multiplication.

1 Like

I do take issue with this definition. * is defined to be left-associative EDIT: dubious – see discussion below. There is an obvious performance reason for the above definition – a matrix multiplication chain with a vector at one end is always efficiently (optimally?) reduced from the vector end. However, there are also obvious performance reasons for @fastmath transformations, but we don’t use those without explicit permission because they violate defined semantics.

I think a bug report should be opened against the above right-associative function since it is invoked from the plain syntax A*B*x.

For your literal issue, I would use the total += v + M*v + (M*M)*v definition since it does the compile-time M*M calculation you want. However, with a slightly broader perspective, I would write your expression as total += v + M*(v + M*v). This is the same number of operations as your (M*M)*v version after constant folding, but is usually slightly-better behaved, numerically. It also doesn’t require constant-folding or storage of M*M.

2 Likes

If the matmuls inline, it could in theory, but I’m guessing it is too complicated for LLVM to pull off.

But a higher level transform would be cool.
It’d be great if more high level array-compiler like transforms were possible.

Do you have a source for that? I’ve never seen such a claim. In fact, the docstring for *(A, B::AbstractMatrix, C) explicitly says it’ll choose whichever multiplication order results in less operations:


  *(A, B::AbstractMatrix, C)
  A * B * C * D


  Chained multiplication of 3 or 4 matrices is done in the most efficient sequence, based on the sizes of the arrays. That is, the
  number of scalar multiplications needed for (A * B) * C (with 3 dense matrices) is compared to that for A * (B * C) to choose which
  of these to execute.

  If the last factor is a vector, or the first a transposed vector, then it is efficient to deal with these first. In particular x' * B
  * y means (x' * B) * y for an ordinary column-major B::Matrix. Unlike dot(x, B, y), this allocates an intermediate array.

  If the first or last factor is a number, this will be fused with the matrix multiplication, using 5-arg mul!.

It wouldn’t surprise me if we falsely promised left-associativity somewhere in the manual though.

2 Likes

Yes, I suppose my question boils down to what is assumed the default associativity (and perhaps what transformations are feasible when fastmath is permitted). @mikmoore for my particular use case, that’s an elegant solution! Though I think for code readability and maintainability (there are quite a few geometry gymnastics in my full case) I’d stick with (M*M). (If it’s of interest, if I remember right the performance difference was of order 10%, and that’s when the inner loop does considerable more work than the example I used here.)

https://docs.julialang.org/en/v1/manual/mathematical-operations/#Operator-Precedence-and-Associativity

The linked page describes multiplication as left-associative. However, it includes the footnote

The operators +, ++ and * are non-associative. a + b + c is parsed as +(a, b, c) not +(+(a, b), c). However, the fallback methods for +(a, b, c, d...) and *(a, b, c, d...) both default to left-associative evaluation.

So it makes the exception for parsing but doesn’t actually state any situations where exceptions are made. The wording of “by default” might suggest that unless a library defines an exception, they do not exist. However, matrix-chain multiplication is part of the standard library and is loaded by default as part of mere Base. The doctring for * with matrix arguments would appear to dictate such an exception.

This is to say that there appears to be some mild disagreement within the docs. The more-specific documentation of ?* has an arguement to take precedence, however.

M*(M*v) requires many less operations than (M*M)*v, assuming M is a square matrix and v is a vector! It’s O(N^3) ops vs O(N^2).
The problem is, if M is an immutable global constant, we could calculate M2 = M*M ahead of time.

Ah I see, yeah that’s highly misleading, because it’s talking about associativity at the parser level, not about what multiplication actually means or what gets executed as an optimization. We should probably clarify that section to make it more clear that methods like *(args....) can still choose to re-associate however they like.

(@mikmoore I suppose constant-folding (I + M + M*M)*v is the performance-optimal logical conclusion, if not the most well-behaved? Admittedly, even for associative types, that’s a much broader optimization to expect the compiler to recognize.)

Yeah I understand that. I was just explaining that it was explicitly documented in the docstring for * that it was not in fact left-associative, it would instead choose whichever associativity it thinks has the least operations.

I think having a compiler pass that can lift the M*M out of the loop would qualify as meeting that semantic because it does result in less runtime operations, even if implementing that pass might be hard.

Yes, although I would avoid getting too dependent on constant-folding since this is the worst way to perform this calculation when constant-folding is not possible. I’d recommend instead calculating I + M + M*M as a global constant and apply that directly.

My issue with this is that it’s getting harrowingly close to @fastmath-style reassociations. Some algorithms are dependent on things like a*b*c being evaluated in a specific order due to under/overflow concerns or for maintaining precision. There’s a branch in our current ComplexF64 multiply that takes specific pains to ensure that 3 Float64s are multiplied in a viable order.

I know this isn’t being literally proposed for IEEEFloat operands. But when those operands are 1 \times 1 or diagonal matrices of floats we expose the same vulnerabilities. This isn’t literally protected by IEEE754 semantics but it still makes me uneasy.

If is feasible to restrict this transformation to associative types? In the opposite case of e.g. eltype(M) an integer, rational, etc., this optimizes performance while also improves stability, if the subsequent multiplication is with real number in v. While not knowing much about compiler inferability, I’d suppose being aware of the data type is easier than analyzing invariants.

Fastmath style reassociations are already the reality for matrix multiplication. BLAS implementations already do lots of re-association by default. The situation with linear algebra is very different from the situation with scalar code:

using StaticArrays
struct MyAbstractArray{T, N} <: AbstractArray{T, N}
    A::AbstractArray{T, N}
end;
Base.size(A::MyAbstractArray) = size(A.A);
Base.getindex(A::MyAbstractArray, inds...) = getindex(A.A, inds...);
julia> let N = 3
           A = randn(N, N)
           B = randn(N, N)
           C = randn(N, N)

           sA = SMatrix{N, N}(A)
           sB = SMatrix{N, N}(B)
           sC = SMatrix{N, N}(C)
           
           norm(A * B * C) === norm(sA * sB * sC)
       end
false
julia> let N = 100
           A = randn(N, N)
           B = randn(N, N)
           myA = MyAbstractArray(A)
           myB = MyAbstractArray(B)
           
           sum(A * B) === sum(myA * myB)
       end
false

If I recall correctly, even if you hit the exact same method, you’ll get a different answer depending on your machine capabilities, like if you hit AVX-512, versus AVX2, or if you have different cache sizes that cause different tiling strategies.

Yes, but these are done at the matrix-matrix level. Here I am arguing for consistent associativity at the expression level. I’m arguing for only([x;;] * [y;;] * [z;;]) == x*y*z == (x*y)*z, which is still compatible with internal reassociations of matrix-matrix multiplication.

I think that writing A * (B * x) to force right-associativity is less onerous than hoping (or counting on the fact that) that A * B * x quietly right-associates. It’s also less surprising when it does.

The overwhelming majority of the time, people are simply applying floating point operations to implement math-on-reals and haven’t given thought to this and they won’t notice or care if we re-order stuff for speed. But in very few cases people really do care and this reassociation isn’t documented well enough (in my opinion) that people can be aware of it.


Also, I’m reminded of #52333 and the hazard identified there. In that spirit, see this example:

julia> using Octonions

julia> x = Octonion(1,2,3,4,5,6,7,8); y = Octonion(1,-2,3,4,5,6,7,8); z = Octonion(1,2,-3,4,5,6,7,8);

julia> [x;;] * [y;;] * [z;;]
1Ă—1 Matrix{Octonion{Int64}}:
 Octonion{Int64}(-652, -1064, 612, -736, -860, -1128, -1036, -1712)

julia> [x;;] * [y;;] * [z;] # re-assocation changes the value
1-element Vector{Octonion{Int64}}:
 Octonion{Int64}(-652, -1064, 612, -736, -1148, -888, -1420, -1376)

Reassociation outside of very controlled contexts is a significant hazard.

I think N-ary * is a failed experiment, or very close to one. By association, I’m worried about N-ary + and ++ as well, although those operations are more-commonly associative.

Yes, this seems misleading. Imho, the situation is similar to reduce which claims:

help?> reduce
search: reduce mapreduce

  reduce(op, itr; [init])

  [...]

  The associativity of the reduction is implementation dependent. This means
  that you can't use non-associative operations like - because it is undefined
  whether reduce(-,[1,2,3]) should be evaluated as (1-2)-3 or 1-(2-3). Use
  foldl or foldr instead for guaranteed left or right associativity.

Maybe, + and * should be documented similarly (which they are in a sense, but could be misunderstood) and clearly state: “When left- or right-associativity is required use brackets, i.e., a + (b + c) or (a + b) + c instead of a + b + c”