Efficient optimization of objective with cos and sin terms

Another thing you can do is use Gröbner bases to obtain an explicit expression for the objective itself. Just add the objective to the previous system, introducing an auxiliary variable s:

y \langle z_1x + z_2y-d,z_1\rangle = x \langle z_1x + z_2y-d,z_2\rangle \\ x^2 + y^2 - 1 = 0 \\ \|z_1x + z_2y-d\|_2^2 - s

Now ask Mathematica to eliminate x,y. It will return you a (horrible) degree 4 polynomial in s. No matter, Julia can handle it:

using Polynomials, LinearAlgebra, Plots

function problem_obj(dim)
    z1 = randn(dim)
    z2 = randn(dim)
    d = randn(dim)
    J(t) = norm(z1 * cos(t) + z2 * sin(t) - d)^2
    t = -π:0.01:π
    values = objective(z1, z2, d)
    real_values = real(filter(isreal, values))
    real_min = minimum(real_values)
    grid_min = minimum(J.(t))
    display(plot(t, J.(t)))
    return real_min, grid_min
end

function objective(z1, z2, d)
    k = dot(z1, z2)
    n1 = norm(z1)^2
    n2 = norm(z2)^2
    w1 = dot(d, z1)
    w2 = dot(d, z2)
    l = norm(d)^2
    c0 =
        4 * k^6 + l^4 * (n1 - n2)^2 + (n1^2 - 4w1^2) * (-n1 * n2 + n2^2 + w1^2)^2 + 36k^3 * (2l + n1 + n2) * w1 * w2 -
        2 * (n1 * (2n1^3 - 4n1^2 * n2 + n1 * n2^2 + n2^3) + (-10n1^2 + 19n1 * n2 - 10n2^2) * w1^2 + 6w1^4) * w2^2 +
        (-8n1^2 + 8n1 * n2 + n2^2 - 12w1^2) * w2^4 - 4w2^6 +
        2l^3 * (n1 - n2) * (n1^2 - n2^2 - w1^2 + w2^2) +
        4k * w1 * w2 * ((l - n1 + 2n2) * (-((l + 2n1 - n2) * (2l + n1 + n2)) + 9w1^2) + 9 * (l + 2n1 - n2) * w2^2) +
        k^4 * (-8l^2 + n1^2 - 10n1 * n2 + n2^2 - 8l * (n1 + n2) - 12 * (w1^2 + w2^2)) +
        2l * (
            n1^4 * n2 - n1^3 * (n2^2 + w1^2 + 4 * w2^2) +
            n2 * (-4w1^2 * (n2^2 + w1^2) + (-n2 + w1) * (n2 + w1) * w2^2 + 5w2^4) -
            n1^2 * n2 * (n2^2 + 5 * (w1^2 - 2w2^2)) +
            n1 * (n2^4 + 5w1^4 + w1^2 * w2^2 - 4w2^4 + 5n2^2 * (2w1^2 - w2^2))
        ) +
        2k^2 * (
            2l^4 - n1^3 * n2 + 4n1^2 * n2^2 - n1 * n2^3 + 4 * l^3 * (n1 + n2) - n1^2 * w1^2 + n1 * n2 * w1^2 -
            10n2^2 * w1^2 + 6w1^4 - (10n1^2 - n1 * n2 + n2^2 + 42w1^2) * w2^2 + 6w2^4 -
            l * (n1^3 - 5n1^2 * n2 - 5n1 * n2^2 + n2^3 + n1 * w1^2 + 19n2 * w1^2 + (19n1 + n2) * w2^2) +
            l^2 * (n1^2 + 10n1 * n2 + n2^2 - 10 * (w1^2 + w2^2))
        ) +
        l^2 * (
            n1^4 + 2n1^3 * n2 + n2^4 + (w1^2 + w2^2)^2 - 2n1^2 * (3n2^2 + 4w1^2 + w2^2) - 2n2^2 * (w1^2 + 4w2^2) +
            2n1 * n2 * (n2^2 + 5 * (w1^2 + w2^2))
        )

    c1 =
        2 * (
            -2l^3 * (n1 - n2)^2 - n1^4 * n2 + n1^3 * n2^2 + n1^2 * n2^3 - n1 * n2^4 +
            4k^4 * (2l + n1 + n2) +
            n1^3 * w1^2 +
            5n1^2 * n2 * w1^2 - 10n1 * n2^2 * w1^2 + 4n2^3 * w1^2 - 5n1 * w1^4 + 4n2 * w1^4 - 36k^3 * w1 * w2 +
            (4n1^3 - 10n1^2 * n2 + 5n1 * n2^2 + n2^3 - (n1 + n2) * w1^2) * w2^2 +
            (4n1 - 5n2) * w2^4 - 3l^2 * (n1 - n2) * (n1^2 - n2^2 - w1^2 + w2^2) +
            l * (
                -(n1 - n2)^2 * (n1^2 + 4n1 * n2 + n2^2) + 2 * (4n1^2 - 5n1 * n2 + n2^2) * w1^2 - w1^4 +
                2 * (n1^2 - 5n1 * n2 + 4 * n2^2 - w1^2) * w2^2 - w2^4
            ) - 6k * w1 * w2 * (-2l^2 + n1^2 - 4n1 * n2 + n2^2 - 2l * (n1 + n2) + 3(w1^2 + w2^2)) +
            k^2 * (
                -8l^3 + n1^3 - 5n1^2 * n2 - 5n1 * n2^2 + n2^3 - 12l^2 * (n1 + n2) +
                n1 * w1^2 +
                19n2 * w1^2 +
                (19n1 + n2) * w2^2 - 2l * (n1^2 + 10n1 * n2 + n2^2 - 10 * (w1^2 + w2^2))
            )
        )

    c2 =
        -8k^4 +
        24k^2 * l^2 +
        24k^2 * l * n1 +
        2k^2 * n1^2 +
        6l^2 * n1^2 +
        6l * n1^3 +
        n1^4 +
        24k^2 * l * n2 +
        20k^2 * n1 * n2 - 12l^2 * n1 * n2 - 6l * n1^2 * n2 +
        2n1^3 * n2 +
        2k^2 * n2^2 +
        6l^2 * n2^2 - 6l * n1 * n2^2 - 6n1^2 * n2^2 +
        6l * n2^3 +
        2n1 * n2^3 +
        n2^4 - 20k^2 * w1^2 - 6l * n1 * w1^2 - 8n1^2 * w1^2 +
        6l * n2 * w1^2 +
        10n1 * n2 * w1^2 - 2n2^2 * w1^2 + w1^4 - 24k * l * w1 * w2 - 12k * n1 * w1 * w2 - 12k * n2 * w1 * w2 -
        20k^2 * w2^2 + 6l * n1 * w2^2 - 2n1^2 * w2^2 - 6l * n2 * w2^2 + 10n1 * n2 * w2^2 - 8n2^2 * w2^2 +
        2w1^2 * w2^2 +
        w2^4

    c3 =
        -16k^2 * l - 8k^2 * n1 - 4l * n1^2 - 2n1^3 - 8k^2 * n2 + 8l * n1 * n2 + 2n1^2 * n2 - 4 * l * n2^2 + 2n1 * n2^2 -
        2n2^3 + 2n1 * w1^2 - 2n2 * w1^2 + 8k * w1 * w2 - 2n1 * w2^2 + 2n2 * w2^2

    c4 = 4k^2 + n1^2 - 2n1 * n2 + n2^2

    p = Polynomial([c0, c1, c2, c3, c4])
    return roots(p)
end
3 Likes

i always forget the usefulness of Gröbner bases haha,There is GitHub - sumiya11/Groebner.jl: Groebner bases in (almost) pure Julia for a julia implementation.

1 Like

I like simple solutions, so I thought I’d try this one for comparison:


using Polynomials, LinearAlgebra, Statistics

function circ2cart(Z,d)
    # function analytical in araujoms' post above
    r1 = filter(isreal, analytical(Z[:,1],Z[:,2],d))
    y = real.(r1)
    t = asin.(y)
    x = cos.(t)
    return [x y]
end

function gdun(X::AbstractMatrix, y::AbstractVector; k=15)
    # gradient descent with unit norm constraint
    # Start with OLS solution
    B = X'X \ X'y
    B ./= norm(B)
    Bs = Matrix{Float64}(undef, k, 2)
    Bs[1,:] = B
    for i in 2:k
        res = X'*(y-X*B)
        B = (B + res)/sqrt(1.0 +  2*B'*res + norm(res)^2)
        Bs[i,:] = B
    end
    return Bs
end

function mse(bc, bs; X=Float64[], y=Float64[])
        B = [bc, bs]
        return mean(abs2, y-X*B)
end

And an example:

# Example vectors from Juliohm's post above
z1 = [0.021553008480093894, 0.5376179480003928, 0.056695819740902764]
z2 = [0.20214667066080527, 0.36051053300345526, 0.4475695503545505]
Z = [z1 z2]
d = [0.6774110630795951, 0.9644071492859374, 0.7483034986631949]

polysol = circ2cart(Z,d)
B = gdun(Z, d)
B

The vector is pointing towards the OLS solution, which lies off the plot.
I start from the unit circle point closest to the OLS solution, and “walk” the unit circle to the closest root (mse=0.17).
The 2 real roots for this example are the black points. The pink lines are the MSE contours.

That’s not going to work in general because you need to consider both possibilities for x, it might be \sqrt{1-y^2} or -\sqrt{1-y^2}. You can instead do

function circ2cart(Z,d)
    # function analytical in araujoms' post above
    r1 = filter(isreal, analytical(Z[:,1],Z[:,2],d))
    y = real.(r1)
    x = sqrt.(1 .- y.^2)
    return [[x;-x] [y;y]]
end

Can it handle symbolic coefficients, though?

Dear Juliohm,

This can be done extremely quickly and simply (a few J evals plus a few thousand flops), taking a few microseconds for small vectors. @dlfivefifty and @stevengj already mentioned two of the key ideas. But I didn’t see anyone actually describing how the Fourier series is best evaluated, and putting it together. As your algebra already showed, J(t) is a Fourier series in t with frequencies -2,-1,0,1,2 only. Thus you sample J at 5 equispaced t-points, use the DFT to get the coeffs, then send the coeffs of the derivative of the series to a “Fourier rootfinder” (eg Boyd’s approach of writing z=e^{it}, multiplying by z^2, and using the companion matrix). You will get 2 or 4 roots with t real. Evaluate J at each and return the smallest.

I procrastinated and wrote the 7-line function plus demo. (I don’t include the test code but have verified it gives the actual min):

using FFTW
using PolynomialRoots
using LinearAlgebra
using Chairmarks

function Joptim(z1,z2,d)
	J(t) = norm(cos(t)*z1 + sin(t)*z2 - d)^2   # func
	Jhat = fftshift(ifft(J.(2pi*(0:4)/5)))     # Fourier coeffs -2,..,2
	Jphat = Jhat .* (im*(-2:2))                # F coeffs of J'
	tt = log.(roots(Jphat))*1im                # Boyd: roots of J' via Taylor
	tt = real.(tt[abs.(imag.(tt)) .< 1e-8])    # only keep near-real roots
	Jmin, imin = findmin(J.(tt))               # brute force 2 or 4 roots
	tt[imin], Jmin                             # return t_min, J_min
end

n=7   # dims
d = randn(n)
z1 = randn(n)
z2 = randn(n)
tm, Jm = Joptim(z1,z2,d)
@b Joptim(z1,z2,d)

4.855 μs (91 allocs: 5.016 KiB)

Please let me know if this is useful. Best, Alex

3 Likes