# Take the automatic differentiation with tensor contraction

When I was doing the automatic differentiation operation encountered some problems, let’s look at my examples firstly:

``````Julia Version 1.9.3
julia> using TensorRules

julia> using Zygote

julia> using TensorOperations

julia> using LinearAlgebra

julia> @∇ function f1(x)
x1 = reshape(x[1:18],(3,3,2))
x2 = reshape(x[19:36],(3,3,2))
@tensor x3[a,d] := x1[a,b,c] * x2[d,b,c]
return tr(x3)
end
f1 (generic function with 1 method)

julia> @∇ function f2(x)
x1 = reshape(x[1:18],(3,3,2))
x2 = reshape(x[19:36],(3,3,2))
@tensor x3 = ncon((x1,x2),([-1,1,2],[-2,1,2]))
return tr(x3)
end
f2 (generic function with 1 method)
julia> x0=rand(36)
36-element Vector{Float64}:
0.7403090894287457
0.04934650826409148
0.8207837917971361
0.7327209084655903
0.6350406011661587
⋮
0.8111887446781879
0.030783329577497076
0.7872188382518067
0.5265345127135576
0.0943159141149933

julia> @show  f1(x0) f2(x0)
f1(x0) = 4.539496023288027
f2(x0) = 4.539496023288027
4.539496023288027
(gradient(f1, x0))[1] = [0.8898426474337621, 0.2897372497786118, 0.9810211091684449, 0.09345708511644779, 0.23858447792575943, 0.24385079572202506, 0.1828454340913519, 0.7635199630683113, 0.05167534139207508, 0.5915424957033092, 0.11155341718052214, 0.2926138350003188, 0.5980007034355839, 0.8111887446781879, 0.030783329577497076, 0.7872188382518067, 0.5265345127135576, 0.0943159141149933, 0.7403090894287457, 0.04934650826409148, 0.8207837917971361, 0.7327209084655903, 0.6350406011661587, 0.6029706814690249, 0.6208911879774303, 0.9034078024769735, 0.08319290996558881, 0.7910177397084297, 0.4126201610953655, 0.23032762918277405, 0.13805554898704864, 0.03284589370750646, 0.6441741642051572, 0.9784624449592936, 0.6860065739408643, 0.47463419102934457]
36-element Vector{Float64}:
0.8898426474337621
0.2897372497786118
0.9810211091684449
0.09345708511644779
0.23858447792575943
⋮
0.03284589370750646
0.6441741642051572
0.9784624449592936
0.6860065739408643
0.47463419102934457

ERROR: MethodError: no method matching adjoint(::Zygote.Context{false}, ::typeof(sort), ::Vector{Int64}; rev::Bool)

Closest candidates are:
adjoint(::ZygoteRules.AContext, ::typeof(sort), ::AbstractArray; by) got unsupported keyword argument "rev"
@ Zygote none:0
adjoint(::ZygoteRules.AContext, ::Base.Fix2, ::Any) got unsupported keyword argument "rev"
@ Zygote none:0
adjoint(::ZygoteRules.AContext, ::Base.Fix1, ::Any) got unsupported keyword argument "rev"
@ Zygote none:0
...

Stacktrace:
[1] kwerr(::NamedTuple{(:rev,), Tuple{Bool}}, ::Function, ::Zygote.Context{false}, ::Function, ::Vector{Int64})
@ Base ./error.jl:165
@ ./none:0 [inlined]
[3] _pullback
[4] _pullback
@ ~/.julia/packages/TensorOperations/QcEK4/src/functions/ncon.jl:44 [inlined]
[5] _pullback(::Zygote.Context{false}, ::TensorOperations.var"##ncon#211", ::Nothing, ::Nothing, ::typeof(ncon), ::Tuple{Array{Float64, 3}, Array{Float64, 3}}, ::Tuple{Vector{Int64}, Vector{Int64}}, ::Vector{Bool}, ::Nothing)
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[6] _pullback
@ ~/.julia/packages/TensorOperations/QcEK4/src/functions/ncon.jl:27 [inlined]
[7] _pullback(::Zygote.Context{false}, ::typeof(ncon), ::Tuple{Array{Float64, 3}, Array{Float64, 3}}, ::Tuple{Vector{Int64}, Vector{Int64}}, ::Vector{Bool}, ::Nothing)
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[8] _pullback
@ ~/.julia/packages/TensorOperations/QcEK4/src/functions/ncon.jl:27 [inlined]
[9] _pullback(::Zygote.Context{false}, ::typeof(ncon), ::Tuple{Array{Float64, 3}, Array{Float64, 3}}, ::Tuple{Vector{Int64}, Vector{Int64}})
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[10] _pullback
@ ./REPL[6]:4 [inlined]
[11] _pullback(ctx::Zygote.Context{false}, f::typeof(f2), args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[12] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:44
[13] pullback
@ ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:42 [inlined]
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:96
[15] top-level scope
@ show.jl:1128
``````

The difference between functions f1(x) and f2(x) is that they do tensor contraction and call different functions( @tensor and ncon() ).

It seems that TensorRules.jl combined with TensorOperations.jl only works for @tensor operations, not for ncon functions.

However, when I actually use TensorOperations.jl, the ncon function is much more convenient for the algorithm I built.

So how can I modify the code to make the ncon function suitable for automatic differentiation?

1 Like

Not a complete answer, but TensorOperations.jl does now support gradients, as noted e.g. in this post. Trying the code without TensorRules.jl (i.e. deleting `@∇` above):

``````julia> @show gradient(f1,x0)[1]
ERROR: MethodError: no method matching StridedViews.StridedView(::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
Stacktrace:
[1] tensorcontract!(C::Array{…}, pC::Tuple{…}, A::Diagonal{…}, pA::Tuple{…}, conjA::Symbol, B::Array{…}, pB::Tuple{…}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero, ::TensorOperations.Backend{…})
@ TensorOperations ~/.julia/packages/TensorOperations/7VyQe/src/implementation/diagonal.jl:32
[2] tensorcontract!(C::Array{…}, pC::Tuple{…}, A::Diagonal{…}, pA::Tuple{…}, conjA::Symbol, B::Array{…}, pB::Tuple{…}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero)
@ TensorOperations ~/.julia/packages/TensorOperations/7VyQe/src/implementation/abstractarray.jl:37
[3] (::TensorOperationsChainRulesCoreExt.var"#58#65"{…})()
@ TensorOperationsChainRulesCoreExt ~/.julia/packages/TensorOperations/7VyQe/ext/TensorOperationsChainRulesCoreExt.jl:99
``````

TensorOperationsChainRulesCoreExt is meant to support Zygote. But Strided + FillArrays seems to be a bad combination. It’s produced by some gradient rules, which we can avoid like so:

``````julia> function f1(x)
x1 = reshape(x[1:18],(3,3,2))
x2 = reshape(x[19:36],(3,3,2))
@tensor x3[a,d] := x1[a,b,c] * x2[d,b,c]
return sum(identity, diag(x3))  # instead of tr
end
f1 (generic function with 1 method)

julia> @show  f1(x0) f2(x0);
f1(x0) = 2.7627671592513434
f2(x0) = 2.7627671592513434

(gradient(f1, x0))[1] = [0.8092069165842924, 0.9687321783662441, 0.3597888712665691, 0.35083048773751657, 0.5084430655437714, 0.09401047700572329,

Sorry for noticing this a bit late. I have opened an issue to track the progress and added this to my TODO list. I will definitely add the support for `ncon`, and will have a look if there is a way to make sure the Strided+FillArrays problem can be handled.