Broadcasting in TensorOperations.jl (and applying ForwardDiff)

Hi everyone,

[Edit] Question 1:

I had a quick question regarding the usage of TensorOperations.jl. How would one go about broadcasting certain operations such as

using TensorOperations;
function g(F)
    @tensor F[i,j] * F[i,j];
end

which works if I have F as two dimensional (say of shape (3,3)). What if I want to broadcast this over a n-dimensional array, say of shape (3,3,100,4,...)? In NumPy I would do something like:

from numpy import einsum
einsum("ij...,ij...", F, F)

where the ellipsis (...) controls the broadcasting.

I have checked Tullio.jl too, but I couldn’t see anything. Any pointers would be helpful.

Thanks

Edit:

I did find EllipsisNotation.jl allows one to index into an array F[1,1,..], but not directly of use in this case.

Question 2:

The end goal is to be able to differentiate through g using ForwardDiff and broadcast the result. This brings me to my next question. Say I have this function:

f(u) = @. 1. + u[1]^2 + u[2]^2

I would take it’s gradient as

g = x-> gradient(f, x)
v = rand(2)
@assert g(v) == 2. .* v   #true

Now what if my input to f is higher dimensional say of shape (2,100,4). How would I take broadcast it’s gradient over the trailing axes? I have taken a look at the posts in discourse, but if you feel this has been answered, feel free to point

Many thanks!

1 Like

Nothing really auto-broadcasts over extra dimensions, Julia seldom thinks of N-dimensional arrays as being collections of lower-dimensional ones. But you can make slices and apply functions to those, in various ways, for instance:

julia> using TensorOperations, TensorCast, ForwardDiff, LinearAlgebra

julia> g(F::AbstractMatrix) = @tensor F[i,j] * F[i,j];  # as in question, is just `dot`

julia> g(rand(2,3))
2.4423897829331964

julia> g3(F::AbstractArray{<:Any, 3}) = @tensor out[k] := F[i,j,k] * F[i,j,k];  # not Einstein, won't work

julia> g3(rand(2,3,4)) |> size
ERROR: TensorOperations.IndexError{String}("non-matching indices between left and right hand side:

julia> g3b(x::AbstractArray{<:Any, 3}) = g.(eachslice(x, dims=3));

julia> g3c(x::AbstractArray{<:Any, 3}) = @cast out[k] := g(x[:,:,k]);  # the same thing

julia> x3 = rand(4,2,3);

julia> g3b(x3) == g3c(x3)
true

julia> g3c(x3) |> size
(3,)

julia> gn(x) = mapslices(g, x, dims=(1,2));

julia> gn(x3) |> size
(1, 1, 3)

julia> gn(rand(4, 2, 5, 6, 7)) |> size
(1, 1, 5, 6, 7)

julia> vec(gn(x3)) ≈ g3b(x3)
true

And gradients:

julia> ForwardDiff.gradient(g, rand(2,3))
2×3 Matrix{Float64}:
 1.6148    1.72041   1.46002
 0.840556  0.139716  0.254563

julia> grad3(x::AbstractArray{<:Any, 3}) = @cast out[k,i,j] := ForwardDiff.gradient(g, x[:,:,k])[i,j];

julia> grad3(x3) |> size
(3, 4, 2)

julia> ForwardDiff.jacobian(g3b, x3) |> size  # not the same thing, doesn't know about sparsity
(3, 24)

julia> reshape(ForwardDiff.jacobian(g3b, x3), 3, 8, 3)  # now take nonzero parts:

julia> vcat(ans[1:1,:,1], ans[2:2,:,2], ans[3:3,:,3]) ≈ reshape(grad3(x3), 3, :)
true

There has been discussion of trying to write something like Jax’s vmap for Julia, re-writing lower-dimensional operations into higher ones. But notice that both for g and its gradient, there isn’t an obvious higher-dimensional function to replace the one on matrices. That would be where you might hope such a transformation would lead to efficiency gains, by replacing for instance many dot products with one matrix multiplication. Without that, it’s just a loop over slices, which is what mapslices or TensorCast already does.

The notation of mapslices will let it work like EllipsisNotation if the .. are at the end, as in gn(x) above. It doesn’t accept multiple arguments so there is no question about combining different array sizes like broadcasting does. (Doing something like that in TensorCast & Tullio is somewhere on the wish-list.)

3 Likes

Thanks @mcabbott
Indeed, Jax’s vmap was something that I had tried before. But it turned out to be not as fast as needed for my application (Finite Elements).

1 Like

Maybe you know this, but if the dimensions you’re dealing with are in fact as small as 3x3, then reinterpreting as StaticArrays is another way to go. And since that gives you an array of arrays, you can broadcast over the container as normal.

3 Likes

Thanks. I always forget about StaticArrays:). Indeed, for my application I only need 3x3 and 2x2.

This should do it for me!

Thanks again

1 Like