Take the automatic differentiation with tensor contraction

Not a complete answer, but TensorOperations.jl does now support gradients, as noted e.g. in this post. Trying the code without TensorRules.jl (i.e. deleting @∇ above):

julia> @show gradient(f1,x0)[1]
ERROR: MethodError: no method matching StridedViews.StridedView(::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
Stacktrace:
  [1] tensorcontract!(C::Array{…}, pC::Tuple{…}, A::Diagonal{…}, pA::Tuple{…}, conjA::Symbol, B::Array{…}, pB::Tuple{…}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero, ::TensorOperations.Backend{…})
    @ TensorOperations ~/.julia/packages/TensorOperations/7VyQe/src/implementation/diagonal.jl:32
  [2] tensorcontract!(C::Array{…}, pC::Tuple{…}, A::Diagonal{…}, pA::Tuple{…}, conjA::Symbol, B::Array{…}, pB::Tuple{…}, conjB::Symbol, α::VectorInterface.One, β::VectorInterface.Zero)
    @ TensorOperations ~/.julia/packages/TensorOperations/7VyQe/src/implementation/abstractarray.jl:37
  [3] (::TensorOperationsChainRulesCoreExt.var"#58#65"{…})()
    @ TensorOperationsChainRulesCoreExt ~/.julia/packages/TensorOperations/7VyQe/ext/TensorOperationsChainRulesCoreExt.jl:99

TensorOperationsChainRulesCoreExt is meant to support Zygote. But Strided + FillArrays seems to be a bad combination. It’s produced by some gradient rules, which we can avoid like so:

julia> function f1(x)
           x1 = reshape(x[1:18],(3,3,2))
           x2 = reshape(x[19:36],(3,3,2))
           @tensor x3[a,d] := x1[a,b,c] * x2[d,b,c]
           return sum(identity, diag(x3))  # instead of tr
       end
f1 (generic function with 1 method)

julia> @show  f1(x0) f2(x0);
f1(x0) = 2.7627671592513434
f2(x0) = 2.7627671592513434

julia> @show gradient(f1,x0)[1];
(gradient(f1, x0))[1] = [0.8092069165842924, 0.9687321783662441, 0.3597888712665691, 0.35083048773751657, 0.5084430655437714, 0.09401047700572329,

julia> @show gradient(f2,x0)[1]
ERROR: MethodError: no method matching adjoint(::Zygote.Context{false}, ::typeof(sort), ::Vector{Int64}; rev::Bool)

So the ncon form still doesn’t work. But I suggest you make an issue on TensorOperations.jl about this, since it seems the intention is to support derivatives.

(And perhaps a second issue there about the Strided + FillArrays problem.)

1 Like