Why is this MLP slower in Flux than in TensorFlow?

I am working with some MLPs and noticed that TensorFlow is much faster than Flux. In the examples below, Flux requires about 20 minutes and TensorFlow requires just over a minute. Am I doing something incorrectly?

Thank you in advance for your feedback


using MKL, Flux, Distributions, Random, ProgressMeter
using Flux: params


function rand_parms()
    μ = rand(Uniform(-3, 3))
    σ′ = rand(Uniform(.1, 2))
    return (;μ,σ′)

function make_training_data(n)
    output = fill(0.0, 3, n)    
    μ,σ′ = rand_parms()
    x = rand(Normal(μ,σ′ ), n)
    for (i,v) in enumerate(x)
        output[:,i] = [μ, σ′ ,v]
    return output

# number of parameter vectors for training 
n_parms = 2500
# number of data points per parameter vector 
n_samples = 250
# training data
train_x = mapreduce(_ -> make_training_data(n_samples), hcat, 1:n_parms)
# true values 
train_y = map(i -> pdf(Normal(train_x[1,i], train_x[2,i]), train_x[3,i]), 1:size(train_x,2))
train_y = reshape(train_y, 1, length(train_y))
train_data = Flux.Data.DataLoader((train_x, train_y), batchsize=1000)

model = Chain(
    Dense(3, 100, tanh),
    Dense(100, 100, tanh),
    Dense(100, 120, tanh),
    Dense(120, 1, identity)

# loss function
loss_fn(a, b) = Flux.huber_loss(model(a), b) 

# optimization algorithm 
opt = ADAM(0.002)

n_epochs = 50

meter = Progress(n_epochs)
train_loss = zeros(n_epochs)
@showprogress for i in 1:n_epochs
    Flux.train!(loss_fn, params(model), train_data, opt)
    train_loss[i] = loss_fn(train_x, train_y)
    next!(meter; showvalues = [(:loss,train_loss[i])])


import tensorflow as tf
import numpy as np
from scipy.stats import norm
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense
import matplotlib.pyplot as plt
import time

n_parms = 2_500
n_points = 250

x_train = np.zeros((n_parms * n_points, 3))
row = 0
for _ in range(n_parms):
    mu = np.random.uniform(-3, 3)
    sigma = np.random.uniform(.1, 2)
    for _ in range(n_points):
        x = np.random.normal(mu, sigma)
        x_train[row,:] = np.array([mu,  sigma, x])
        row = row + 1
y_train = norm.pdf(x_train[:,2], x_train[:,0], x_train[:,1])


model = Sequential([
    Flatten(input_shape = (3, 1)),
    Dense(100, activation = 'tanh'),
    Dense(100, activation = 'tanh'),
    Dense(120, activation = 'tanh'),
    Dense(1, activation = 'linear')

              loss='huber', metrics=[tf.keras.metrics.RootMeanSquaredError()])

start_time = time.time()
losses = model.fit(x_train, y_train, epochs = 50,
          batch_size = 1000)
end_time = time.time()

print('run time: ', end_time - start_time)

1 Like

The Julia code is leaving all the inputs as Float64, while the TF code uses Float32 by default. Make sure to use x[.x]f0 for literals, and convert non-literals (e.g. with Float32(x)). With just those changes, the Flux version runs an order of magnitude faster on my machine.

1 Like

I think that fast_tanh should be used instead of tanh. That might make some difference as well.

1 Like

Thank you for your help. Indeed, switching to Float32 gave an order of magnitude speed up. Now it runs in about 2 minutes and 30 seconds. One of the great things about Julia is that I did not need to convert with Float32(x). Instead, I initialized output = zeros(Float32, 3, n) for the train_x data and the train_y was automatically Float32.

I have two remaining question. First, where do I find fast_tanh? I could not find anything with a google search. I did find FastActivations.jl. Would that be comparable to TF? Second, how much of the remaining difference might be due to Zygote.jl? I noticed poor performance when using it with Turing. Of course, Zygote might be optimized for neural networks.


Julia has a module that provides versions of math functions that may violate strict IEEE semantics here. As far as I know these functions are not exported, but you can use them like so


Do not expect large gains though. A few percentage maybe.

1 Like

As of recent versions of Flux you don’t have to find anything, the library will handle that for you :slight_smile: Flux.jl/basic.jl at v0.13.0 · FluxML/Flux.jl · GitHub

Some, but probably not too too much. Assuming you’re measuring the end-to-end runtime of your script, there will be at least 30s (and likely more) of just compilation latency in there too. You could try timing the training loop in isolation to see how long that takes.

Yes, though again the issues with Turing appear to be primarily compilation related: Zygote's compilation scales badly with the number of `~` statements · Issue #1754 · TuringLang/Turing.jl · GitHub.