Hello, I was trying to implement TriMAP using Zygote’s gradient function, but it is very slow and the data is very high dimensional, is there any tip to improve the performance of the `gradient`

function?

The following function is where the optimization happens and it’s slowing everything down taking more than > 11GB for 300 samples in 26x26 dimension (MNIST dataset), and similar results maybe slightly less using the Swiss Role so it’s not really a dimensional issue but a sample issue,

```
function trimap(X :: AbstractMatrix{T},
maxoutdim :: Integer=2,
maxiter :: Integer=400,
initialize:: Symbol=:pca,
lr :: T = 0.5,
weight_temp :: T = 0.5,
m₁ :: Int = 10,
m₂ :: Int = 6,
nntype = ManifoldLearning.BruteForce) where {T <: Real}
"""
Implements the routine as described in
https://arxiv.org/abs/1910.00204
"""
d, n = size(X)
Y = if initialize == :pca
predict(fit(ManifoldLearning.PCA, X, maxoutdim=maxoutdim), X)
elseif initialize == :random
rand(T, maxoutdim, n)
else error("Unknown initialization")
end
# Neareest neighbors
NN = fit(nntype, X)
# initialize triplets and weights
triplets, weights = generate_triplets(X, NN, m₁, m₂, weight_temp)
# Optimization of the embedding
gain = zeros(T, size(Y))
vel = zeros(T, size(Y))
embedding_gradient(y) = gradient(embedding -> trimap_loss(embedding, triplets, weights), y)
@inbounds for i in 1:maxiter
local gamma = if i > SWITCH_ITER
FINAL_MOMENTUM
else
INIT_MOMENTUM
end
# Note that also with ForwardDiff it is slower
grad = embedding_gradient(Y .+ gamma .* vel)[1]
Y, gain, vel = update_embedding(Y, grad, vel, gain, lr, i)
end
return Trimap{nntype,T}(d, NN, Y)
```

And the loss function is defined as

```
function trimap_loss(embedding,# :: AbstractMatrix{T},
triplets :: AbstractMatrix{Int},
weights :: AbstractVector{T}) where {T <: Real}
# points chosen for the triplets
pointsI = embedding[:, triplets[1, :]]
# points close to the points in I
pointsJ = embedding[:, triplets[2, :]]
# points not in the neighborhood of points i
pointsK = embedding[:, triplets[3, :]]
# evaluate the distances
sim_distances = 1.0 .+ squared_euclidean_distance(pointsI, pointsJ)
out_distances = 1.0 .+ squared_euclidean_distance(pointsI, pointsK)
vv = @. weights / (1.0 + out_distances / sim_distances)
loss = Statistics.mean(vv)
return loss
end
function squared_euclidean_distance(x1 :: AbstractMatrix{T},
x2 :: AbstractMatrix{T}) where {T <: Real}
@fastmath nn = @. (x1 - x2)^2
@fastmath sum(nn,dims=1) # here we sum by row, so that the end result
# is a vector with each 'row' having a value
end
```

Is there anything I can do to improve the performance of this code? Should I still use Zygote, or are there better-suited packages for this?

I tried ReverseDiff and ForwardDiff too, but still to no avail (even storing the results in a GradientTape doesn’t work better). What do you think is the issue?

Best,