Optimization on Stiefel manifold with auto-differentiation

You could also just optimize f(X) = f(Y (Y^T Y)^{-1/2}) over unconstrained real square matrices Y.

The polar decomposition X = Y (Y^T Y)^{-1/2} automatically satisfies X^T X = I, and the function Y / sqrt(Hermitian(Y'Y)) should be differentiable with standard AD packages (e.g. ChainRules.jl has a rule for the symmetric-matrix square root).

(This is a generalization of a simple trick to optimize on the unit sphere: Optimization on unit sphere? - #3 by stevengj).

6 Likes