I have been trying out Flux.jl for RNNs and I really like the syntax. I am wondering how I should expect performance to compare to, say PyTorch and whether there are things I should look out for in terms of defining models in order to get good performance.
To get myself started, I wrote a simple example of a custom RNN that learns to generate a sine wave (I am using Float32 variables everywhere because it seems to improve performance by 20% or so and is a fair comparison to PyTorch). This kind of model is pretty representative of my use case in which I’d like to be able to define RNN models by hand.
using Flux
N = 200 #number of hidden units
S = 1 #number of inputs
R = 1 #number of readouts
B = 20 #batches per epoch
Nepochs = 100
dt = .1f0
T = 20
period = 10f0 #sine wave period
NT = Int(T/dt) #number of timesteps
lr = .001f0 #learning rate
sig = 0.1f0 #std deviation of initial condition
t = dt*(1:NT)
s = zeros(Float32,NT,S,B) #input, in this case just set to zero
rtarg = zeros(Float32,NT,R,B) #target output
for bi = 1:B
rtarg[:,1,bi] = sin.(2*pi*t/period)
end
ws = param(randn(Float32,N,S)/Float32(sqrt(S))) #input to hidden
J = param(randn(Float32,N,N)/Float32(sqrt(N))) #hidden to hidden
wr = param(randn(Float32,R,N)/Float32(sqrt(N))) #hidden to readout
b = param(zeros(Float32,N,1)) #bias
function calcloss(x0)
loss = 0
x = x0
for ti = 1:NT
x += dt*(-x + tanh.(ws*s[ti,:,:] + J*x .+ b))
r = wr*x
loss += sum((r-rtarg[ti,:,:]).^2)/B
end
return loss
end
opt = ADAM((ws,J,wr,b),lr)
xinit = randn(Float32,N,1)
prevt = time()
for ei = 1:Nepochs
print(ei,"\r")
x0 = xinit .+ sig*randn(Float32,N,B) #initial condition
Flux.train!(calcloss,[[x0]],opt)
end
print("train time: ",time() - prevt)
Running this takes around 25 seconds on my CPU. As a comparison, the equivalent PyTorch code below takes 4-5 seconds, and the difference seems to grow more dramatic with N.
import torch
import time
from torch.autograd import Variable
import numpy as np
N = 200 #number of hidden units
S = 1 #number of inputs
R = 1 #number of readouts
B = 20 #batches per epoch
Nepochs = 100
dt = .1
T = 20
period = 10 #sine wave period
NT = int(T/dt) #number of timesteps
lr = .001 #learning rate
sig = 0.1 #std deviation of initial condition
t = dt*np.arange(NT)
s = Variable(torch.zeros(NT,S,B),requires_grad=False) #input, in this case just set to zero
rtarg = Variable(torch.zeros(NT,R,B),requires_grad=False) #target output
for bi in range(B):
rtarg[:,0,bi] = torch.FloatTensor(np.sin(2*np.pi*t/period))
ws0 = np.random.standard_normal([N,S]).astype(np.float32)/np.sqrt(S) #input to hidden
J0 = np.random.standard_normal([N,N]).astype(np.float32)/np.sqrt(N) #hidden to hidden
wr0 = np.random.standard_normal([R,N]).astype(np.float32)/np.sqrt(N) #hidden to readout
b0 = np.zeros([N,1]).astype(np.float32) #bias
ws = Variable(torch.from_numpy(ws0),requires_grad=True)
J = Variable(torch.from_numpy(J0),requires_grad=True)
wr = Variable(torch.from_numpy(wr0),requires_grad=True)
b = Variable(torch.from_numpy(b0),requires_grad=True)
opt = torch.optim.Adam([J,wr,b,ws],lr=lr)
xinit = torch.randn(N,1)
prevt = time.time()
for ei in range(Nepochs):
print(ei,"\r",end='')
x = xinit + sig*torch.randn(N,B) #initial condition
xa = torch.zeros(NT,N,B)
r = torch.zeros(NT,R,B)
for ti in range(NT):
x = x + dt*(-x + torch.tanh(ws.mm(s[ti,:,:]) + J.mm(x) + b))
xa[ti,:,:] = x
r[ti,:,:] = wr.mm(x)
loss = torch.sum(torch.pow(r-rtarg,2))/B
loss.backward()
opt.step()
opt.zero_grad()
print("train time: ",time.time() - prevt)
As Flux.jl is still young, I’m sure it’s not been fully optimized, but I’m wondering if I’m doing everything right or if there are any tricks I could use to speed up my model. I’m also wondering in general what should be expected in terms of performance vs. other frameworks.