Converting PyTorch to Flux while keeping performance

Hey,
I am currently converting some models from pytorch to flux and can’t figure out why I am having performance issues.
My function in PyTorch:

def stiffness_reg(x, idx):
    x_nn = knn_gather(x, idx) # basically indexing a point cloud
    norm1 = tc.linalg.norm(x_nn[:, :, 0, :] - x_nn[:, :, 1, :], dim=2)
    norm2 = tc.linalg.norm(x_nn[:, :, 0, :] - x_nn[:, :, 2, :], dim=2)
    norm3 = tc.linalg.norm(x_nn[:, :, 0, :] - x_nn[:, :, 3, :], dim=2)

    norm = tc.sum(norm1+norm2+norm3)
    return norm

And in Julia

function stiffness_reg(deformation, idx)
    deformation_indexed = deformation[idx, :] # indexing a point cloud
    norm1 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 2, :])
    norm2 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 3, :])
    norm3 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 4, :])

    # norm1 = sqrt.(reduce((x, y) -> x+abs(y)^2, norm1; dims=2, init=0.0f0))
    # norm2 = sqrt.(reduce((x, y) -> x+abs(y)^2, norm2; dims=2, init=0.0f0))
    # norm3 = sqrt.(reduce((x, y) -> x+abs(y)^2, norm3; dims=2, init=0.0f0))

    norm1 = sum(norm1, dims=2) # I know this is not the same as taking a norm
    norm2 = sum(norm2, dims=2)
    norm3 = sum(norm3, dims=2)

    sum(norm1.+norm2.+norm3)
end

Now keep in mind that there are differences as the commented out version doesn’t work in flux at all. But even using simple sum instead takes ~4-5 times more time than in PyTorch.
That is to the contrast to the base loss function:

loss_fn(new_points, close_points) = sqrt(sum((new_points - close_points).^2))

which is approx 1.5 times faster than in PyTorch.

What could I be doing wrong? I’ve searched for a longtime and haven’t found an answer still.
I’ve read that the sum() and reduce() functions are not optimized in flux but then how am I supposed to do this other way?

1 Like

May I ask where you read this? They most certainly are optimized and that has nothing to do with Flux. What do you see benchmarking both PyTorch (using PyTorch Benchmark β€” PyTorch Tutorials 2.1.1+cu121 documentation) and Julia (using Home Β· BenchmarkTools.jl) versions?

Edit: this is what I get locally, on CPU:

julia> x, y = ntuple(_ -> rand(Float32, 128, 128), 2);

julia> @benchmark gradient(loss_fn, $x, $y)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  31.691 ΞΌs …  1.276 ms  β”Š GC (min … max):  0.00% … 91.96%
 Time  (median):     34.208 ΞΌs              β”Š GC (median):     0.00%
 Time  (mean Β± Οƒ):   41.239 ΞΌs Β± 83.340 ΞΌs  β”Š GC (mean Β± Οƒ):  13.99% Β±  6.72%

  β–ƒβ–†β–ˆβ–†β–                                                       ▁
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–†β–„β–…β–†β–†β–„β–…β–…β–†β–‡β–†β–†β–…β–…β–…β–…β–…β–„β–„β–ƒβ–…β–…β–…β–ƒβ–β–β–β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–„β–†β–ˆβ–‡β–‡β–ˆ β–ˆ
  31.7 ΞΌs      Histogram: log(frequency) by time      87.9 ΞΌs <

vs:

def loss_fn(x, y):
    l = ((x - y) ** 2).sum().sqrt()
    return torch.autograd.grad([l], [x, y]) # same result with l.backward()

x, y = [torch.rand(128, 128, requires_grad=True) for _ in range(2)]

t = benchmark.Timer(stmt="loss_fn(x, y)", setup="from __main__ import loss_fn", globals={"x": x, "y": y})
>>> t.timeit(10_000)
<torch.utils.benchmark.utils.common.Measurement object at 0x7ff5cec3dac0>
loss_fn(x, y)
setup: from __main__ import loss_fn
  93.99 us
  1 measurement, 10000 runs , 1 thread

Yes, the loss_fn is similar to what I get.
I might just misunderstood you on a previous post but I got that from here: Zygote Performance (Again…) - General Usage - JuliaLang

Back to the topic at hand, the long and short of it seems to be that there aren’t optimized adjoints in place for mapreduce-like functions (including sum(f, ...) ) yet

Now I rewritten the function in Julia using Tullio (also recommended by someone):

rowmap(f, xs) = @tullio ys[r] := (f)(xs[r,c]) grad=Dual

function stiffness_reg(deformation, idx)

    deformation_indexed = deformation[idx, :]
    deformations1 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 2, :])
    deformations2 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 3, :])
    deformations3 = (deformation_indexed[:, 1, :] .- deformation_indexed[:, 4, :])

    norm_vec1 = rowmap(norm, deformations1)
    norm_vec2 = rowmap(norm, deformations2)
    norm_vec3 = rowmap(norm, deformations3)

    sum(norm_vec1+norm_vec2+norm_vec3)
end

And the benchmarks are:
PyTorch

<torch.utils.benchmark.utils.common.Measurement object at 0x00000133FCB28DC0>
reg(x, y)
setup: from __main__ import reg
  6.68 ms
  1 measurement, 1000 runs , 1 thread

Julia

@benchmark grad = gradient(params) do
           new_points = register(data_src)
           loss = stiffness_reg(register.vec, vec_idxs)
           return loss
       end
BenchmarkTools.Trial: 14 samples with 1 evaluation.
 Range (min … max):   89.977 ms …    1.889 s  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     127.545 ms               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   462.202 ms Β± 683.802 ms  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–ˆβ–ƒ            
  β–ˆβ–ˆβ–„β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–„β–β–β–β–β–β–β–β–β–β–β–β–β–„β–β–„ ▁
  90 ms            Histogram: frequency by time          1.89 s <

 Memory estimate: 101.66 KiB, allocs estimate: 2009.

Yes, but note here how you’re calling sum(xs) and not sum(f, xs).

You may need to import LoopVectorization in order for Tullio to generate a fully optimized kernel. More importantly, I would extract deformation_indexed[:, 1, :] into its own local variable to potentially save on a lot of compute/memory overhead.

Also, what is register? It seems like there is more code here that may have an influence on performance (e.g. if register is a mutable struct), so a MWE would be much appreciated.

1 Like

Thank you for the tips, I will for sure try them.

In termes of register() it’s a simple Flux model:

struct Register{A<:AbstractArray}
    vec::A
end

Flux.@functor Register  # makes trainable

function (a::Register)(x::AbstractArray)
    x .+ a.vec
end

Which is then called in a simple for loop with some nearest neighbour searches (those are 3 times more performant than pytorch so I don’t suspect them for bottlenecks.
Unfortunately I can’t exactly share code as it’s something I am adapting from work. But outside of this function and knn searches (based on terrific KristofferC/NearestNeighbors.jl: High performance nearest neighbor data structures and algorithms for Julia. (github.com)) not much is happening computation wise.
This is the simplified train loop.

for i=1:1000
 grad = gradient(params) do
     new_points = register_model(data_src)

     Zygote.ignore() do 
         close_points, vec_idxs = nearest_points_estimation()
     end

     loss = loss_fn(new_points, close_points) + stiffness_reg(register_model.vec, vec_idxs)
     return loss
 end
 Flux.update!(opt, params, grad)
end

For the normalization, I followed the suggestion at Functions for normalizing vectors Β· Issue #12047 Β· JuliaLang/julia Β· GitHub. Try norm_vec1 = sqrt.(sum(abs2, deformations1, dims=2)).

It really helped actually. (I also tried some naive way of pre-allocation but it didn’t really do anything)

Previous one with tullio:

julia> @benchmark grad = gradient(params) do
           new_points = register(data_src)
           loss = stiffness_reg(deformations1, deformations2, deformations3, norm_vec1, norm_vec2, norm_vec3, register.vec, vec_idxs)
           return loss
       end
BenchmarkTools.Trial: 27 samples with 1 evaluation.
 Range (min … max):   76.742 ms …    1.171 s  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):      80.287 ms               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   213.052 ms Β± 310.449 ms  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–ˆ        
  β–ˆβ–…β–…β–β–β–β–β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–ƒβ–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–ƒβ–β–β–β–β–β–β–β–ƒβ–β–β–β–β–β–ƒ ▁
  76.7 ms          Histogram: frequency by time          1.17 s <

 Memory estimate: 104.22 KiB, allocs estimate: 2020.

Current one with the proposed solution

julia> @benchmark grad = gradient(params) do
           new_points = register(data_src)
           loss = stiffness_reg(deformations1, deformations2, deformations3, norm_vec1, norm_vec2, norm_vec3, register.vec, vec_idxs)
           return loss
       end
BenchmarkTools.Trial: 154 samples with 1 evaluation.
 Range (min … max):  14.629 ms …    1.028 s  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     15.123 ms               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   32.863 ms Β± 105.738 ms  β”Š GC (mean Β± Οƒ):  0.33% Β± 2.91%

  β–ˆ  
  β–ˆβ–…β–β–„β–„β–β–β–β–„β–β–β–β–β–β–„β–β–β–β–β–β–„β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–β–„ β–„
  14.6 ms       Histogram: log(frequency) by time       581 ms <

 Memory estimate: 110.03 KiB, allocs estimate: 1968.

I’ll see if it can be even more optimised somewhere but thank you for you time.

2 Likes

One more advanced trick: you can define a function like Flux.jl/recurrent.jl at v0.13.3 Β· FluxML/Flux.jl Β· GitHub which divides an array of fixed size along one dim into independent views (4 in your case). The rrule is an optimization which reduces the 4 x length(deformation_indexed) allocations + accumulations you’d ordinarily get from repeating deformation_indexed[:, i, :] 4 times to just 1.

1 Like