Take the automatic differentiation with tensor contraction

When I was doing the automatic differentiation operation encountered some problems, let’s look at my examples firstly:

Julia Version 1.9.3
julia> using TensorRules

julia> using Zygote

julia> using TensorOperations

julia> using LinearAlgebra

julia> @∇ function f1(x)
           x1 = reshape(x[1:18],(3,3,2))
           x2 = reshape(x[19:36],(3,3,2))
           @tensor x3[a,d] := x1[a,b,c] * x2[d,b,c]
           return tr(x3)
       end
f1 (generic function with 1 method)

julia> @∇ function f2(x)
           x1 = reshape(x[1:18],(3,3,2))
           x2 = reshape(x[19:36],(3,3,2))
           @tensor x3 = ncon((x1,x2),([-1,1,2],[-2,1,2]))
           return tr(x3)
       end
f2 (generic function with 1 method)
julia> x0=rand(36)
36-element Vector{Float64}:
 0.7403090894287457
 0.04934650826409148
 0.8207837917971361
 0.7327209084655903
 0.6350406011661587
 ⋮
 0.8111887446781879
 0.030783329577497076
 0.7872188382518067
 0.5265345127135576
 0.0943159141149933

julia> @show  f1(x0) f2(x0)
f1(x0) = 4.539496023288027
f2(x0) = 4.539496023288027
4.539496023288027
julia> @show gradient(f1,x0)[1]
(gradient(f1, x0))[1] = [0.8898426474337621, 0.2897372497786118, 0.9810211091684449, 0.09345708511644779, 0.23858447792575943, 0.24385079572202506, 0.1828454340913519, 0.7635199630683113, 0.05167534139207508, 0.5915424957033092, 0.11155341718052214, 0.2926138350003188, 0.5980007034355839, 0.8111887446781879, 0.030783329577497076, 0.7872188382518067, 0.5265345127135576, 0.0943159141149933, 0.7403090894287457, 0.04934650826409148, 0.8207837917971361, 0.7327209084655903, 0.6350406011661587, 0.6029706814690249, 0.6208911879774303, 0.9034078024769735, 0.08319290996558881, 0.7910177397084297, 0.4126201610953655, 0.23032762918277405, 0.13805554898704864, 0.03284589370750646, 0.6441741642051572, 0.9784624449592936, 0.6860065739408643, 0.47463419102934457]
36-element Vector{Float64}:
 0.8898426474337621
 0.2897372497786118
 0.9810211091684449
 0.09345708511644779
 0.23858447792575943
 ⋮
 0.03284589370750646
 0.6441741642051572
 0.9784624449592936
 0.6860065739408643
 0.47463419102934457

julia> @show gradient(f2,x0)[1] 
ERROR: MethodError: no method matching adjoint(::Zygote.Context{false}, ::typeof(sort), ::Vector{Int64}; rev::Bool)

Closest candidates are:
  adjoint(::ZygoteRules.AContext, ::typeof(sort), ::AbstractArray; by) got unsupported keyword argument "rev"
   @ Zygote none:0
  adjoint(::ZygoteRules.AContext, ::Base.Fix2, ::Any) got unsupported keyword argument "rev"
   @ Zygote none:0
  adjoint(::ZygoteRules.AContext, ::Base.Fix1, ::Any) got unsupported keyword argument "rev"
   @ Zygote none:0
  ...

Stacktrace:
  [1] kwerr(::NamedTuple{(:rev,), Tuple{Bool}}, ::Function, ::Zygote.Context{false}, ::Function, ::Vector{Int64})
    @ Base ./error.jl:165
  [2] adjoint
    @ ./none:0 [inlined]
  [3] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:75 [inlined]
  [4] _pullback
    @ ~/.julia/packages/TensorOperations/QcEK4/src/functions/ncon.jl:44 [inlined]
  [5] _pullback(::Zygote.Context{false}, ::TensorOperations.var"##ncon#211", ::Nothing, ::Nothing, ::typeof(ncon), ::Tuple{Array{Float64, 3}, Array{Float64, 3}}, ::Tuple{Vector{Int64}, Vector{Int64}}, ::Vector{Bool}, ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
  [6] _pullback
    @ ~/.julia/packages/TensorOperations/QcEK4/src/functions/ncon.jl:27 [inlined]
  [7] _pullback(::Zygote.Context{false}, ::typeof(ncon), ::Tuple{Array{Float64, 3}, Array{Float64, 3}}, ::Tuple{Vector{Int64}, Vector{Int64}}, ::Vector{Bool}, ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/.julia/packages/TensorOperations/QcEK4/src/functions/ncon.jl:27 [inlined]
  [9] _pullback(::Zygote.Context{false}, ::typeof(ncon), ::Tuple{Array{Float64, 3}, Array{Float64, 3}}, ::Tuple{Vector{Int64}, Vector{Int64}})
    @ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [10] _pullback
    @ ./REPL[6]:4 [inlined]
 [11] _pullback(ctx::Zygote.Context{false}, f::typeof(f2), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [12] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:44
 [13] pullback
    @ ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:42 [inlined]
 [14] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:96
 [15] top-level scope
    @ show.jl:1128

The difference between functions f1(x) and f2(x) is that they do tensor contraction and call different functions( @tensor and ncon() ).

It seems that TensorRules.jl combined with TensorOperations.jl only works for @tensor operations, not for ncon functions.

However, when I actually use TensorOperations.jl, the ncon function is much more convenient for the algorithm I built.

So how can I modify the code to make the ncon function suitable for automatic differentiation?

1 Like

Not a complete answer, but TensorOperations.jl does now support gradients, as noted e.g. in this post. Trying the code without TensorRules.jl (i.e. deleting @∇ above):

julia> @show gradient(f1,x0)[1]
ERROR: MethodError: no method matching StridedViews.StridedView(::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
Stacktrace:
  [1] tensorcontract!(C::Array{…}, pC::Tuple{…}, A::Diagonal{…}, pA::Tuple{…}, conjA::Symbol, B::Array{…}, pB::Tuple{…}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero, ::TensorOperations.Backend{…})
    @ TensorOperations ~/.julia/packages/TensorOperations/7VyQe/src/implementation/diagonal.jl:32
  [2] tensorcontract!(C::Array{…}, pC::Tuple{…}, A::Diagonal{…}, pA::Tuple{…}, conjA::Symbol, B::Array{…}, pB::Tuple{…}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero)
    @ TensorOperations ~/.julia/packages/TensorOperations/7VyQe/src/implementation/abstractarray.jl:37
  [3] (::TensorOperationsChainRulesCoreExt.var"#58#65"{…})()
    @ TensorOperationsChainRulesCoreExt ~/.julia/packages/TensorOperations/7VyQe/ext/TensorOperationsChainRulesCoreExt.jl:99

TensorOperationsChainRulesCoreExt is meant to support Zygote. But Strided + FillArrays seems to be a bad combination. It’s produced by some gradient rules, which we can avoid like so:

julia> function f1(x)
           x1 = reshape(x[1:18],(3,3,2))
           x2 = reshape(x[19:36],(3,3,2))
           @tensor x3[a,d] := x1[a,b,c] * x2[d,b,c]
           return sum(identity, diag(x3))  # instead of tr
       end
f1 (generic function with 1 method)

julia> @show  f1(x0) f2(x0);
f1(x0) = 2.7627671592513434
f2(x0) = 2.7627671592513434

julia> @show gradient(f1,x0)[1];
(gradient(f1, x0))[1] = [0.8092069165842924, 0.9687321783662441, 0.3597888712665691, 0.35083048773751657, 0.5084430655437714, 0.09401047700572329,

julia> @show gradient(f2,x0)[1]
ERROR: MethodError: no method matching adjoint(::Zygote.Context{false}, ::typeof(sort), ::Vector{Int64}; rev::Bool)

So the ncon form still doesn’t work. But I suggest you make an issue on TensorOperations.jl about this, since it seems the intention is to support derivatives.

(And perhaps a second issue there about the Strided + FillArrays problem.)

1 Like

Sorry for noticing this a bit late. I have opened an issue to track the progress and added this to my TODO list. I will definitely add the support for ncon, and will have a look if there is a way to make sure the Strided+FillArrays problem can be handled.

1 Like