Array Contraction, LoopVectorization & AD

Hi everyone,
I have been using LoopVectorization.jl with very satisfactory results to speedup some calculations.
Now, I have been hitting a problem: combining it with AD!
Although my use case is not super complicated (it is just an array contraction) I am not able to differentiate through it.
I read the documentation of Zygote, I am aware of the limitations when dealing with mutating functions.
I will consider the easier case of a matrix multiplication, but the same applies to more complex scenarios (the one I really actually care).

For instance, let us consider the following (that I just copied from LoopVectorization docs)

function mygemmavx!(C, A, B)
   for m ∈ axes(A,1), n ∈ axes(B,2)
       Cmn = zero(eltype(C))
       for k ∈ axes(A,2)
           Cmn += A[m,k] * B[k,n]
       end
       C[m,n] = Cmn
   end
end

I tried adding Zygote.Buffer in an allocating mygemmavx, but it is several orders of magnitudes slower.
I then tried with Tullio.jl. Although being a bit slower than LoopVectorization, it worked nicely on the gradients!

using Tullio, LoopVectorization, BenchmarkTools
mul(A, B) = @tullio C[i,k] := A[i,j] * B[j,k]
W = rand(100, 100); x = rand(100,100);
@benchmark sum(mul(W,x))

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  19.080 ΞΌs …   4.181 ms  β”Š GC (min … max):  0.00% … 95.39%
 Time  (median):     43.110 ΞΌs               β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):   47.226 ΞΌs Β± 140.766 ΞΌs  β”Š GC (mean Β± Οƒ):  11.02% Β±  3.68%

                                 β–β–ˆβ–„β–                           
  β–‚β–‚β–‚β–‚β–β–‚β–‚β–‚β–„β–ƒβ–ƒβ–„β–„β–ƒβ–„β–…β–„β–ƒβ–ƒβ–ƒβ–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–‚β–ƒβ–…β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–„β–ƒβ–ƒβ–ƒβ–ƒβ–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚ β–ƒ
  19.1 ΞΌs         Histogram: frequency by time         63.5 ΞΌs <

 Memory estimate: 80.72 KiB, allocs estimate: 52.
@benchmark gradient(W -> sum(mul(W,x)), W)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   74.694 ΞΌs …  16.137 ms  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     148.622 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   161.961 ΞΌs Β± 292.893 ΞΌs  β”Š GC (mean Β± Οƒ):  9.82% Β± 6.28%

                    β–†β–ˆβ–                                          
  β–β–β–‚β–‚β–‚β–‚β–„β–„β–ƒβ–ƒβ–„β–„β–…β–…β–„β–…β–†β–†β–ˆβ–ˆβ–ˆβ–…β–ƒβ–ƒβ–‚β–‚β–β–‚β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β– β–‚
  74.7 ΞΌs          Histogram: frequency by time          321 ΞΌs <

 Memory estimate: 242.59 KiB, allocs estimate: 159.

While standard * gave

@benchmark gradient(W -> sum(W*x), W)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  41.687 ΞΌs …   3.389 ms  β”Š GC (min … max):  0.00% … 96.11%
 Time  (median):     51.237 ΞΌs               β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):   71.793 ΞΌs Β± 206.830 ΞΌs  β”Š GC (mean Β± Οƒ):  18.93% Β±  6.48%

  β–β–†β–ˆβ–†β–†β–…β–ƒβ–‚β–                   β–‚                                β–‚
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–…β–‡β–†β–…β–†β–†β–…β–†β–…β–†β–†β–…β–†β–†β–…β–ˆβ–ˆβ–ˆβ–†β–…β–†β–†β–†β–†β–‡β–…β–„β–„β–„β–…β–…β–…β–†β–†β–…β–„β–„β–β–„β–„β–„β–„β–ƒβ–„β–ƒβ–ƒβ–…β–… β–ˆ
  41.7 ΞΌs       Histogram: log(frequency) by time       236 ΞΌs <

 Memory estimate: 236.11 KiB, allocs estimate: 35.

Tullio is within a factor of two of the standard * (which is actually great).
In case of more complicated contractions, Tullio still works but the gradient computation gets around 15 times slower than the forward pass.
So, my question is: is this the best we can do now in Julia to quickly contract arrays and obtain gradients? Is there any option I am missing/am I doing something wrong? Should I implement custom rules (for the weird tensor contraction, they should be straightforward to obtain, especially following the guide in ChainRules).

This is employed within Turing.jl, I want to do HMC, so quick gradients are a requirement for a good computational performance.
(Tagging @Elrod @mcabbott but insights from anyone are very welcome)
Thank you in advance,
Marco

1 Like

Do you have some examples of your complicated contractions? If they are allowed by TensorOperations.jl, then you should also try that. It now has built-in support for taking gradients.

Otherwise your best bet is probably to write gradient rules, which themselves call Tullio. The macro can derive the expressions you need:

julia> @tullio C[i,k] := A[i,j] * B[j,k] verbose=1
β”Œ Info: symbolic gradients
β”‚   inbody =
β”‚    2-element Vector{Any}:
β”‚     :(π›₯A[i, j] = π›₯A[i, j] + π›₯β„›[i, k] * conj(B[j, k]))
β””     :(π›₯B[j, k] = π›₯B[j, k] + π›₯β„›[i, k] * conj(A[i, j]))

Then your rrule will contain something like back(dC) = (NoTangent(), @tullio(dA[i, j] := dC[i, k] * conj(B[j, k])), @tullio(dB[j, k] := dC[i, k] * conj(A[i, j]))).

This is not how Tullio’s gradients work. Instead, like the forward pass, it always writes one loop nest, writing into π›₯A and π›₯B simultaneously. This limits parallelism (can safely multi-thread only over j) and seems to play less well with LoopVectorization.jl. Unfortunately changing this would be quite messy, and I’m unlikely to get around to it.

2 Likes

Hi @mcabbott ,
Thank you for your answer.
So, a more complicated contraction is, for instance (I am trying also the package you suggested)

function tullio_conv(W,v)
    return @tullio C[i,k] := W[i,j,k,l] * v[j,l]
end

function tensor_conv(W,v)
    return @tensor C[i,k] := W[i,j,k,l] * v[j,l]
end

The benchmark reads

W = rand(2, 2, 37, 1400)
x = rand(2, 1400)
@benchmark sum(tullio_conv(W, x))

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  36.650 ΞΌs … 147.037 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     37.334 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   38.074 ΞΌs Β±   2.253 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

    β–ˆβ–…                                                          
  β–β–ƒβ–ˆβ–ˆβ–ˆβ–„β–ƒβ–‚β–‚β–β–‚β–„β–„β–ƒβ–ƒβ–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β– β–‚
  36.6 ΞΌs         Histogram: frequency by time         45.3 ΞΌs <

 Memory estimate: 688 bytes, allocs estimate: 2.

@benchmark sum(tensor_conv(W, x))

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  141.516 ΞΌs …   3.715 ms  β”Š GC (min … max):  0.00% … 92.18%
 Time  (median):     203.635 ΞΌs               β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):   297.734 ΞΌs Β± 433.301 ΞΌs  β”Š GC (mean Β± Οƒ):  26.07% Β± 15.99%

  β–†β–ˆβ–†β–…β–ƒβ–‚                                                        β–‚
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–†β–…β–…β–ƒβ–„β–„β–β–β–β–β–β–β–β–…β–β–β–β–„β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ƒβ–β–„β–†β–‡β–‡β–‡β–ˆβ–ˆβ–‡β–ˆβ–‡β–‡ β–ˆ
  142 ΞΌs        Histogram: log(frequency) by time       2.61 ms <

 Memory estimate: 1.59 MiB, allocs estimate: 137.

Regarding AD, only Tullio looks to be working

@benchmark gradient(W -> sum(tullio_conv(W, x)), W)

BenchmarkTools.Trial: 8735 samples with 1 evaluation.
 Range (min … max):  406.598 ΞΌs …   4.155 ms  β”Š GC (min … max):  0.00% … 86.13%
 Time  (median):     467.586 ΞΌs               β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):   570.568 ΞΌs Β± 404.231 ΞΌs  β”Š GC (mean Β± Οƒ):  12.25% Β± 13.93%

  β–ƒβ–ˆβ–†β–„β–‚   ▁▃▃▁                                                  ▁
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–…β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–„β–…β–ƒβ–β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–„β–‡β–‡β–‡β–ˆβ–ˆβ–ˆβ–ˆ β–ˆ
  407 ΞΌs        Histogram: log(frequency) by time       2.65 ms <

 Memory estimate: 1.61 MiB, allocs estimate: 59.
Error stacktrace for TensorOperations

MethodError: no method matching StridedViews.StridedView(::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})

Closest candidates are:
StridedViews.StridedView(::PermutedDimsArray{T, N, P}) where {T, N, P}
@ StridedViews ~/.julia/packages/StridedViews/dcnHM/src/stridedview.jl:51
StridedViews.StridedView(::Base.ReshapedArray)
@ StridedViews ~/.julia/packages/StridedViews/dcnHM/src/stridedview.jl:50
StridedViews.StridedView(::SubArray)
@ StridedViews ~/.julia/packages/StridedViews/dcnHM/src/stridedview.jl:49
…

Stacktrace:
[1] tensorcontract!(C::Array{Float64, 4}, pC::Tuple{NTuple{4, Int64}, Tuple{}}, A::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::Matrix{Float64}, pB::Tuple{Tuple{}, Tuple{Int64, Int64}}, conjB::Symbol, Ξ±::VectorInterface.One, Ξ²::VectorInterface.Zero, #unused#::TensorOperations.Backend{:StridedBLAS})
@ TensorOperations ~/.julia/packages/TensorOperations/7VyQe/src/implementation/abstractarray.jl:63
[2] tensorcontract!(C::Array{Float64, 4}, pC::Tuple{NTuple{4, Int64}, Tuple{}}, A::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, pA::Tuple{Tuple{Int64, Int64}, Tuple{}}, conjA::Symbol, B::Matrix{Float64}, pB::Tuple{Tuple{}, Tuple{Int64, Int64}}, conjB::Symbol, Ξ±::VectorInterface.One, Ξ²::VectorInterface.Zero)
@ TensorOperations ~/.julia/packages/TensorOperations/7VyQe/src/implementation/abstractarray.jl:35
[3] (::TensorOperationsChainRulesCoreExt.var"#58#65"{Tuple{Tuple{Int64, Int64}, Tuple{}}, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, Tuple{}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, NTuple{4, Base.OneTo{Int64}}}}}})()
@ TensorOperationsChainRulesCoreExt ~/.julia/packages/TensorOperations/7VyQe/ext/TensorOperationsChainRulesCoreExt.jl:99
[4] unthunk
@ ~/.julia/packages/ChainRulesCore/7MWx2/src/tangent_types/thunks.jl:204 [inlined]
[5] wrap_chainrules_output
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:110 [inlined]
[6] map (repeats 4 times)
@ ./tuple.jl:276 [inlined]
[7] wrap_chainrules_output
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:111 [inlined]
[8] ZBack
@ ~/.julia/packages/Zygote/YYT6v/src/compiler/chainrules.jl:211 [inlined]
[9] Pullback
@ ./In[35]:25 [inlined]
[10] (::Zygote.Pullback{Tuple{typeof(tensor_conv), Array{Float64, 4}, Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 4}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 4}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#63"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, NTuple{4, Base.OneTo{Int64}}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}}})(Ξ”::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[11] Pullback
@ ./In[38]:1 [inlined]
[12] (::Zygote.Pullback{Tuple{var"#47#48", Array{Float64, 4}}, Tuple{Zygote.var"#2995#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(tensor_conv), Array{Float64, 4}, Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 4}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 4}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#63"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, NTuple{4, Base.OneTo{Int64}}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}}}}})(Ξ”::Float64)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface2.jl:0
[13] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#47#48", Array{Float64, 4}}, Tuple{Zygote.var"#2995#back#766"{Zygote.var"#760#764"{Matrix{Float64}}}, Zygote.var"#1990#back#194"{Zygote.var"#190#193"{Zygote.Context{false}, GlobalRef, Matrix{Float64}}}, Zygote.Pullback{Tuple{typeof(tensor_conv), Array{Float64, 4}, Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Array{Float64, 4}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Array{Float64, 4}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#63"{Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{Number, NamedTuple{(), Tuple{}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, NTuple{4, Base.OneTo{Int64}}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_contract_pullback#41"{Tuple{DataType, Tuple{Tuple{Int64, Int64}, Tuple{}}, Array{Float64, 4}, Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}, Symbol, Matrix{Float64}, Tuple{Tuple{Int64, Int64}, Tuple{}}, Symbol, Bool}}}, Zygote.Pullback{Tuple{typeof(TensorOperations.promote_contract), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.Pullback{Tuple{typeof(Base.promote_op), typeof(TensorOperations.tensorop), Type{Float64}, Type{Float64}}, Tuple{Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.var"#2173#back#293"{Zygote.var"#291#292"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}}, Zygote.ZBack{ChainRules.var"#apply_type_pullback#42"{Tuple{DataType, DataType}}}}}, Zygote.Pullback{Tuple{typeof(Core.Compiler.return_type), typeof(TensorOperations.tensorop), Type{Tuple{Float64, Float64}}}, Tuple{typeof(Core.Compiler.return_type)}}}}}}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}, Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}}}}}}}})(Ξ”::Float64)
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45
[14] gradient(f::Function, args::Array{Float64, 4})
@ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97
[15] var"##core#1252"()
@ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:489
[16] var"##sample#1253"(::Tuple{}, __params::BenchmarkTools.Parameters)
@ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:495
[17] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:99
[18] #invokelatest#2
@ ./essentials.jl:821 [inlined]
[19] invokelatest
@ ./essentials.jl:816 [inlined]
[20] #run_result#45
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
[21] run_result
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
[22] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117
[23] run (repeats 2 times)
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117 [inlined]
[24] #warmup#54
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:169 [inlined]
[25] warmup(item::BenchmarkTools.Benchmark)
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:168
[26] top-level scope
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:393

From you message, I understood I am not doing anything wrong. Is this right, or am I missing something?

The error might be avoided by sum(abs2, ...) instead of sum.

Zygote’s rule for sum uses FillArrays, which this PR may remove. It’s an optimisation to save a little allocation which causes all kinds of issues (normally in toy problems like this). Here @tensor doesn’t understand this type of fake array, see e.g. this question.

1 Like

Thanks, now is fixed.
It is slower than Tullio’s gradient.

@benchmark gradient(W -> sum(abs2, tullio_conv(W, x)), W)

BenchmarkTools.Trial: 3244 samples with 1 evaluation.
 Range (min … max):  966.187 ΞΌs …  17.276 ms  β”Š GC (min … max):  0.00% …  0.00%
 Time  (median):       1.160 ms               β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):     1.538 ms Β± 903.611 ΞΌs  β”Š GC (mean Β± Οƒ):  21.29% Β± 22.59%

  β–β–…β–ˆβ–ˆβ–‡β–†β–„β–ƒβ–   ▂▂▁                             ▁▂▃▃▃▂▂▂▁         ▁
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆβ–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–†β–„β–…β–‡β–…β–…β–ƒβ–β–ƒβ–β–β–β–β–β–β–β–β–β–β–β–β–β–ƒβ–β–β–…β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–†β–†β–ƒβ–†β–† β–ˆ
  966 ΞΌs        Histogram: log(frequency) by time       3.88 ms <

 Memory estimate: 7.99 MiB, allocs estimate: 624.

Regarding the performance, is there anything (in principle) that might be done to improve it?
If the problem is β€œjust” with parallelization, I might still be happy (if it doesn’t parallelize, I can just run more chains in parallel), but I would like to understand whether there is anything I can do better.

Ok. @tensor nees permutedims here, which is quite expensive. Things you might try are (1) writing rrules for Tullio, and perhaps bigger (2) re-ordering array indices to minimise permutations.

Answering later with more precise benchmarks.

  1. Array reordering puts TensorOperations.jl at the same level of Tullio for the last example.
  2. Adding Zygote.@adjoints rules improves performance.

For the first scenario (the standard matrix multiplication) I get a considerable speedup.
Without adding the rule, I have

@benchmark gradient(W -> sum(abs2, mul(W,x)), W)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):   73.643 ΞΌs …  16.139 ms  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     161.584 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   175.547 ΞΌs Β± 237.904 ΞΌs  β”Š GC (mean Β± Οƒ):  7.80% Β± 7.28%

                       ▁ β–β–‚β–„β–…β–…β–†β–„β–†β–†β–ˆβ–†β–‡β–ˆβ–‡β–…β–ƒβ–‚β–β–                     
  β–‚β–β–β–‚β–‚β–‚β–‚β–ƒβ–ƒβ–ƒβ–„β–„β–…β–…β–‡β–‡β–†β–‡β–‡β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–†β–…β–…β–„β–„β–ƒβ–ƒβ–ƒβ–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚ β–…
  73.6 ΞΌs          Histogram: frequency by time          245 ΞΌs <

 Memory estimate: 320.59 KiB, allocs estimate: 160.

After adding the rule

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  22.003 ΞΌs …   2.901 ms  β”Š GC (min … max):  0.00% … 94.00%
 Time  (median):     52.789 ΞΌs               β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):   59.898 ΞΌs Β± 122.754 ΞΌs  β”Š GC (mean Β± Οƒ):  11.28% Β±  5.44%

                     β–β–ƒβ–†β–‡β–‡β–ˆβ–…β–ƒβ–ƒβ–β–                                
  β–β–β–β–β–β–‚β–‚β–‚β–‚β–†β–‡β–ˆβ–‡β–†β–‡β–‡β–…β–…β–†β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–†β–„β–ƒβ–‚β–‚β–‚β–‚β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β– β–ƒ
  22 ΞΌs           Histogram: frequency by time          102 ΞΌs <

 Memory estimate: 160.44 KiB, allocs estimate: 80.

Later will do the same with the weirder contraction.
So, the lessons learnt are:

  1. Remember the performance tips section of Julia docs (array indexing matters)
  2. Write custom rules to boost performance.

Will update later with the more complex scenario.

1 Like

It’s not productive/human time efficient to need to implement all these rules, but it is a fairly reliable way to get good performance.
SimpleChains.jl does this to hit its performance targets.

LoopVectorization.jl should automatically be splitting those into separate loops. If you have an example where it fails to do that, you could share it.
Although I’m unlikely to get around to addressing it, I could answer questions and point someone toward how to figure out what is going on/why/how to fix it.