Flux.jl RNN performance

I have been trying out Flux.jl for RNNs and I really like the syntax. I am wondering how I should expect performance to compare to, say PyTorch and whether there are things I should look out for in terms of defining models in order to get good performance.

To get myself started, I wrote a simple example of a custom RNN that learns to generate a sine wave (I am using Float32 variables everywhere because it seems to improve performance by 20% or so and is a fair comparison to PyTorch). This kind of model is pretty representative of my use case in which Iâ€™d like to be able to define RNN models by hand.

``````using Flux

N = 200 #number of hidden units
S = 1 #number of inputs
R = 1 #number of readouts
B = 20 #batches per epoch

Nepochs = 100
dt = .1f0
T = 20
period = 10f0 #sine wave period
NT = Int(T/dt) #number of timesteps

lr = .001f0 #learning rate
sig = 0.1f0 #std deviation of initial condition

t = dt*(1:NT)

s = zeros(Float32,NT,S,B) #input, in this case just set to zero
rtarg = zeros(Float32,NT,R,B) #target output
for bi = 1:B
rtarg[:,1,bi] = sin.(2*pi*t/period)
end

ws = param(randn(Float32,N,S)/Float32(sqrt(S))) #input to hidden
J = param(randn(Float32,N,N)/Float32(sqrt(N))) #hidden to hidden
wr = param(randn(Float32,R,N)/Float32(sqrt(N))) #hidden to readout
b = param(zeros(Float32,N,1)) #bias

function calcloss(x0)
loss = 0
x = x0
for ti = 1:NT
x += dt*(-x + tanh.(ws*s[ti,:,:] + J*x .+ b))
r = wr*x

loss += sum((r-rtarg[ti,:,:]).^2)/B
end

return loss
end

xinit = randn(Float32,N,1)

prevt = time()
for ei = 1:Nepochs
print(ei,"\r")

x0 = xinit .+ sig*randn(Float32,N,B) #initial condition

Flux.train!(calcloss,[[x0]],opt)
end
print("train time: ",time() - prevt)
``````

Running this takes around 25 seconds on my CPU. As a comparison, the equivalent PyTorch code below takes 4-5 seconds, and the difference seems to grow more dramatic with N.

``````import torch
import time
import numpy as np

N = 200 #number of hidden units
S = 1 #number of inputs
R = 1 #number of readouts
B = 20 #batches per epoch

Nepochs = 100
dt = .1
T = 20
period = 10 #sine wave period
NT = int(T/dt) #number of timesteps

lr = .001 #learning rate
sig = 0.1 #std deviation of initial condition

t = dt*np.arange(NT)

s = Variable(torch.zeros(NT,S,B),requires_grad=False) #input, in this case just set to zero
for bi in range(B):
rtarg[:,0,bi] = torch.FloatTensor(np.sin(2*np.pi*t/period))

ws0 = np.random.standard_normal([N,S]).astype(np.float32)/np.sqrt(S) #input to hidden
J0 = np.random.standard_normal([N,N]).astype(np.float32)/np.sqrt(N) #hidden to hidden
wr0 = np.random.standard_normal([R,N]).astype(np.float32)/np.sqrt(N) #hidden to readout
b0 = np.zeros([N,1]).astype(np.float32) #bias

xinit = torch.randn(N,1)

prevt = time.time()
for ei in range(Nepochs):
print(ei,"\r",end='')
x = xinit + sig*torch.randn(N,B) #initial condition
xa = torch.zeros(NT,N,B)
r = torch.zeros(NT,R,B)

for ti in range(NT):
x = x + dt*(-x + torch.tanh(ws.mm(s[ti,:,:]) + J.mm(x) + b))
xa[ti,:,:] = x
r[ti,:,:] = wr.mm(x)

loss = torch.sum(torch.pow(r-rtarg,2))/B
loss.backward()
opt.step()

print("train time: ",time.time() - prevt)
``````

As Flux.jl is still young, Iâ€™m sure itâ€™s not been fully optimized, but Iâ€™m wondering if Iâ€™m doing everything right or if there are any tricks I could use to speed up my model. Iâ€™m also wondering in general what should be expected in terms of performance vs. other frameworks.

5 Likes

Maybe the problem is that you are benchmarking in global scope?
See https://docs.julialang.org/en/v1/manual/performance-tips/#Avoid-global-variables-1

Try putting everything inside a function.

Thatâ€™s worth checking, but in this case it doesnâ€™t seem to affect the time at all. The bottleneck is the call to Flux.train!().

Try asking in the flux slack channel

It would be great if you could try to narrow this down a bit more; maybe strip it down to a single forward-backwards iteration and then try making the program gradually smaller. Itâ€™s likely that thereâ€™s a single kernel here which isnâ€™t well-optimised and is hurting performance.

Ok, hereâ€™s a more stripped down version. Now I just have a linear RNN where x(t+1) = J*x(t) and the objective is simply to make x go to zero, and Iâ€™m doing one training step.

This code runs in 0.5-0.6 seconds (about 4 times as long as the PyTorch code further down this post), about 0.1 for the forward pass and 0.4-0.5 for the backward pass. Iâ€™m doing the benchmarking the second time I run the file from the REPL so Iâ€™m not benchmarking the compilation.

``````using Flux

N = 500 #number of hidden units
B = 20 #batch size
NT = 500 #number of timesteps

lr = .001f0

J = param(randn(Float32,N,N)/Float32(sqrt(N))) #hidden to hidden

t1 = time()
function calcloss(x0)
loss = 0
x = x0
for ti = 1:NT
x = J*x
loss += sum(x.^2)
end

return loss
end

x0 = randn(Float32,N,B) #initial condition
loss = calcloss(x0)
println("forward time: ",time() - t1)

t2 = time()
Flux.Tracker.back!(loss)
opt()

println("backward time: ",time() - t2)
println("total time: ",time() - t1)
``````

This is the equivalent in PyTorch and it runs in about 0.15 seconds (0.05 for forward pass, 0.1 for backward pass, and the actual update in opt.step() is negligible).

``````import torch
import time
import numpy as np

N = 500 #number of hidden units
B = 20 #batch size
NT = 500 #number of timesteps

lr = .001

J0 = np.random.standard_normal([N,N]).astype(np.float32)/np.sqrt(N) #hidden to hidden

t1 = time.time()
x0 = torch.randn(N,B) #initial condition
x = x0
loss = 0
for ti in range(NT):
x = J.mm(x)
loss += torch.sum(torch.pow(x,2))
print("forward time: ",time.time() - t1)

t2 = time.time()
loss.backward()
opt.step()
print("backward time: ",time.time() - t2)
print("total time: ",time.time() - t1)
``````

I also ran this comparison on my laptop (fewer cores) and the difference was a factor of 2.5 rather than 4.

I tried a few of the usual suspects here, like worrying about type-stability, but did not get much improvement. Would be interested to see if there are better ideas.

``````using BenchmarkTools

@btime calcloss(\$x0)                ## 162.306 ms (20501 allocations: 38.66 MiB)
@btime Tracker.back!(calcloss(\$x0)) ## 1.565 s (103505 allocations: 632.36 MiB)

function calcloss2(x0, NT=NT, J=J)
x = J*x0
loss = sum(z->z^2, x)
for ti = 2:NT
x = J*x
loss += sum(z->z^2, x)
end
loss
end

function calcloss3(x0, NT=NT, J=J)
xs = [ J*x0 ]
for ti = 2:NT
push!(xs, J*xs[ti-1])
end
loss = sum(sum(z->z^2, x) for x in xs)
end

@code_warntype calcloss(x0)
@code_warntype calcloss2(x0, NT, J)
@code_warntype calcloss3(x0, NT, J)

@btime calcloss2(\$x0)                ## 156.008 ms (12996 allocations: 38.51 MiB)
@btime calcloss2(\$x0, \$NT, \$J)       ## 152.126 ms (12996 allocations: 38.51 MiB)
@btime Tracker.back!(calcloss2(\$x0, \$NT, \$J)) ## 1.051 s (24996 allocations: 592.03 MiB)

@btime calcloss3(\$x0)                ## 155.222 ms (13007 allocations: 38.52 MiB)
@btime Tracker.back!(calcloss3(\$x0, \$NT, \$J)) ## 1.172 s (25007 allocations: 592.04 MiB)
``````

Iâ€™ve been playing with this further and am now looking at a single matrix multiplication. The code is below.

Forward pass:
Flux: 221.375 ms (36 allocations: 19.07 MiB)
PyTorch: 200ms

Backward pass:
Flux: 456.944 ms (166 allocations: 152.59 MiB)
PyTorch: 200ms.

Flux: 712.069 ms (22 allocations: 496 bytes)
PyTorch: 150ms

The important difference is probably the backward pass. I am doing this on a 2-core machine and the difference will probably be larger on my 4-core machine based on what I saw earlier. (I can check that tomorrow). Also, itâ€™s not because of the x^2 (performance is the same if the loss is just sum(x)).

The update step is also surprisingly different, so I guess the ADAM implementation is less efficient in Flux. However, for deep networks or RNNs with many timesteps this should be quick compared to the backward pass, so itâ€™s probably not important. The performance is closer between Flux and PyTorch if I use the SGD optimizer.

``````using BenchmarkTools
using Flux

N = 5000 #number of hidden units
B = 500 #batch size

lr = .001f0

J = param(randn(Float32,N,N)/Float32(sqrt(N))) #hidden to hidden

function calcloss(x0)
x = x0
x = J*x
loss = sum(x.^2)

return loss
end

x0 = randn(Float32,N,B) #initial condition
@btime loss = calcloss(x0) #221.375 ms (36 allocations: 19.07 MiB)

loss = calcloss(x0)

@btime Flux.Tracker.back!(loss) #456.944 ms (166 allocations: 152.59 MiB)

@btime opt() #712.069 ms (22 allocations: 496 bytes)
``````

PyTorch code:

``````import torch
import time
import numpy as np

N = 5000 #number of hidden units
B = 500 #batch size

lr = .001

J0 = np.random.standard_normal([N,N]).astype(np.float32)/np.sqrt(N) #hidden to hidden

x0 = torch.randn(N,B)
t1 = time.time()
x = x0
x = J.mm(x)
loss = torch.sum(torch.pow(x,2))
print("forward time: ",time.time() - t1)

t2 = time.time()
loss.backward()
print("backward time: ",time.time() - t2)

t3 = time.time()
opt.step()
print("gradient application time: ",time.time() - t3)
``````

Iâ€™m lazy to write it, but how does this compare with a hand-written backwards pass? And (having written such a function) how fast is an `@grad calcloss(x0) = ...` definition using that?

Maybe you could try to limit blas numthread for a â€śfairâ€ť comparison.

I tried it on my 2-core computer with a single thread for both Flux and PyTorch. Again the forward pass was about the same for both (~500ms) but for PyTorch the backward pass was the same as the forward pass and in Flux it took about 2.5x as long.

I would need to do some research to figure out how to implement a hand-written backward pass using Flux.Tracker, but I will look into it.

OK I had a go. My first gradient here re-does the matrix multiplication of the forward pass on the reverse. This is about as fast as the automatic backward, above, on my slow machine.

The second way instead re-uses this `x`, and is twice as fast.

I donâ€™t know if this is whatâ€™s going on, because the way Flux gets the derivative of `sum(z->z^2,x)` involves broadcasting and perhaps ForwardDiff and is quite complicated.

``````## hand-written gradient
x = J * x0
D1 = 2 .* D .* x
DJ = D1 * transpose(x0)
end

Flux.back!(calcloss(x0))

@btime calcloss_grad(\$x0, \$(data(J)), 1)    ## 891.765 ms (6 allocations: 114.44 MiB)

## stitch that into Flux:
using Flux.Tracker: @grad, data, back!, TrackedArray, track
calcloss_2(x0, J=J) = sum(z->z^2, J * x0)
calcloss_2(x0, J::TrackedArray) = track(calcloss_2, x0, J)
@grad calcloss_2(x0, J) = calcloss_2(x0, data(J)), D -> (nothing, calcloss_grad(x0, data(J), D) ,)

@btime calcloss_2(\$x0, \$J)
loss = calcloss_2(x0, J)
@btime back!(\$loss)                         ## 906.353 ms (9 allocations: 114.44 MiB)

## better version, saving x for re-use
calcloss_3(x0, J=J) = sum(z->z^2, J * x0)
calcloss_3(x0, J::TrackedArray) = track(calcloss_3, x0, J)