Converting PyTorch to Flux while keeping performance

Yes, the loss_fn is similar to what I get.
I might just misunderstood you on a previous post but I got that from here: Zygote Performance (Again…) - General Usage - JuliaLang

Back to the topic at hand, the long and short of it seems to be that there aren’t optimized adjoints in place for mapreduce-like functions (including sum(f, ...) ) yet

Now I rewritten the function in Julia using Tullio (also recommended by someone):

rowmap(f, xs) = @tullio ys[r] := (f)(xs[r,c]) grad=Dual

function stiffness_reg(deformation, idx)

    deformation_indexed = deformation[idx, :]
    deformations1 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 2, :])
    deformations2 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 3, :])
    deformations3 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 4, :])

    norm_vec1 = rowmap(norm, deformations1)
    norm_vec2 = rowmap(norm, deformations2)
    norm_vec3 = rowmap(norm, deformations3)

    sum(norm_vec1+norm_vec2+norm_vec3)
end

And the benchmarks are:
PyTorch

<torch.utils.benchmark.utils.common.Measurement object at 0x00000133FCB28DC0>
reg(x, y)
setup: from __main__ import reg
  6.68 ms
  1 measurement, 1000 runs , 1 thread

Julia

@benchmark grad = gradient(params) do
           new_points = register(data_src)
           loss = stiffness_reg(register.vec, vec_idxs)
           return loss
       end
BenchmarkTools.Trial: 14 samples with 1 evaluation.
 Range (min … max):   89.977 ms …    1.889 s  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     127.545 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   462.202 ms ± 683.802 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▃            
  ██▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▄▁▄ ▁
  90 ms            Histogram: frequency by time          1.89 s <

 Memory estimate: 101.66 KiB, allocs estimate: 2009.