Hi everyone,
I have been using `LoopVectorization.jl` with very satisfactory results to speedup some calculations.
Now, I have been hitting a problem: combining it with AD!
Although my use case is not super complicated (it is just an array contraction) I am not able to differentiate through it.
I read the documentation of Zygote, I am aware of the limitations when dealing with mutating functions.
I will consider the easier case of a matrix multiplication, but the same applies to more complex scenarios (the one I really actually care).

For instance, let us consider the following (that I just copied from `LoopVectorization` docs)

``````function mygemmavx!(C, A, B)
for m β axes(A,1), n β axes(B,2)
Cmn = zero(eltype(C))
for k β axes(A,2)
Cmn += A[m,k] * B[k,n]
end
C[m,n] = Cmn
end
end
``````

I tried adding `Zygote.Buffer` in an allocating `mygemmavx`, but it is several orders of magnitudes slower.
I then tried with `Tullio.jl`. Although being a bit slower than `LoopVectorization`, it worked nicely on the gradients!

``````using Tullio, LoopVectorization, BenchmarkTools
mul(A, B) = @tullio C[i,k] := A[i,j] * B[j,k]
W = rand(100, 100); x = rand(100,100);
@benchmark sum(mul(W,x))

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max):  19.080 ΞΌs β¦   4.181 ms  β GC (min β¦ max):  0.00% β¦ 95.39%
Time  (median):     43.110 ΞΌs               β GC (median):     0.00%
Time  (mean Β± Ο):   47.226 ΞΌs Β± 140.766 ΞΌs  β GC (mean Β± Ο):  11.02% Β±  3.68%

ββββ
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
19.1 ΞΌs         Histogram: frequency by time         63.5 ΞΌs <

Memory estimate: 80.72 KiB, allocs estimate: 52.
``````
``````@benchmark gradient(W -> sum(mul(W,x)), W)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max):   74.694 ΞΌs β¦  16.137 ms  β GC (min β¦ max): 0.00% β¦ 0.00%
Time  (median):     148.622 ΞΌs               β GC (median):    0.00%
Time  (mean Β± Ο):   161.961 ΞΌs Β± 292.893 ΞΌs  β GC (mean Β± Ο):  9.82% Β± 6.28%

βββ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
74.7 ΞΌs          Histogram: frequency by time          321 ΞΌs <

Memory estimate: 242.59 KiB, allocs estimate: 159.

``````

While standard `*` gave

``````@benchmark gradient(W -> sum(W*x), W)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max):  41.687 ΞΌs β¦   3.389 ms  β GC (min β¦ max):  0.00% β¦ 96.11%
Time  (median):     51.237 ΞΌs               β GC (median):     0.00%
Time  (mean Β± Ο):   71.793 ΞΌs Β± 206.830 ΞΌs  β GC (mean Β± Ο):  18.93% Β±  6.48%

βββββββββ                   β                                β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
41.7 ΞΌs       Histogram: log(frequency) by time       236 ΞΌs <

Memory estimate: 236.11 KiB, allocs estimate: 35.
``````

`Tullio` is within a factor of two of the standard `*` (which is actually great).
In case of more complicated contractions, `Tullio` still works but the gradient computation gets around 15 times slower than the forward pass.
So, my question is: is this the best we can do now in `Julia` to quickly contract arrays and obtain gradients? Is there any option I am missing/am I doing something wrong? Should I implement custom rules (for the weird tensor contraction, they should be straightforward to obtain, especially following the guide in `ChainRules`).

This is employed within `Turing.jl`, I want to do HMC, so quick gradients are a requirement for a good computational performance.
(Tagging @Elrod @mcabbott but insights from anyone are very welcome)
Marco

1 Like

Do you have some examples of your complicated contractions? If they are allowed by TensorOperations.jl, then you should also try that. It now has built-in support for taking gradients.

Otherwise your best bet is probably to write gradient rules, which themselves call Tullio. The macro can derive the expressions you need:

``````julia> @tullio C[i,k] := A[i,j] * B[j,k] verbose=1
β   inbody =
β    2-element Vector{Any}:
β     :(π₯A[i, j] = π₯A[i, j] + π₯β[i, k] * conj(B[j, k]))
β     :(π₯B[j, k] = π₯B[j, k] + π₯β[i, k] * conj(A[i, j]))
``````

Then your `rrule` will contain something like `back(dC) = (NoTangent(), @tullio(dA[i, j] := dC[i, k] * conj(B[j, k])), @tullio(dB[j, k] := dC[i, k] * conj(A[i, j])))`.

This is not how Tullioβs gradients work. Instead, like the forward pass, it always writes one loop nest, writing into `π₯A` and `π₯B` simultaneously. This limits parallelism (can safely multi-thread only over `j`) and seems to play less well with LoopVectorization.jl. Unfortunately changing this would be quite messy, and Iβm unlikely to get around to it.

2 Likes

Hi @mcabbott ,
So, a more complicated contraction is, for instance (I am trying also the package you suggested)

``````function tullio_conv(W,v)
return @tullio C[i,k] := W[i,j,k,l] * v[j,l]
end

function tensor_conv(W,v)
return @tensor C[i,k] := W[i,j,k,l] * v[j,l]
end
``````

``````W = rand(2, 2, 37, 1400)
x = rand(2, 1400)
@benchmark sum(tullio_conv(W, x))

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max):  36.650 ΞΌs β¦ 147.037 ΞΌs  β GC (min β¦ max): 0.00% β¦ 0.00%
Time  (median):     37.334 ΞΌs               β GC (median):    0.00%
Time  (mean Β± Ο):   38.074 ΞΌs Β±   2.253 ΞΌs  β GC (mean Β± Ο):  0.00% Β± 0.00%

ββ
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
36.6 ΞΌs         Histogram: frequency by time         45.3 ΞΌs <

Memory estimate: 688 bytes, allocs estimate: 2.

@benchmark sum(tensor_conv(W, x))

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max):  141.516 ΞΌs β¦   3.715 ms  β GC (min β¦ max):  0.00% β¦ 92.18%
Time  (median):     203.635 ΞΌs               β GC (median):     0.00%
Time  (mean Β± Ο):   297.734 ΞΌs Β± 433.301 ΞΌs  β GC (mean Β± Ο):  26.07% Β± 15.99%

ββββββ                                                        β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
142 ΞΌs        Histogram: log(frequency) by time       2.61 ms <

Memory estimate: 1.59 MiB, allocs estimate: 137.
``````

Regarding AD, only `Tullio` looks to be working

``````@benchmark gradient(W -> sum(tullio_conv(W, x)), W)

BenchmarkTools.Trial: 8735 samples with 1 evaluation.
Range (min β¦ max):  406.598 ΞΌs β¦   4.155 ms  β GC (min β¦ max):  0.00% β¦ 86.13%
Time  (median):     467.586 ΞΌs               β GC (median):     0.00%
Time  (mean Β± Ο):   570.568 ΞΌs Β± 404.231 ΞΌs  β GC (mean Β± Ο):  12.25% Β± 13.93%

βββββ   ββββ                                                  β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
407 ΞΌs        Histogram: log(frequency) by time       2.65 ms <

Memory estimate: 1.61 MiB, allocs estimate: 59.
``````
Error stacktrace for TensorOperations

MethodError: no method matching StridedViews.StridedView(::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})

Closest candidates are:
StridedViews.StridedView(::PermutedDimsArray{T, N, P}) where {T, N, P}
@ StridedViews ~/.julia/packages/StridedViews/dcnHM/src/stridedview.jl:51
StridedViews.StridedView(::Base.ReshapedArray)
@ StridedViews ~/.julia/packages/StridedViews/dcnHM/src/stridedview.jl:50
StridedViews.StridedView(::SubArray)
@ StridedViews ~/.julia/packages/StridedViews/dcnHM/src/stridedview.jl:49
β¦

Stacktrace:
[1] tensorcontract!(C::Array{Float64, 4}, pC::Tuple{NTuple{4, Int64}, Tuple{}}, A::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::Matrix{Float64}, pB::Tuple{Tuple{}, Tuple{Int64, Int64}}, conjB::Symbol, Ξ±::VectorInterface.One, Ξ²::VectorInterface.Zero, #unused#::TensorOperations.Backend{:StridedBLAS})
@ TensorOperations ~/.julia/packages/TensorOperations/7VyQe/src/implementation/abstractarray.jl:63
[2] tensorcontract!(C::Array{Float64, 4}, pC::Tuple{NTuple{4, Int64}, Tuple{}}, A::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::Matrix{Float64}, pB::Tuple{Tuple{}, Tuple{Int64, Int64}}, conjB::Symbol, Ξ±::VectorInterface.One, Ξ²::VectorInterface.Zero)
@ TensorOperations ~/.julia/packages/TensorOperations/7VyQe/src/implementation/abstractarray.jl:35
[3] (::TensorOperationsChainRulesCoreExt.var"#58#65"{Tuple{Tuple{Int64, Int64}, Tuple{}}, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, Tuple{}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, NTuple{4, Base.OneTo{Int64}}}}}})()
@ TensorOperationsChainRulesCoreExt ~/.julia/packages/TensorOperations/7VyQe/ext/TensorOperationsChainRulesCoreExt.jl:99
[4] unthunk
@ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:204 [inlined]
[5] wrap_chainrules_output
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:110 [inlined]
[6] map (repeats 4 times)
@ ./tuple.jl:276 [inlined]
[7] wrap_chainrules_output
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:111 [inlined]
[8] ZBack
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:211 [inlined]
[9] Pullback
@ ./In[35]:25 [inlined]
[10] (::Zygote.Pullback{Tuple{typeof(tensor_conv), Array{Float64, 4}, Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 4}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 4}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#63"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, NTuple{4, Base.OneTo{Int64}}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}}})(Ξ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[11] Pullback
@ ./In[38]:1 [inlined]
[12] (::Zygote.Pullback{Tuple{var"#47#48", Array{Float64, 4}}, Tuple{Zygote.var"#2995#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(tensor_conv), Array{Float64, 4}, Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 4}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 4}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#63"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, NTuple{4, Base.OneTo{Int64}}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}}}}})(Ξ::Float64)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[13] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#47#48", Array{Float64, 4}}, Tuple{Zygote.var"#2995#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(tensor_conv), Array{Float64, 4}, Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 4}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 4}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#63"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, NTuple{4, Base.OneTo{Int64}}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}}}}}})(Ξ::Float64)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
[15] var"##core#1252"()
@ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:489
[16] var"##sample#1253"(::Tuple{}, __params::BenchmarkTools.Parameters)
@ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:495
[17] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:99
[18] #invokelatest#2
@ ./essentials.jl:821 [inlined]
[19] invokelatest
@ ./essentials.jl:816 [inlined]
[20] #run_result#45
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
[21] run_result
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
[22] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117
[23] run (repeats 2 times)
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117 [inlined]
[24] #warmup#54
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:169 [inlined]
[25] warmup(item::BenchmarkTools.Benchmark)
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:168
[26] top-level scope
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:393

From you message, I understood I am not doing anything wrong. Is this right, or am I missing something?

The error might be avoided by `sum(abs2, ...)` instead of `sum`.

Zygoteβs rule for `sum` uses FillArrays, which this PR may remove. Itβs an optimisation to save a little allocation which causes all kinds of issues (normally in toy problems like this). Here `@tensor` doesnβt understand this type of fake array, see e.g. this question.

1 Like

Thanks, now is fixed.
It is slower than `Tullio`βs gradient.

``````@benchmark gradient(W -> sum(abs2, tullio_conv(W, x)), W)

BenchmarkTools.Trial: 3244 samples with 1 evaluation.
Range (min β¦ max):  966.187 ΞΌs β¦  17.276 ms  β GC (min β¦ max):  0.00% β¦  0.00%
Time  (median):       1.160 ms               β GC (median):     0.00%
Time  (mean Β± Ο):     1.538 ms Β± 903.611 ΞΌs  β GC (mean Β± Ο):  21.29% Β± 22.59%

βββββββββ   βββ                             βββββββββ         β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
966 ΞΌs        Histogram: log(frequency) by time       3.88 ms <

Memory estimate: 7.99 MiB, allocs estimate: 624.
``````

Regarding the performance, is there anything (in principle) that might be done to improve it?
If the problem is βjustβ with parallelization, I might still be happy (if it doesnβt parallelize, I can just run more chains in parallel), but I would like to understand whether there is anything I can do better.

Ok. `@tensor` nees permutedims here, which is quite expensive. Things you might try are (1) writing `rrule`s for Tullio, and perhaps bigger (2) re-ordering array indices to minimise permutations.

Answering later with more precise benchmarks.

1. Array reordering puts `TensorOperations.jl` at the same level of `Tullio` for the last example.
2. Adding `Zygote.@adjoints` rules improves performance.

For the first scenario (the standard matrix multiplication) I get a considerable speedup.
Without adding the rule, I have

``````@benchmark gradient(W -> sum(abs2, mul(W,x)), W)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max):   73.643 ΞΌs β¦  16.139 ms  β GC (min β¦ max): 0.00% β¦ 0.00%
Time  (median):     161.584 ΞΌs               β GC (median):    0.00%
Time  (mean Β± Ο):   175.547 ΞΌs Β± 237.904 ΞΌs  β GC (mean Β± Ο):  7.80% Β± 7.28%

β βββββββββββββββββββ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
73.6 ΞΌs          Histogram: frequency by time          245 ΞΌs <

Memory estimate: 320.59 KiB, allocs estimate: 160.
``````

``````BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min β¦ max):  22.003 ΞΌs β¦   2.901 ms  β GC (min β¦ max):  0.00% β¦ 94.00%
Time  (median):     52.789 ΞΌs               β GC (median):     0.00%
Time  (mean Β± Ο):   59.898 ΞΌs Β± 122.754 ΞΌs  β GC (mean Β± Ο):  11.28% Β±  5.44%

βββββββββββ
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
22 ΞΌs           Histogram: frequency by time          102 ΞΌs <

Memory estimate: 160.44 KiB, allocs estimate: 80.
``````

Later will do the same with the weirder contraction.
So, the lessons learnt are:

1. Remember the performance tips section of Julia docs (array indexing matters)
2. Write custom rules to boost performance.

Will update later with the more complex scenario.

1 Like

Itβs not productive/human time efficient to need to implement all these rules, but it is a fairly reliable way to get good performance.
SimpleChains.jl does this to hit its performance targets.

LoopVectorization.jl should automatically be splitting those into separate loops. If you have an example where it fails to do that, you could share it.
Although Iβm unlikely to get around to addressing it, I could answer questions and point someone toward how to figure out what is going on/why/how to fix it.