Flux LayerNorm slower than pytorch?

I closed a similar topic I opened about one hour ago by mistake, here I try again with clearer example, the issue is that the same LayerNorm layer in pytorch and Flux has large difference in performance and I don’t know if it is expected or not.

If anybody could clarify this or tell me where to look for an answer I would be grateful here are two minimal working examples and the resulting benchmarks (both done in Pluto in case it affects the results):

In pytorch:

using PythonCall

torch = pyimport("torch")
np = pyimport("numpy")

l = torch.nn.LayerNorm(768)
v = torch.tensor(
    np.array(
        rand(Float32, 2, 128, 768)
    )
)
	
@benchmark l(v) 

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 97.483 μs … 3.871 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 135.234 μs ┊ GC (median): 0.00%
Time (mean ± σ): 138.385 μs ± 42.160 μs ┊ GC (mean ± σ): 0.00% ± 0.00%

Memory estimate: 16 bytes, allocs estimate: 1.

In Flux:

using Flux

jl = Flux.LayerNorm(768)
jv = rand(Float32, 768, 128, 2)

@benchmark jl(jv)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 286.678 μs … 6.144 ms ┊ GC (min … max): 0.00% … 91.80%
Time (median): 322.035 μs ┊ GC (median): 0.00%
Time (mean ± σ): 341.300 μs ± 324.158 μs ┊ GC (mean ± σ): 5.36% ± 5.33%

Memory estimate: 2.25 MiB, allocs estimate: 24.

1 Like

Flux’s LayerNorm just calls https://github.com/FluxML/Flux.jl/blob/v0.12.9/src/layers/stateless.jl#L36, so if you have any ideas for optimizing it please do file a PR :slight_smile:

Ok thanks! I didn’t go throught the code because I thought in no way I could understand it but I gues I could have. Then I gues that’s just how things are!!

As the comment there says, σ = std(x, dims=dims, corrected=false) could use the mean, since this rule accepts it (as does this one.)

Out of curiosity I tried but at least just replacing the Flux example above with

function normalise(x::AbstractArray; dims=ndims(x), ϵ=Flux.ofeltype(x, 1e-5))
    μ = mean(x, dims=dims)
    σ = std(x, dims=dims, mean=μ, corrected=false) 
    return (x .- μ) ./ (σ .+ ϵ)
end

# jl = Flux.LayerNorm(768)
jv = rand(Float32, 768, 128, 2)

@benchmark normalise(jv)

Gives very similar results, in fact a bit worse, would it make a difference if I integrate the different normalize within a struct using @functor?

No, I don’t think so. But the right comparison for the mean change is probably this, a small improvement:

julia> @btime normalise($jv; dims=1:1);
  min 240.000 μs, mean 322.445 μs (14 allocations, 770.50 KiB)

julia> @btime Flux.normalise($jv; dims=1:1);
  min 265.708 μs, mean 358.654 μs (20 allocations, 771.72 KiB)

Note that there are other steps in applying the layer, @less jl(jv). There might be room to optimise these further.

Which package is @less from? Sorry I found it

I did find less but I don’t understand what you meant indicating it, like is there smth in the output of @less that you think would be useful?

@less just shows you the source code and location of the method being called. In this case, https://github.com/FluxML/Flux.jl/blob/master/src/layers/normalise.jl#L175-L178. I think Michael’s point is that more is going on than just normalise, as evidenced by https://github.com/FluxML/Flux.jl/blob/master/src/layers/normalise.jl#L177. Given you mentioned above that just normalise is slower than the entire PyTorch LayerNorm forward pass, that function is still very much a culprit.

Thanks so much for the explanation, I would have never understood the point. I will look into the code more closely since it is so straightforward to read, though without fully understanding all that is going on I guess.

If I can ask you what would you say is a good starting point in the code to understand the flux design?

Do you mean in general or for this specific layer? For LayerNorm specifically, the implementation is not that different from what you’d write by hand using Base Julia or e.g. Numpy. .diag is extracted into its own layer type because that operation is useful outside of layer normalization. Ideally we’d like to avoid having to write a fully custom kernel like https://github.com/pytorch/pytorch/blob/ef066f0832eab3192a0610a1ecf955239c26de0b/aten/src/ATen/native/cpu/layer_norm_kernel.cpp, but if that’s not possible then such a kernel would be a welcome addition to https://github.com/FluxML/NNlib.jl.