Hey,
I am currently converting some models from pytorch to flux and canβt figure out why I am having performance issues.
My function in PyTorch:
def stiffness_reg(x, idx):
x_nn = knn_gather(x, idx) # basically indexing a point cloud
norm1 = tc.linalg.norm(x_nn[:, :, 0, :] - x_nn[:, :, 1, :], dim=2)
norm2 = tc.linalg.norm(x_nn[:, :, 0, :] - x_nn[:, :, 2, :], dim=2)
norm3 = tc.linalg.norm(x_nn[:, :, 0, :] - x_nn[:, :, 3, :], dim=2)
norm = tc.sum(norm1+norm2+norm3)
return norm
And in Julia
function stiffness_reg(deformation, idx)
deformation_indexed = deformation[idx, :] # indexing a point cloud
norm1 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 2, :])
norm2 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 3, :])
norm3 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 4, :])
# norm1 = sqrt.(reduce((x, y) -> x+abs(y)^2, norm1; dims=2, init=0.0f0))
# norm2 = sqrt.(reduce((x, y) -> x+abs(y)^2, norm2; dims=2, init=0.0f0))
# norm3 = sqrt.(reduce((x, y) -> x+abs(y)^2, norm3; dims=2, init=0.0f0))
norm1 = sum(norm1, dims=2) # I know this is not the same as taking a norm
norm2 = sum(norm2, dims=2)
norm3 = sum(norm3, dims=2)
sum(norm1.+norm2.+norm3)
end
Now keep in mind that there are differences as the commented out version doesnβt work in flux at all. But even using simple sum instead takes ~4-5 times more time than in PyTorch.
That is to the contrast to the base loss function:
loss_fn(new_points, close_points) = sqrt(sum((new_points - close_points).^2))
which is approx 1.5 times faster than in PyTorch.
What could I be doing wrong? Iβve searched for a longtime and havenβt found an answer still.
Iβve read that the sum() and reduce() functions are not optimized in flux but then how am I supposed to do this other way?