Batched matmul with einsum

I’m performing a batched matmul using Einsum. Is there a clean way to infer leading dimensions, like I can with numpy?

For example, I’m currently doing this:

@einsum C[k, l, m, j] = A[i, j] * B[k, l, m, j]

But it’d be great to be able to do something like this…

@einsum C[..., m, j] = A[i, j] * B[m, j]

so that this can work with a B of arbitrary dimensions.

I’m not sure if this can be inferred, but on a separate note you’d probably be better off making k,l the trailing dimensions, since Julia is column-major (unlike numpy).

One option might be to use reshape to re-interpret any number of trailing dimensions into a single dimension.

2 Likes