MLJFlux much faster for simple MNIST example than TensorFlow?

I am trying to port this TensorFlow MNIST example to MLJFlux.jl. This is what I have so far, it is currently achieving an accuracy of 97.1%, which is pretty similar to the TensorFlow example:

using MLJ, MLJFlux, Flux

mutable struct MNISTBuilder <: MLJFlux.Builder
    n_hidden::Int
end

function MLJFlux.build(builder::MNISTBuilder, (n1, n2), m, n_channels)
    return Chain(
        flatten,
        Dense(n1 * n2, builder.n_hidden, relu),
        Dense(builder.n_hidden, m),
    )
end

imgs = Flux.Data.MNIST.images()
labels = Flux.Data.MNIST.labels()
labels = coerce(labels, Multiclass)

@load ImageClassifier
clf = ImageClassifier(;
    builder=MNISTBuilder(128),
    optimiser=ADAM(0.001),
    loss=Flux.crossentropy,
    epochs=6,
    batch_size=128,
)

mach = machine(clf, imgs, labels)

@time evaluate!(
    mach;
    resampling=Holdout(fraction_train=5/6, shuffle=true, rng=123),
    operation=predict_mode,
    measure=[accuracy, #=cross_entropy, =#misclassification_rate],
    verbosity = 3,
)

I was quite surprised that my implementation was actually a lot faster than the TensorFlow example (~12s after a warmup run vs ~26s). For the TensorFlow example, I just put %%time in front of the model.fit call. I was expecting TensorFlow to do well for such a simple example, so I am wondering whether I missed anything that gives my implementation an unfair advantage over the TensorFlow one and would appreciate any feedback from people with more expertise here. Perhaps I am just too skeptical though, and Flux is really that much faster, I definitely wouldn’t have any problem with that! :smile:

This is my setup for testing:

julia> versioninfo()
Julia Version 1.5.2
Commit 539f3ce943 (2020-09-23 23:17 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-7500U CPU @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-9.0.1 (ORCJIT, skylake)
Environment:
  JULIA_NUM_THREADS = 2

TensorFlow version is 2.3.1 with Python 3.8.5.

A small aside: Is there any way to print the accuracy, validation loss and validation accuracy for each epoch as well, like in the Python example? Adding cross_entropy as a measure errored for me and I don’t think measures are reported at the end of each epoch, only at the end of training.

3 Likes