Zygote gradient slow for TriMAP

Hello, I was trying to implement TriMAP using Zygote’s gradient function, but it is very slow and the data is very high dimensional, is there any tip to improve the performance of the gradient function?

The following function is where the optimization happens and it’s slowing everything down taking more than > 11GB for 300 samples in 26x26 dimension (MNIST dataset), and similar results maybe slightly less using the Swiss Role so it’s not really a dimensional issue but a sample issue,

function trimap(X :: AbstractMatrix{T}, 
                maxoutdim :: Integer=2, 
                maxiter :: Integer=400, 
                initialize:: Symbol=:pca,
                lr :: T = 0.5,
                weight_temp :: T = 0.5,
                m₁ :: Int = 10,
                m₂ :: Int = 6,
                nntype = ManifoldLearning.BruteForce) where {T <: Real}
    """
        Implements the routine as described in 
         https://arxiv.org/abs/1910.00204
        
    """
    d, n = size(X)

    Y = if initialize == :pca 
            predict(fit(ManifoldLearning.PCA, X, maxoutdim=maxoutdim), X)
        elseif initialize == :random
            rand(T, maxoutdim, n)
        else error("Unknown initialization")
    end

    # Neareest neighbors
    NN = fit(nntype, X) 
    
    # initialize triplets and weights
    triplets, weights = generate_triplets(X, NN, m₁, m₂, weight_temp)

    # Optimization of the embedding
    gain = zeros(T, size(Y))
    vel = zeros(T, size(Y))

    embedding_gradient(y) = gradient(embedding -> trimap_loss(embedding, triplets, weights), y)
    @inbounds for i in 1:maxiter
        local gamma = if i > SWITCH_ITER 
                FINAL_MOMENTUM
            else 
                INIT_MOMENTUM
        end
        # Note that also with ForwardDiff it is slower
        grad = embedding_gradient(Y .+ gamma .* vel)[1]
        Y, gain, vel = update_embedding(Y, grad, vel, gain, lr, i)
    end

    return Trimap{nntype,T}(d, NN, Y)

And the loss function is defined as

function trimap_loss(embedding,# :: AbstractMatrix{T}, 
    triplets :: AbstractMatrix{Int},
    weights :: AbstractVector{T}) where {T <: Real}

    
    # points chosen for the triplets
    pointsI = embedding[:, triplets[1, :]]
    # points close to the points in I
    pointsJ = embedding[:, triplets[2, :]]
    # points not in the neighborhood of points i
    pointsK = embedding[:, triplets[3, :]]
    # evaluate the distances
    sim_distances = 1.0 .+ squared_euclidean_distance(pointsI, pointsJ)
    out_distances = 1.0 .+ squared_euclidean_distance(pointsI, pointsK)
    vv = @. weights / (1.0 + out_distances / sim_distances)
    loss = Statistics.mean(vv)
    return loss
end

function squared_euclidean_distance(x1 :: AbstractMatrix{T}, 
                        x2 :: AbstractMatrix{T}) where {T <: Real}
    @fastmath nn = @. (x1 - x2)^2 
    @fastmath sum(nn,dims=1) # here we sum by row, so that the end result
                              # is a vector with each 'row' having a value 
end

Is there anything I can do to improve the performance of this code? Should I still use Zygote, or are there better-suited packages for this?

I tried ReverseDiff and ForwardDiff too, but still to no avail (even storing the results in a GradientTape doesn’t work better). What do you think is the issue?

Best,