Optimization on Stiefel manifold with auto-differentiation

I want to minimize the energy E of some quantum state, which is a real-valued function ultimately depending on a square real orthogonal matrix X (i.e. matrices in the Stiefel manifold). The (Euclidean) gradients of the energy with elements in X needs to be calculated by auto-differentiation.

From what I know, optimization on the Stiefel manifold can be done using the Manopt.jl package, or OptimKit.jl with TensorKitManifolds.jl. The auto-differentiation part can be handled by Zygote.jl or Enzyme.jl. Unfortunately, I have zero experience on any of them, and it’s a bit hard for me to follow the documentation (in particular, OptimKit doesn’t have docs yet).

If someone can help provide me with a minimal example on such a problem, using a relatively simple target function f(X) to be optimized to demonstrate the usage of these packages, it would be of great help of me. Thanks!


Appendix: The exact expression of E(X) that I’m trying to optimize is

E = \sum_{k,\sigma} \xi_k \rho_{k\sigma} + 2 \sum_k \operatorname{Re}(\Delta_k \eta_k)

where k sums over the momenta in the 1st Brillouin zone of a square lattice with periodic or anti-periodic boundary condition, and

\begin{align} \xi_k &= -2t (\cos k_x + \cos k_y) - \mu \\ \Delta_k &= 2\Delta (\cos k_x - \cos k_y) \\ \sum_k \xi_k \rho_{k\uparrow} &= \frac{1}{2} \sum_k \xi_k [1 - (G_k)_{12}] \\ \sum_k \xi_k \rho_{k\downarrow} &= \frac{1}{2} \sum_k \xi_k [1 - (G_k)_{34}] \\ \eta_k &= -\frac{1}{4}[ (G_k)_{41} + (G_k)_{32} + i (G_k)_{42} - i(G_k)_{31} ] \end{align}

Here t, \Delta, \mu are 3 fixed parameters. The matrix G_k is related to the orthogonal matrix X in the following way:

G_k = A + B(D + G_\omega(k))^{-1} B^\mathsf{T}
G_\omega(k) = \left[\bigoplus_{i=1}^{\chi} \begin{bmatrix} & -e^{-ik_x} \sigma_x \\ e^{ik_x} \sigma_x & \end{bmatrix}\right] \oplus \left[\bigoplus_{i=1}^{\chi} \begin{bmatrix} & -e^{-ik_y} \sigma_x \\ e^{ik_y} \sigma_x & \end{bmatrix}\right]

Here \sigma_x is the x Pauli matrix, and \chi is a positive integer (which is not very big). The matrices A, B, D are given by

\begin{bmatrix} A & B \\ -B^\mathsf{T} & D \end{bmatrix} = X^\mathsf{T} \left( \bigoplus_{i=1}^{4\chi+2} \begin{bmatrix} 0 & 1 \\ -1 & 0 \end{bmatrix} \right) X

The size of the matrix X is then (8\chi + 4) \times (8\chi + 4). You can see that it is really not feasible not to use auto-differentiation…

1 Like

Without some concrete code it is a bit hard to help. But if you compute the gradient with Zygote or Enzyme, the thing missing is, that this is the Euclidean and not the Riemannian gradient, see the full explanation e.g. at

While I am not so sure that ManifoldDiff.jl can directly work with Zygote or Enzyme (I think it should), you can also riemannian_gradient directly for a conversion.

Your appendix is nice for context but nothing that I easily could implement nor test with anything. It is too complex for that - and misses concrete data.
So I can not actually help with code here.

You could also just optimize f(X) = f(Y (Y^T Y)^{-1/2}) over unconstrained real square matrices Y.

The polar decomposition X = Y (Y^T Y)^{-1/2} automatically satisfies X^T X = I, and the function Y / sqrt(Hermitian(Y'Y)) should be differentiable with standard AD packages (e.g. ChainRules.jl has a rule for the symmetric-matrix square root).

(This is a generalization of a simple trick to optimize on the unit sphere: Optimization on unit sphere? - #3 by stevengj).

6 Likes

@stevengj This is a very useful trick. Thanks! I’ll try this if I encounter issues with Manopt.

@kellertuer The code is in my GitHub repo

The energy function can be constructed from

using GaussianfPEPS

# set energy parameters
# for "d-wave" state, Δx = Δy = -Δ
t, Δx, Δy, mu = 1.0, 0.5, -0.5, -0.6
# create the Brillouin zone with anti-PBC on x-direction, and PBC on y-direction
# for a square lattice with 100 x 100 sites
bz = BrillouinZone((100, 100), (false, true))
# Np is a constant, which is often just set to 2
# (corresponding to spin-1/2 fermions)
Np = 2
# energy function
f(X) = BCS.energy_peps(fiducial_cormat(X), bz, Np; t, Δx, Δy, mu)

which aims to be the Julia equivalent of

which was written in Python using PyManopt. The AD was handled with JAX there. It seems that PyManopt has some way to convert the Euclidean derivatives produced by JAX to a manifold derivative. I tried to be compatible with it, but it is using a convention different from even its own accompanying paper, so I gave up figuring it out.

I encounter some problems due to sparse matrices in my code. For now I just apply @stevengj’s trick so I don’t need Manopt to get the derivatives.

using Zygote, Random, LinearAlgebra, GaussianfPEPS

t, Δx, Δy, mu = 1.0, 0.5, -0.5, -0.6
bz = BrillouinZone((10, 10), (false, true))
Np = 2

function myfun(Y::AbstractMatrix)
    X = Y / sqrt(Hermitian(Y' * Y))
    G = fiducial_cormat(X)
    return BCS.energy_peps(G, bz, Np; t, Δx, Δy, mu)
end

χ = 1
N = Np + 4χ
Random.seed!(0)
x = rand(2N, 2N);
g = Zygote.gradient(myfun, x)

Then I get error

Need an adjoint for constructor SparseArrays.SparseMatrixCSC{ComplexF64, Int64}. Gradient is of type Matrix{ComplexF64}

Does this mean that I need to define an rrule for the sparse matrices I created?

Not quite, as explained here: Zygote.jl: How to get the gradient of sparse matrix - #6 by stevengj — if you are going to write an rrule, it typically has to be for the function that includes both the sparse-matrix construction and how you use the sparse matrix. (This isn’t usually so hard, though.)

Where are the sparse matrices in your problem description above?

(You might also try Enzyme.jl.)

The sparse matrices are G_\omega(k) (called cormat_virtual in my code) and J = \bigoplus_{i=1}^{4\chi+2} \begin{bmatrix} 0 & 1 \\ -1 & 0 \end{bmatrix} (its construction function in my code is get_J).

I found out that Enzyme supports array mutations. So I can avoid using sparse matrices if Enzyme also has some trouble in handling them.

Now I no longer construct J explicitly, but G_\omega(k) is still a sparse matrix. When using Enzyme, although sparse matrix seems to cause no trouble, it still complains “Enzyme compilation failed due to illegal type analysis.”

I take the derivative using the “convenient” function gradient

g = Enzyme.gradient(Reverse, myfun, x)

The error message (which is quite long) begins with

Enzyme compilation failed due to illegal type analysis.
 This usually indicates the use of a Union type, which is not fully supported with Enzyme.API.strictAliasing set to true [the default].
 Ideally, remove the union (which will also make your code faster), or try setting Enzyme.API.strictAliasing!(false) before any autodiff call.
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)
 Failure within method: MethodInstance for getproperty(::QRPivoted{ComplexF64, Matrix{ComplexF64}, Vector{ComplexF64}, Vector{Int64}}, ::Symbol)
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erroneous code with Cthulhu.jl
Caused by:
Stacktrace:
 [1] triu!
   @ ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:441
 [2] getproperty
   @ ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/qr.jl:497

But I’m not aware of any union types in my code relevant to energy_peps.