When I was doing the automatic differentiation operation encountered some problems, let’s look at my examples firstly:
Julia Version 1.9.3
julia> using TensorRules
julia> using Zygote
julia> using TensorOperations
julia> using LinearAlgebra
julia> @∇ function f1(x)
x1 = reshape(x[1:18],(3,3,2))
x2 = reshape(x[19:36],(3,3,2))
@tensor x3[a,d] := x1[a,b,c] * x2[d,b,c]
return tr(x3)
end
f1 (generic function with 1 method)
julia> @∇ function f2(x)
x1 = reshape(x[1:18],(3,3,2))
x2 = reshape(x[19:36],(3,3,2))
@tensor x3 = ncon((x1,x2),([-1,1,2],[-2,1,2]))
return tr(x3)
end
f2 (generic function with 1 method)
julia> x0=rand(36)
36-element Vector{Float64}:
0.7403090894287457
0.04934650826409148
0.8207837917971361
0.7327209084655903
0.6350406011661587
⋮
0.8111887446781879
0.030783329577497076
0.7872188382518067
0.5265345127135576
0.0943159141149933
julia> @show f1(x0) f2(x0)
f1(x0) = 4.539496023288027
f2(x0) = 4.539496023288027
4.539496023288027
julia> @show gradient(f1,x0)[1]
(gradient(f1, x0))[1] = [0.8898426474337621, 0.2897372497786118, 0.9810211091684449, 0.09345708511644779, 0.23858447792575943, 0.24385079572202506, 0.1828454340913519, 0.7635199630683113, 0.05167534139207508, 0.5915424957033092, 0.11155341718052214, 0.2926138350003188, 0.5980007034355839, 0.8111887446781879, 0.030783329577497076, 0.7872188382518067, 0.5265345127135576, 0.0943159141149933, 0.7403090894287457, 0.04934650826409148, 0.8207837917971361, 0.7327209084655903, 0.6350406011661587, 0.6029706814690249, 0.6208911879774303, 0.9034078024769735, 0.08319290996558881, 0.7910177397084297, 0.4126201610953655, 0.23032762918277405, 0.13805554898704864, 0.03284589370750646, 0.6441741642051572, 0.9784624449592936, 0.6860065739408643, 0.47463419102934457]
36-element Vector{Float64}:
0.8898426474337621
0.2897372497786118
0.9810211091684449
0.09345708511644779
0.23858447792575943
⋮
0.03284589370750646
0.6441741642051572
0.9784624449592936
0.6860065739408643
0.47463419102934457
julia> @show gradient(f2,x0)[1]
ERROR: MethodError: no method matching adjoint(::Zygote.Context{false}, ::typeof(sort), ::Vector{Int64}; rev::Bool)
Closest candidates are:
adjoint(::ZygoteRules.AContext, ::typeof(sort), ::AbstractArray; by) got unsupported keyword argument "rev"
@ Zygote none:0
adjoint(::ZygoteRules.AContext, ::Base.Fix2, ::Any) got unsupported keyword argument "rev"
@ Zygote none:0
adjoint(::ZygoteRules.AContext, ::Base.Fix1, ::Any) got unsupported keyword argument "rev"
@ Zygote none:0
...
Stacktrace:
[1] kwerr(::NamedTuple{(:rev,), Tuple{Bool}}, ::Function, ::Zygote.Context{false}, ::Function, ::Vector{Int64})
@ Base ./error.jl:165
[2] adjoint
@ ./none:0 [inlined]
[3] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:75 [inlined]
[4] _pullback
@ ~/.julia/packages/TensorOperations/QcEK4/src/functions/ncon.jl:44 [inlined]
[5] _pullback(::Zygote.Context{false}, ::TensorOperations.var"##ncon#211", ::Nothing, ::Nothing, ::typeof(ncon), ::Tuple{Array{Float64, 3}, Array{Float64, 3}}, ::Tuple{Vector{Int64}, Vector{Int64}}, ::Vector{Bool}, ::Nothing)
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[6] _pullback
@ ~/.julia/packages/TensorOperations/QcEK4/src/functions/ncon.jl:27 [inlined]
[7] _pullback(::Zygote.Context{false}, ::typeof(ncon), ::Tuple{Array{Float64, 3}, Array{Float64, 3}}, ::Tuple{Vector{Int64}, Vector{Int64}}, ::Vector{Bool}, ::Nothing)
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[8] _pullback
@ ~/.julia/packages/TensorOperations/QcEK4/src/functions/ncon.jl:27 [inlined]
[9] _pullback(::Zygote.Context{false}, ::typeof(ncon), ::Tuple{Array{Float64, 3}, Array{Float64, 3}}, ::Tuple{Vector{Int64}, Vector{Int64}})
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[10] _pullback
@ ./REPL[6]:4 [inlined]
[11] _pullback(ctx::Zygote.Context{false}, f::typeof(f2), args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[12] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:44
[13] pullback
@ ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:42 [inlined]
[14] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:96
[15] top-level scope
@ show.jl:1128
The difference between functions f1(x) and f2(x) is that they do tensor contraction and call different functions( @tensor and ncon() ).
It seems that TensorRules.jl combined with TensorOperations.jl only works for @tensor operations, not for ncon functions.
However, when I actually use TensorOperations.jl, the ncon function is much more convenient for the algorithm I built.
So how can I modify the code to make the ncon function suitable for automatic differentiation?