# Converting PyTorch to Flux while keeping performance

Hey,
I am currently converting some models from pytorch to flux and can’t figure out why I am having performance issues.
My function in PyTorch:

``````def stiffness_reg(x, idx):
x_nn = knn_gather(x, idx) # basically indexing a point cloud
norm1 = tc.linalg.norm(x_nn[:, :, 0, :] - x_nn[:, :, 1, :], dim=2)
norm2 = tc.linalg.norm(x_nn[:, :, 0, :] - x_nn[:, :, 2, :], dim=2)
norm3 = tc.linalg.norm(x_nn[:, :, 0, :] - x_nn[:, :, 3, :], dim=2)

norm = tc.sum(norm1+norm2+norm3)
return norm
``````

And in Julia

``````function stiffness_reg(deformation, idx)
deformation_indexed = deformation[idx, :] # indexing a point cloud
norm1 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 2, :])
norm2 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 3, :])
norm3 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 4, :])

# norm1 = sqrt.(reduce((x, y) -> x+abs(y)^2, norm1; dims=2, init=0.0f0))
# norm2 = sqrt.(reduce((x, y) -> x+abs(y)^2, norm2; dims=2, init=0.0f0))
# norm3 = sqrt.(reduce((x, y) -> x+abs(y)^2, norm3; dims=2, init=0.0f0))

norm1 = sum(norm1, dims=2) # I know this is not the same as taking a norm
norm2 = sum(norm2, dims=2)
norm3 = sum(norm3, dims=2)

sum(norm1.+norm2.+norm3)
end
``````

Now keep in mind that there are differences as the commented out version doesn’t work in flux at all. But even using simple sum instead takes ~4-5 times more time than in PyTorch.
That is to the contrast to the base loss function:

``````loss_fn(new_points, close_points) = sqrt(sum((new_points - close_points).^2))
``````

which is approx 1.5 times faster than in PyTorch.

What could I be doing wrong? I’ve searched for a longtime and haven’t found an answer still.
I’ve read that the sum() and reduce() functions are not optimized in flux but then how am I supposed to do this other way?

1 Like

May I ask where you read this? They most certainly are optimized and that has nothing to do with Flux. What do you see benchmarking both PyTorch (using PyTorch Benchmark — PyTorch Tutorials 1.12.1+cu102 documentation) and Julia (using Home · BenchmarkTools.jl) versions?

Edit: this is what I get locally, on CPU:

``````julia> x, y = ntuple(_ -> rand(Float32, 128, 128), 2);

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max):  31.691 μs …  1.276 ms  ┊ GC (min … max):  0.00% … 91.96%
Time  (median):     34.208 μs              ┊ GC (median):     0.00%
Time  (mean ± σ):   41.239 μs ± 83.340 μs  ┊ GC (mean ± σ):  13.99% ±  6.72%

▃▆█▆▁                                                       ▁
███████▇▆▄▅▆▆▄▅▅▆▇▆▆▅▅▅▅▅▄▄▃▅▅▅▃▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▆█▇▇█ █
31.7 μs      Histogram: log(frequency) by time      87.9 μs <
``````

vs:

``````def loss_fn(x, y):
l = ((x - y) ** 2).sum().sqrt()

x, y = [torch.rand(128, 128, requires_grad=True) for _ in range(2)]

t = benchmark.Timer(stmt="loss_fn(x, y)", setup="from __main__ import loss_fn", globals={"x": x, "y": y})
``````
``````>>> t.timeit(10_000)
<torch.utils.benchmark.utils.common.Measurement object at 0x7ff5cec3dac0>
loss_fn(x, y)
setup: from __main__ import loss_fn
93.99 us
1 measurement, 10000 runs , 1 thread
``````

Yes, the loss_fn is similar to what I get.
I might just misunderstood you on a previous post but I got that from here: Zygote Performance (Again…) - General Usage - JuliaLang

Back to the topic at hand, the long and short of it seems to be that there aren’t optimized adjoints in place for mapreduce-like functions (including `sum(f, ...)` ) yet

Now I rewritten the function in Julia using Tullio (also recommended by someone):

``````rowmap(f, xs) = @tullio ys[r] := (f)(xs[r,c]) grad=Dual

function stiffness_reg(deformation, idx)

deformation_indexed = deformation[idx, :]
deformations1 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 2, :])
deformations2 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 3, :])
deformations3 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 4, :])

norm_vec1 = rowmap(norm, deformations1)
norm_vec2 = rowmap(norm, deformations2)
norm_vec3 = rowmap(norm, deformations3)

sum(norm_vec1+norm_vec2+norm_vec3)
end
``````

And the benchmarks are:
PyTorch

``````<torch.utils.benchmark.utils.common.Measurement object at 0x00000133FCB28DC0>
reg(x, y)
setup: from __main__ import reg
6.68 ms
1 measurement, 1000 runs , 1 thread
``````

Julia

``````@benchmark grad = gradient(params) do
new_points = register(data_src)
loss = stiffness_reg(register.vec, vec_idxs)
return loss
end
BenchmarkTools.Trial: 14 samples with 1 evaluation.
Range (min … max):   89.977 ms …    1.889 s  ┊ GC (min … max): 0.00% … 0.00%
Time  (median):     127.545 ms               ┊ GC (median):    0.00%
Time  (mean ± σ):   462.202 ms ± 683.802 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%

█▃
██▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▄▁▄ ▁
90 ms            Histogram: frequency by time          1.89 s <

Memory estimate: 101.66 KiB, allocs estimate: 2009.
``````

Yes, but note here how you’re calling `sum(xs)` and not `sum(f, xs)`.

You may need to import LoopVectorization in order for Tullio to generate a fully optimized kernel. More importantly, I would extract `deformation_indexed[:, 1, :]` into its own local variable to potentially save on a lot of compute/memory overhead.

Also, what is `register`? It seems like there is more code here that may have an influence on performance (e.g. if `register` is a mutable struct), so a MWE would be much appreciated.

1 Like

Thank you for the tips, I will for sure try them.

In termes of register() it’s a simple Flux model:

``````struct Register{A<:AbstractArray}
vec::A
end

Flux.@functor Register  # makes trainable

function (a::Register)(x::AbstractArray)
x .+ a.vec
end
``````

Which is then called in a simple for loop with some nearest neighbour searches (those are 3 times more performant than pytorch so I don’t suspect them for bottlenecks.
Unfortunately I can’t exactly share code as it’s something I am adapting from work. But outside of this function and knn searches (based on terrific KristofferC/NearestNeighbors.jl: High performance nearest neighbor data structures and algorithms for Julia. (github.com)) not much is happening computation wise.
This is the simplified train loop.

``````for i=1:1000
new_points = register_model(data_src)

Zygote.ignore() do
close_points, vec_idxs = nearest_points_estimation()
end

loss = loss_fn(new_points, close_points) + stiffness_reg(register_model.vec, vec_idxs)
return loss
end
end
``````

For the normalization, I followed the suggestion at Functions for normalizing vectors · Issue #12047 · JuliaLang/julia · GitHub. Try `norm_vec1 = sqrt.(sum(abs2, deformations1, dims=2))`.

It really helped actually. (I also tried some naive way of pre-allocation but it didn’t really do anything)

Previous one with tullio:

``````julia> @benchmark grad = gradient(params) do
new_points = register(data_src)
loss = stiffness_reg(deformations1, deformations2, deformations3, norm_vec1, norm_vec2, norm_vec3, register.vec, vec_idxs)
return loss
end
BenchmarkTools.Trial: 27 samples with 1 evaluation.
Range (min … max):   76.742 ms …    1.171 s  ┊ GC (min … max): 0.00% … 0.00%
Time  (median):      80.287 ms               ┊ GC (median):    0.00%
Time  (mean ± σ):   213.052 ms ± 310.449 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%

█
█▅▅▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▃▁▁▁▁▁▃ ▁
76.7 ms          Histogram: frequency by time          1.17 s <

Memory estimate: 104.22 KiB, allocs estimate: 2020.
``````

Current one with the proposed solution

``````julia> @benchmark grad = gradient(params) do
new_points = register(data_src)
loss = stiffness_reg(deformations1, deformations2, deformations3, norm_vec1, norm_vec2, norm_vec3, register.vec, vec_idxs)
return loss
end
BenchmarkTools.Trial: 154 samples with 1 evaluation.
Range (min … max):  14.629 ms …    1.028 s  ┊ GC (min … max): 0.00% … 0.00%
Time  (median):     15.123 ms               ┊ GC (median):    0.00%
Time  (mean ± σ):   32.863 ms ± 105.738 ms  ┊ GC (mean ± σ):  0.33% ± 2.91%

█
█▅▁▄▄▁▁▁▄▁▁▁▁▁▄▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄ ▄
14.6 ms       Histogram: log(frequency) by time       581 ms <

Memory estimate: 110.03 KiB, allocs estimate: 1968.
``````

I’ll see if it can be even more optimised somewhere but thank you for you time.

1 Like

One more advanced trick: you can define a function like Flux.jl/recurrent.jl at v0.13.3 · FluxML/Flux.jl · GitHub which divides an array of fixed size along one dim into independent views (4 in your case). The `rrule` is an optimization which reduces the 4 x `length(deformation_indexed)` allocations + accumulations you’d ordinarily get from repeating `deformation_indexed[:, i, :]` 4 times to just 1.

1 Like