Flux running slow?

I’m new to Julia and testing it out for ML, hoping to use it for RL which is my main area of work. I noticed that Flux seems to be running much slower than I would have expected. I created a simple example of how I’m using Flux. I recreated the same thing in python with pytorch(1.6) and it ran ~ 8 times faster. I’m not sure what I’m doing wrong. Any pointers would be much appreciated.

I’m using Julia v1.5, Flux v0.11.1, Python v3.7, pytorch v1.6, Ubuntu 18.04 no GPU

Flux example:

using Flux
using Flux: params, update!
using Dates: now
using Statistics: mean

model = Chain(
    Dense(10, 128, tanh),
    Dense(128, 128, tanh),
    Dense(128, 1)
opt = ADAM(3e-4)
p = params(model)
x = rand(Float32, 10, 2000)
y = rand(Float32, 1, 2000)
function loss(x, y)
    ŷ = model(x)
    mean(y .- ŷ).^2
for j = 1:10
    st = now()
    for i = 1:10
        g = gradient(() -> loss(x, y), p)
        update!(opt, p, g)
    println(now() - st)

python using pytorch example

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.l1 = nn.Linear(10, 128)
        self.l2 = nn.Linear(128, 128)
        self.l3 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.tanh(self.l1(x))
        x = F.tanh(self.l2(x))
        x = self.l3(x)
        return x

model = Net()
loss_f = nn.MSELoss()
opt = optim.Adam(model.parameters(), lr=3e-4)
x = torch.rand(2000, 10)
y = torch.rand(2000, 1)
for j in range(10):
    st = time.time()
    for i in range(10):
        y_hat = model(x)
        loss = loss_f(y_hat, y)
    print(f"{(time.time() - st) * 1000:.0f} milliseconds")
1 Like

You can probably replace the for loop with the Flux.@epochs and Flux.train! function. See

@time Flux.@epochs 10 Flux.train!(loss, p, [(x,y)], opt)

I time it twice and it did 10 epochs in about 0.17s.

The other thing is that some of Torch’s stuff are more optimised, so you can look to using Torch.jl

However, my CUDA version is not compatible atm so i can’t test it out.

You can try

using Torch
model = model.layers |> torch

Thanks for this. I gave it a go using
Flux.@epochs 10 Flux.train!(loss, p, [(x,y)], opt)
instead of the loop and I found it actually slowed it down slightly, ~300ms for the 10 updates vs ~170ms with the original for loop and update!. Either way it’s still significantly slower than pytorch which was ~20ms for the 10 updates.

I’ll have a look at Torch.jl.

Could you measure the time per-epoch instead of the total time? Zygote takes a while to compile the gradient code on first run, so that may be dramatically inflating the total time. This is usually amortized over the course of a training run because it’s called, say, 100 batches x 10 epochs times, so just running it on effectively 10 batches is a worst-case scenario.

One other thing to try is to move all of the declarations into a functions or let blocks. Using globals could have an impact on this kind of more microbench-y test. All this can be done in parallel with trying out Torch.jl, of course.

1 Like

I ran with 10 * 10 batches,timing each batch of 10 updates. the first is slow for the compilation as you say, but all other batches are consistent at ~170ms for each batch of 10 updates. (vs 20ms for pytorch). All the time seems to be taken up with the calculation of gradients rather than the update!. Are there faster alternatives to calculate gradients in Flux other than the default Zygote gradient()?

I also tried with the test wrapped in a function to avoid global variables as mentioned on the Julia performance page, however in this case it didn’t seem to make a difference.

I guess there is also Knet’s AutoGrad.jl. The other obvious candidate would be ReverseDiff.jl given it’s an NN. But I have not tried either.

I reckon Torch.jl if it works would be the fastest given it’s more optimized.

Ah, just noticed you’re running on Cpu with tanh activations. Flux vs pytorch cpu performance is most likely the culprit (long story short, small dense MLPs with tanh on CPU hit a bunch of areas in Flux that need to be optimized), except more or less pronounced because you’re also running the backwards pass.

1 Like

Good news is that this should be much better (at least 3x) soon with my improvements to exp (and I might try my hand at they hyperbolic functions directly). The following `expm1 function for Float32 is about 2x faster.

MAX_EXP(::Type{Float32}) =  88.72284f0          # log 2^127 *(2-2^-23)
MIN_EXP(::Type{Float32}) = -17.32868f0          # log 2^-25

# 1/log(2) (For Float32 reductions)
LogBINV(::Type{Float32})    = 1.442695f0
# -log(base, 2) in upper and lower bits
LogBU(::Type{Float32})      = -0.6931472f0
LogBL(::Type{Float32})  = 1.9046542f-9

@inline function expm1_kernel(x::Float32)
    return x*evalpoly(x, (1.0f0, 0.5f0, 0.16666667f0, 0.04166667f0, 0.008333191f0, 0.0013887697f0, 0.00019959333f0, 2.5497458f-5))
@inline function expm1(x::T) where T<:Float32
    N_float = round(x*LogBINV(T))
    N = unsafe_trunc(Int32, N_float)
    r = muladd(N_float, LogBU(T), x)
    r = muladd(N_float, LogBL(T), r)
    small_part = expm1_kernel(r)
    if !(abs(N)<MIN_EXP(T))
        isnan(x) && return x
        x > MAX_EXP(T) && return Inf32
        x < MIN_EXP(T) && return -1f0
    twopk = reinterpret(T, (N+Int32(126)) << Int32(23))
    return muladd(twopk, small_part, twopk - .5f0)*2

Hopefully this improvement will be in 1.6

thanks for that. tanh seems to be the culprit. I switched to relu and I got ~30ms per batch of 10 updates, vs ~20ms for pytorch also using relu which is much better.

Unfortunately this doesn’t solve my overall problem as I’m wanting to use it for RL where there is extensive use of exp and log in the loss functions. I’ll look at your expm1 update thanks Oscar. I coded up an implementation of ppo in julia using Flux which was where I first realised I might be doing something wrong as it was running slower than expected.

After defining that, adding in the following gives a tanh that is about 2x faster than base (if using my expm1 function)

function mytanh(x::T) where T
	abs(2x) >= MAX_EXP(T) && return copysign(one(T), x)
	k = expm1(2x)
	return k/(k+2)

Great, thanks Oscar