BLAS performance issues for common neural network patterns

question

#1

Hi,

I’ve been following Julia development on and off and with the buzz around the latest release I decided to give it another try by writing a small neural network library. I am working on a daily basis with Torch so I implemented something with a similar API.

However it turned out the performance of the code I wrote was terrible compared to Torch when doing a forward - backward operation on a fully connected network with around 500,000,000 parameters.

500,000,000 parameters:
torch: 0.43s
julia: 2.84 - 2.98s

68,000,000 parameters:
torch: 0.05 - 0.06s
julia: 0.35 - 0.36s

So I tried to break it down to the main operations, you can see the code below. This performs exactly as the network I wrote. Are there any obvious mistakes that might kill the performance? Any suggestions on how I can improve things?

module Test

T = Float32
W1 = rand(T, 2048, 512 * 512)
W2 = rand(T, 1024, 2048)
W3 = rand(T, 10, 1024)
dW1, dW2, dW3 = zeros(W1), zeros(W2), zeros(W3)
out1, out2, out3 = zeros(T, 2048), zeros(T, 1024), zeros(T, 10)
dOut1, dOut2, dOut = zeros(T, 2048), zeros(T, 1024), zeros(T, 512 * 512)

function mockNN(input::Array{Float32, 1}, error::Array{Float32, 1})
  # Forward
  BLAS.gemv!('N', T(1.0), W1, input, T(0.0), out1)
  BLAS.gemv!('N', T(1.0), W2, out1, T(0.0), out2)
  BLAS.gemv!('N', T(1.0), W3, out2, T(0.0), out3)

  # Backward
  # ∂E/∂inputs and ∂E/∂W
  fill!(dW3, 0)
  fill!(dOut2, 0)
  BLAS.gemv!('N', T(1.0), W3', error, T(0.0), dOut2)
  BLAS.ger!(T(1.0), error, out2, dW3)
  
  fill!(dW2, 0)
  fill!(dOut1, 0)
  BLAS.gemv!('N', T(1.0), W2', dOut2, T(0.0), dOut1)
  BLAS.ger!(T(1.0), dOut2, out1, dW2)

  fill!(dW1, 0)
  fill!(dOut, 0)
  BLAS.gemv!('N', T(1.0), W1', dOut1, T(0.0), dOut)
  BLAS.ger!(T(1.0), dOut1, input, dW1)
end

input = rand(T, 512 * 512)
error = rand(T, 10)
@time mockNN(input, error)
for i in 1:10
  input = rand(T, 512 * 512)
  error = rand(T, 10)
  @time mockNN(input, error)
end

end

#2

Have you checked that both languages are using the same BLAS?


#3

It may be that the transpose take up a good chunk of the time

julia> @time W1';
  7.697140 seconds (7 allocations: 2.000 GB, 20.98% gc time)

#4

try replacing

BLAS.gemv!('N', T(1.0), W3', error, T(0.0), dOut2)

etc., with

BLAS.gemv!('T', T(1.0), W3, error, T(0.0), dOut2)

I.e. don’t do the transpose yourself, but instead tell gemv! to.

With these changes, my little retina macbook yields

julia> for i in 1:10
         input = rand(T, 512 * 512)
         error = rand(T, 10)
         @time mockNN(input, error)
       end
  1.101238 seconds (101 allocations: 5.438 KB)
  1.073498 seconds (15 allocations: 240 bytes)
  1.090495 seconds (15 allocations: 240 bytes)
  1.095570 seconds (15 allocations: 240 bytes)
  1.079725 seconds (15 allocations: 240 bytes)
  1.089084 seconds (15 allocations: 240 bytes)
  1.088494 seconds (15 allocations: 240 bytes)
  1.074428 seconds (15 allocations: 240 bytes)
  1.343097 seconds (15 allocations: 240 bytes)
  3.410145 seconds (15 allocations: 240 bytes)

EDIT: before it was

julia> @time mockNN(input, error);
 22.707278 seconds (28 allocations: 2.008 GB, 5.36% gc time)

#5

Torch was compiled against OpenBlas and Julia with the one it comes with.


#6

Yes, that was really dumb of me not reading carefully the Blas interface :(.
Thanks for pointing out the transpose op, much appreciated! Both Torch and Julia are now in the same ballpark.