Creating a macro to call @tullio with arbitrary dimensions

Hello everyone,

Custom Macro

I am hoping to be able to write a macro that can call @tullio to specify an arbitrary sized linear form. I think this should follow from some easy-ish fix similar to an old thread… but I am stuck.

Mathematical Description

I have the input arrays A \in \underbrace{d \times \cdots \times d}_{M +1 \text{ times}}, e \in d, and the output array of b\in d. In Einstein sum notation I want the following mathematical operation:

b_i = A_{i, a_1, ..., a_M}e_{a_1}e_{a_2} \cdots e_{a_M}

I can accomplish this using through a string manipulation and Meta.parse, but this does not compile down because the expression produced is evaluated at run time each time it is called.

I have tried craft my own macro following the documentation, but the macro that I have made does not modify appear to be performing the mathematical operation on the data structures that I intended.

Minimum Working Example

using Tullio, Random
##
macro myMacro(b, A, e, M)
    return quote 
        local val_a_max = $M
        local args1 = [Symbol("a$j") for j = 1:val_a_max]
        # @info "args1 = $args1"
        local args2 = [(:ref, esc(e), Symbol("d$j")) for j = 1:val_a_max]
        # @info "args2 = $args2"
        expr = esc(:(@tullio esc(b)[i] = prod(esc(A)[i, $(args1...)], $(args2...))))
        # Meta.show_sexpr(expr)
        # println()
        expr
    end
end
#
function test_this()
    Random.seed!(1)
    d = 2 
    M_max = 5
    M = nothing
    A = nothing
    e = nothing
    b = zeros(d)
    for a in 2:M_max
        global M = a
        global A = rand((d for _ in 1:M+1)...)
        global e = rand(d)
        global b = zeros(d)
        @info "M = $M"
        # my custom macro
        @info "- myMacro"
        @time @myMacro b A e M
        @info "-- $(b')"
        # hard coded calls to tullio macro
        @info "- hard coded tullio macro"
        @time begin 
            if M == 2 
                @tullio b[i] = A[i,a1,a2]*e[a1]*e[a2]
            elseif M == 3 
                @tullio b[i] = (A[i,a1,a2,a3]
                    *e[a1]*e[a2]*e[a3])
            elseif M == 4 
                @tullio b[i] = (A[i,a1,a2,a3,a4]
                    *e[a1]*e[a2]*e[a3]*e[a4])
            elseif M == 5 
                @tullio b[i] = (A[i,a1,a2,a3,a4,a5]
                    *e[a1]*e[a2]*e[a3]*e[a4]*e[a5])
            elseif M == 6 
                @tullio b[i] = (A[i,a1,a2,a3,a4,a5,a6]
                    *e[a1]*e[a2]*e[a3]*e[a4]*e[a5]*e[a6])
            elseif M == 7 
                @tullio b[i] = (A[i,a1,a2,a3,a4,a5,a6,a7]
                    *e[a1]*e[a2]*e[a3]*e[a4]*e[a5]*e[a6]*e[a7])
            elseif M == 8 
                @tullio b[i] = (A[i,a1,a2,a3,a4,a5,a6,a7,a8]
                    *e[a1]*e[a2]*e[a3]*e[a4]*e[a5]*e[a6]*e[a7]*e[a8])
            elseif M == 9 
                @tullio b[i] = (A[i,a1,a2,a3,a4,a5,a6,a7,a8,a9]
                    *e[a1]*e[a2]*e[a3]*e[a4]*e[a5]*e[a6]*e[a7]*e[a8]*e[a9])
            elseif M == 10 
                @tullio b[i] = (A[i,a1,a2,a3,a4,a5,a6,a7,a8,a9,a10]
                    *e[a1]*e[a2]*e[a3]*e[a4]*e[a5]*e[a6]*e[a7]*e[a8]*e[a9]*e[a10])
            else 
                @error "not implemented"
            end
        end
        @info "-- $(b')"
        # eval expr
        b = zeros(d)
        cmd_str = "@tullio b[i] = A[i"*prod(",a$j" for j in 1:M)*"]"*prod("*e[a$j]" for j in 1:M)
        expr = Meta.parse(cmd_str)
        @info "- eval expr"
        @time eval(expr)
        @info "-- $(b')"
        println("=============================================================")
    end
end
##
# first time everything jit compiles  
println("First evaluation of the test. We expect everything to have high compilation time")
test_this()
#
println()
println()
# now only the Meta.parse is evaluating at run time
println("Second evaluation of the test. We expect only Meta.parse to have compilation time")
test_this()

Output

First evaluation of the test. We expect everything to have high compilation time
[ Info: M = 2
[ Info: - myMacro
  0.020640 seconds (90.94 k allocations: 4.514 MiB, 99.45% compilation time)
[ Info: -- [0.0 0.0]
[ Info: - hard coded tullio macro
  0.118539 seconds (474.54 k allocations: 24.213 MiB, 3.53% gc time, 99.81% compilation time)
[ Info: -- [0.6458665226117443 0.8712762271312573]
[ Info: - eval expr
  0.107026 seconds (322.60 k allocations: 16.671 MiB, 96.68% compilation time)
[ Info: -- [0.6458665226117443 0.8712762271312573]
=============================================================
[ Info: M = 3
[ Info: - myMacro
  0.000021 seconds (76 allocations: 2.688 KiB)
[ Info: -- [0.0 0.0]
[ Info: - hard coded tullio macro
  0.141966 seconds (474.08 k allocations: 24.201 MiB, 99.93% compilation time)
[ Info: -- [0.047684484532381215 0.055926908789111836]
[ Info: - eval expr
  0.138103 seconds (370.36 k allocations: 18.921 MiB, 2.69% gc time, 97.45% compilation time)
[ Info: -- [0.047684484532381215 0.055926908789111836]
=============================================================
[ Info: M = 4
[ Info: - myMacro
  0.000032 seconds (90 allocations: 3.234 KiB)
[ Info: -- [0.0 0.0]
[ Info: - hard coded tullio macro
  0.181469 seconds (557.63 k allocations: 28.450 MiB, 99.95% compilation time)
[ Info: -- [0.24707741708999534 0.3316466084333281]
[ Info: - eval expr
  0.170222 seconds (435.33 k allocations: 22.167 MiB, 97.76% compilation time)
[ Info: -- [0.24707741708999534 0.3316466084333281]
=============================================================
[ Info: M = 5
[ Info: - myMacro
  0.000021 seconds (104 allocations: 3.719 KiB)
[ Info: -- [0.0 0.0]
[ Info: - hard coded tullio macro
  0.245771 seconds (628.29 k allocations: 31.886 MiB, 1.32% gc time, 99.96% compilation time)
[ Info: -- [0.06775196978697368 0.09599998137184107]
[ Info: - eval expr
  0.237658 seconds (494.51 k allocations: 24.935 MiB, 98.09% compilation time)
[ Info: -- [0.06775196978697368 0.09599998137184107]
=============================================================


Second evaluation of the test. We expect only Meta.parse to have compilation time
[ Info: M = 2
[ Info: - myMacro
  0.000013 seconds (62 allocations: 2.203 KiB)
[ Info: -- [0.0 0.0]
[ Info: - hard coded tullio macro
  0.000013 seconds (12 allocations: 224 bytes)
[ Info: -- [0.6458665226117443 0.8712762271312573]
[ Info: - eval expr
  0.101354 seconds (322.58 k allocations: 16.674 MiB, 96.98% compilation time)
[ Info: -- [0.6458665226117443 0.8712762271312573]
=============================================================
[ Info: M = 3
[ Info: - myMacro
  0.000016 seconds (76 allocations: 2.688 KiB)
[ Info: -- [0.0 0.0]
[ Info: - hard coded tullio macro
  0.000013 seconds (15 allocations: 272 bytes)
[ Info: -- [0.047684484532381215 0.055926908789111836]
[ Info: - eval expr
  0.136961 seconds (370.35 k allocations: 18.923 MiB, 2.04% gc time, 97.48% compilation time)
[ Info: -- [0.047684484532381215 0.055926908789111836]
=============================================================
[ Info: M = 4
[ Info: - myMacro
  0.000023 seconds (90 allocations: 3.234 KiB)
[ Info: -- [0.0 0.0]
[ Info: - hard coded tullio macro
  0.000017 seconds (18 allocations: 336 bytes)
[ Info: -- [0.24707741708999534 0.3316466084333281]
[ Info: - eval expr
  0.177276 seconds (435.32 k allocations: 22.165 MiB, 97.77% compilation time)
[ Info: -- [0.24707741708999534 0.3316466084333281]
=============================================================
[ Info: M = 5
[ Info: - myMacro
  0.000022 seconds (104 allocations: 3.719 KiB)
[ Info: -- [0.0 0.0]
[ Info: - hard coded tullio macro
  0.000016 seconds (21 allocations: 384 bytes)
[ Info: -- [0.06775196978697368 0.09599998137184107]
[ Info: - eval expr
  0.233445 seconds (494.50 k allocations: 24.936 MiB, 1.05% gc time, 98.06% compilation time)
[ Info: -- [0.06775196978697368 0.09599998137184107]
=============================================================

This macro might do what you want:

using Tullio: @tullio

macro tullioform(b, A, e, M)
    bb, AA, ee = esc(b), esc(A), esc(e)
    Ainds = [Symbol(:a, k) for k in 1:M]
    eprod = [:($ee[$(Symbol(:a, k))]) for k in 1:M]
    quote
        let b = $bb
            @tullio b[i] = *($AA[i, $(Ainds...)], $(eprod...))
        end
    end
end

Example:

julia> A = rand(2, 3, 3, 3); e = rand(3); b = zeros(2);

julia> @tullioform b A e 3
2-element Vector{Float64}:
 4.3187279479751375
 3.6292691407204396

julia> b === @tullioform b A e 3
true

General Use Case

Thanks to @matthias314 for looking at this with me! Unfortunately, this code works well in isolation, but there is a hiccup. First, if we pas the macro a variable, call it M, with a particular value. The above macro will not work…

@tullioform A b e 2 # works
M = 2
@tullioform A b e M # does not work

This can be alleviated by the following small change to the macro

macro tullioform(b, A, e, M)
    bb, AA, ee = esc(b), esc(A), esc(e)
    M_val = eval(:($M))   # Obtain the value from the symbol by `evaluating'
    println("Macro called with M = $M_val")
    Ainds = [Symbol(:a, k) for k in 1:M_val]
    eprod = [:($ee[$(Symbol(:a, k))]) for k in 1:M_val]
    return quote 
        let b = $bb
            @tullio $b[i] = *($AA[i, $(Ainds...)], $(eprod...))
        end
    end
end

Hygiene Issues

These problems results because of hygiene issues… If we want to call this macro from within a local scope like a function or a for loop, errors occur. Consider:

for M in 1:M_max 
  M = m
  A = rand((d for _ in 1:M+1)...)
  e = rand(d)
  b = zeros(d)
  @tullioform A b e M
end

Unfortunately, this results in the following error:

ERROR: LoadError: UndefVarError: `M` not defined in `Main` 

This indicates that I was not letting properly using local and global variables.

Final Solution

The final solution (for me at least) comes from understanding how to generate multiple functions that can get compiled down. Don’t worry we still end up using the the macro we have been working on, but with a few tricks.

using Tullio: @Tullio
macro _tullioform(b, A, e, M)
    if !(M isa Integer)
        error("@tullioform requires M to be an integer literal. Got: ", M)
    end
    Ainds = [Symbol(:a, k) for k in 1:M]
    eprod = [:($e[$(Symbol(:a, k))]) for k in 1:M]
    ex = :($b[i] = *($A[i, $(Ainds...)], $(eprod...)))
    return esc(quote 
        @tullio $ex
    end)
end
##
M_MAX = 50
for M in 2:M_MAX
    Mp1 = M + 1
    @eval begin 
        function tullioform!(b::Vector{<:Real}, A::Array{<:Real, $Mp1}, e::Vector{<:Real})
            @_tullioform b A e $M 
            return b
        end
    end
end

The macro will build the expression that we want to be evaluated (properly escaped). Then in the for loop we generate multiple functions that can get compiled down and specialized for each size of A. This code can be tested with the following:

using Random
function test_all(d = 2, seed=1, M_MAX = 5)
    Random.seed!(seed)
    for M in 2:M_MAX
        println("-------------------------------------------------------------")
        println("- M = $M")
        A = rand((d for _ in 1:M+1)...)
        e = rand(d)
        b = zeros(d)
        @info "- custom macro"
        @time tullioform!(b, A, e)
        @info "-- $(b')"
        @info "- hard coded macro"
        b = zeros(d)
        @time begin 
            if M == 2 
                @tullio b[i] = A[i,a1,a2]*e[a1]*e[a2]
            elseif M == 3 
                @tullio b[i] = (A[i,a1,a2,a3]
                    *e[a1]*e[a2]*e[a3])
            elseif M == 4 
                @tullio b[i] = (A[i,a1,a2,a3,a4]
                    *e[a1]*e[a2]*e[a3]*e[a4])
            elseif M == 5 
                @tullio b[i] = (A[i,a1,a2,a3,a4,a5]
                    *e[a1]*e[a2]*e[a3]*e[a4]*e[a5])
            else 
                @error "not implemented"
            end
        end
        @info "-- $(b')"
    end
end
## 
d = 5
M_MAX = 5
println("=============================================================")
println("First run: Compile time should be a factor")
test_all(d, 1, M_MAX)
##
println("=============================================================")
println("Second run: Compile time should NOT be a factor")
test_all(d, 1, M_MAX)
println("=============================================================")

Output

=============================================================
First run: Compile time should be a factor
-------------------------------------------------------------
- M = 2
[ Info: - custom macro
  0.110359 seconds (338.54 k allocations: 17.587 MiB, 99.92% compilation time)
[ Info: -- [2.266382452055628 3.720653596334037 3.5303713328800947 3.029815857736556 3.5458117652904355]
[ Info: - hard coded macro
  0.105300 seconds (318.48 k allocations: 16.515 MiB, 99.95% compilation time)
[ Info: -- [2.266382452055628 3.720653596334037 3.5303713328800947 3.029815857736556 3.5458117652904355]
-------------------------------------------------------------
- M = 3
[ Info: - custom macro
  0.142950 seconds (389.60 k allocations: 20.028 MiB, 99.95% compilation time)
[ Info: -- [21.66869584807132 19.221745537741267 22.37165953801042 22.641925913474022 21.30457819290222]
[ Info: - hard coded macro
  0.141125 seconds (365.59 k allocations: 18.750 MiB, 99.96% compilation time)
[ Info: -- [21.66869584807132 19.221745537741267 22.37165953801042 22.641925913474022 21.30457819290222]
-------------------------------------------------------------
- M = 4
[ Info: - custom macro
  0.178966 seconds (459.03 k allocations: 23.495 MiB, 1.80% gc time, 99.96% compilation time)
[ Info: -- [29.47781306202569 28.360920426867278 29.768770348832547 28.011725722784554 29.994603556109308]
[ Info: - hard coded macro
  0.181621 seconds (429.90 k allocations: 21.946 MiB, 99.97% compilation time)
[ Info: -- [29.47781306202569 28.360920426867278 29.768770348832547 28.011725722784554 29.994603556109308]
-------------------------------------------------------------
- M = 5
[ Info: - custom macro
  0.244355 seconds (522.15 k allocations: 26.470 MiB, 99.97% compilation time)
[ Info: -- [55.271239017691826 54.37312505932062 54.5786869409735 54.23750041903694 54.57361004157535]
[ Info: - hard coded macro
  0.242419 seconds (488.42 k allocations: 24.694 MiB, 99.97% compilation time)
[ Info: -- [55.271239017691826 54.37312505932062 54.5786869409735 54.23750041903694 54.57361004157535]

=============================================================
Second run: Compile time should NOT be a factor
-------------------------------------------------------------
- M = 2
[ Info: - custom macro
  0.000002 seconds
[ Info: -- [2.266382452055628 3.720653596334037 3.5303713328800947 3.029815857736556 3.5458117652904355]
[ Info: - hard coded macro
  0.000016 seconds (12 allocations: 240 bytes)
[ Info: -- [2.266382452055628 3.720653596334037 3.5303713328800947 3.029815857736556 3.5458117652904355]
-------------------------------------------------------------
- M = 3
[ Info: - custom macro
  0.000002 seconds
[ Info: -- [21.66869584807132 19.221745537741267 22.37165953801042 22.641925913474022 21.30457819290222]
[ Info: - hard coded macro
  0.000010 seconds (15 allocations: 288 bytes)
[ Info: -- [21.66869584807132 19.221745537741267 22.37165953801042 22.641925913474022 21.30457819290222]
-------------------------------------------------------------
- M = 4
[ Info: - custom macro
  0.000005 seconds
[ Info: -- [29.47781306202569 28.360920426867278 29.768770348832547 28.011725722784554 29.994603556109308]
[ Info: - hard coded macro
  0.000015 seconds (18 allocations: 352 bytes)
[ Info: -- [29.47781306202569 28.360920426867278 29.768770348832547 28.011725722784554 29.994603556109308]
-------------------------------------------------------------
- M = 5
[ Info: - custom macro
  0.000021 seconds
[ Info: -- [55.271239017691826 54.37312505932062 54.5786869409735 54.23750041903694 54.57361004157535]
[ Info: - hard coded macro
  0.000029 seconds (21 allocations: 400 bytes)
[ Info: -- [55.271239017691826 54.37312505932062 54.5786869409735 54.23750041903694 54.57361004157535]
=============================================================