Thanks Professor! The sample code is quite useful. Now the AD part can run without problem, but there are still issues in the last step with optimize
. I think it’s easier to format things here than in emails.
First, to get Zygote working, I have to construct G_\omega(k) without any use of the SparseArray package, and store it as a dense matrix (these changes are in my ad-opt
branch). The code I end up using is:
using Random
using LinearAlgebra
using GaussianfPEPS
using GaussianfPEPS: rand_orth
using Zygote
using TensorKit, OptimKit, TensorKitManifolds
import TensorKitManifolds.Stiefel: StiefelTangent
# d-wave BCS Hamiltonian parameters
t, Δx, Δy, mu = 1.0, 0.5, -0.5, 0.0
# number of physical (complex) fermions
# and virtual (complex) fermions along each lattice direction
Np, χ = 2, 1
N = Np + 4χ
Random.seed!(0)
# create a random real orthogonal matrix
x = TensorMap(rand_orth(2N), ℂ^(2N) → ℂ^(2N));
# start with a small BZ for quick tests
Lx, Ly = 10, 10
bz = BrillouinZone((Lx, Ly), (false, true));
function myfun(X)
# extract matrix elements from the TensorMap
G = fiducial_cormat(reshape(X.data, (2N, 2N)))
return BCS.energy_peps(G, bz, Np; t, Δx, Δy, mu)
end
function fg(X)
fval, pb = Zygote.pullback(myfun, X)
grad = pb(1.)[1] # sample code doesn't have this [1]
proj_grad = TensorKitManifolds.Stiefel.project!(grad, X)
return fval, proj_grad
end
myscale!(Δ::StiefelTangent, α::Real) = rmul!(Δ, α)
myadd!(Δy::StiefelTangent, Δx::StiefelTangent, α::Real) = axpy!(α, Δx, Δy)
result = optimize(
fg, x, LBFGS(); retract = Stiefel.retract, inner = Stiefel.inner,
transport = Stiefel.transport, scale! = myscale!, add! = myadd!
)
The function fg
can run successfully for a given input TensorMap X
. But when calling optimize
, I get the error (I’m using OptimKit v0.4.0):
MethodError: no method matching optimize(::typeof(fg), ::TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, ::LBFGS{Float64, HagerZhangLineSearch{Rational{Int64}}}; retract::typeof(retract), inner::typeof(TensorKitManifolds.inner), transport::typeof(transport), scale!::typeof(myscale!), add!::typeof(myadd!))
This error has been manually thrown, explicitly, so the method may exist but be intentionally marked as unimplemented.
Closest candidates are:
optimize(::Any, ::Any, ::LBFGS; precondition, finalize!, shouldstop, hasconverged, retract, inner, transport!, scale!, add!, isometrictransport) got unsupported keyword argument "transport"
@ OptimKit ~/.julia/packages/OptimKit/G6i79/src/lbfgs.jl:56
optimize(::Any, ::Any, !Matched::ConjugateGradient; precondition, finalize!, shouldstop, hasconverged, retract, inner, transport!, scale!, add!, isometrictransport) got unsupported keyword argument "transport"
@ OptimKit ~/.julia/packages/OptimKit/G6i79/src/cg.jl:68
optimize(::Any, ::Any, !Matched::GradientDescent; precondition, finalize!, shouldstop, hasconverged, retract, inner, transport!, scale!, add!, isometrictransport) got unsupported keyword argument "transport"
@ OptimKit ~/.julia/packages/OptimKit/G6i79/src/gd.jl:51
Removing the transport
argument doesn’t solve the problem; the error message become
ArgumentError: arrays must have the same axes for copy! (consider using `copyto!`)
Stacktrace:
[1] copy!
@ ./abstractarray.jl:920 [inlined]
[2] stiefelexp(W::TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, A::TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, Z::TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, α::Float64)
@ TensorKitManifolds.Stiefel ~/.julia/packages/TensorKitManifolds/D6miB/src/stiefel.jl:178
[3] retract_exp(W::TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, Δ::StiefelTangent{TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}}, α::Float64)
@ TensorKitManifolds.Stiefel ~/.julia/packages/TensorKitManifolds/D6miB/src/stiefel.jl:188
[4] #retract#4
@ ~/.julia/packages/TensorKitManifolds/D6miB/src/stiefel.jl:104 [inlined]
[5] retract
@ ~/.julia/packages/TensorKitManifolds/D6miB/src/stiefel.jl:102 [inlined]
[6] takestep(iter::OptimKit.HagerZhangLineSearchIterator{Float64, typeof(fg), typeof(retract), typeof(TensorKitManifolds.inner), TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, StiefelTangent{TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}}, Rational{Int64}}, α::Float64)
@ OptimKit ~/.julia/packages/OptimKit/G6i79/src/linesearches.jl:279
[7] iterate(iter::OptimKit.HagerZhangLineSearchIterator{Float64, typeof(fg), typeof(retract), typeof(TensorKitManifolds.inner), TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, StiefelTangent{TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}}, Rational{Int64}})
@ OptimKit ~/.julia/packages/OptimKit/G6i79/src/linesearches.jl:180
[8] (::HagerZhangLineSearch{Rational{Int64}})(fg::typeof(fg), x₀::TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, η₀::StiefelTangent{TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}}, fg₀::Tuple{Float64, StiefelTangent{TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}}}; retract::Function, inner::typeof(TensorKitManifolds.inner), initialguess::Float64, acceptfirst::Bool, maxiter::Int64, maxfg::Int64, verbosity::Int64)
@ OptimKit ~/.julia/packages/OptimKit/G6i79/src/linesearches.jl:131
[9] optimize(fg::typeof(fg), x::TensorMap{Float64, ComplexSpace, 1, 1, Vector{Float64}}, alg::LBFGS{Float64, HagerZhangLineSearch{Rational{Int64}}}; precondition::typeof(OptimKit._precondition), finalize!::typeof(OptimKit._finalize!), shouldstop::OptimKit.DefaultShouldStop, hasconverged::OptimKit.DefaultHasConverged{Float64}, retract::Function, inner::typeof(TensorKitManifolds.inner), transport!::typeof(OptimKit._transport!), scale!::typeof(myscale!), add!::typeof(myadd!), isometrictransport::Bool)
@ OptimKit ~/.julia/packages/OptimKit/G6i79/src/lbfgs.jl:107
[10] top-level scope
@ ~/GitHub/Playground_Julia/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X13sZmlsZQ==.jl:1