Yes, the loss_fn is similar to what I get.
I might just misunderstood you on a previous post but I got that from here: Zygote Performance (Again…) - General Usage - JuliaLang
Back to the topic at hand, the long and short of it seems to be that there aren’t optimized adjoints in place for mapreduce-like functions (including
sum(f, ...)
) yet
Now I rewritten the function in Julia using Tullio (also recommended by someone):
rowmap(f, xs) = @tullio ys[r] := (f)(xs[r,c]) grad=Dual
function stiffness_reg(deformation, idx)
deformation_indexed = deformation[idx, :]
deformations1 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 2, :])
deformations2 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 3, :])
deformations3 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 4, :])
norm_vec1 = rowmap(norm, deformations1)
norm_vec2 = rowmap(norm, deformations2)
norm_vec3 = rowmap(norm, deformations3)
sum(norm_vec1+norm_vec2+norm_vec3)
end
And the benchmarks are:
PyTorch
<torch.utils.benchmark.utils.common.Measurement object at 0x00000133FCB28DC0>
reg(x, y)
setup: from __main__ import reg
6.68 ms
1 measurement, 1000 runs , 1 thread
Julia
@benchmark grad = gradient(params) do
new_points = register(data_src)
loss = stiffness_reg(register.vec, vec_idxs)
return loss
end
BenchmarkTools.Trial: 14 samples with 1 evaluation.
Range (min … max): 89.977 ms … 1.889 s ┊ GC (min … max): 0.00% … 0.00%
Time (median): 127.545 ms ┊ GC (median): 0.00%
Time (mean ± σ): 462.202 ms ± 683.802 ms ┊ GC (mean ± σ): 0.00% ± 0.00%
█▃
██▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▄▁▄ ▁
90 ms Histogram: frequency by time 1.89 s <
Memory estimate: 101.66 KiB, allocs estimate: 2009.