I am trying to port this TensorFlow MNIST example to MLJFlux.jl. This is what I have so far, it is currently achieving an accuracy of 97.1%, which is pretty similar to the TensorFlow example:
using MLJ, MLJFlux, Flux
mutable struct MNISTBuilder <: MLJFlux.Builder
n_hidden::Int
end
function MLJFlux.build(builder::MNISTBuilder, (n1, n2), m, n_channels)
return Chain(
flatten,
Dense(n1 * n2, builder.n_hidden, relu),
Dense(builder.n_hidden, m),
)
end
imgs = Flux.Data.MNIST.images()
labels = Flux.Data.MNIST.labels()
labels = coerce(labels, Multiclass)
@load ImageClassifier
clf = ImageClassifier(;
builder=MNISTBuilder(128),
optimiser=ADAM(0.001),
loss=Flux.crossentropy,
epochs=6,
batch_size=128,
)
mach = machine(clf, imgs, labels)
@time evaluate!(
mach;
resampling=Holdout(fraction_train=5/6, shuffle=true, rng=123),
operation=predict_mode,
measure=[accuracy, #=cross_entropy, =#misclassification_rate],
verbosity = 3,
)
I was quite surprised that my implementation was actually a lot faster than the TensorFlow example (~12s after a warmup run vs ~26s). For the TensorFlow example, I just put %%time
in front of the model.fit
call. I was expecting TensorFlow to do well for such a simple example, so I am wondering whether I missed anything that gives my implementation an unfair advantage over the TensorFlow one and would appreciate any feedback from people with more expertise here. Perhaps I am just too skeptical though, and Flux is really that much faster, I definitely wouldn’t have any problem with that!
This is my setup for testing:
julia> versioninfo()
Julia Version 1.5.2
Commit 539f3ce943 (2020-09-23 23:17 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i7-7500U CPU @ 2.70GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-9.0.1 (ORCJIT, skylake)
Environment:
JULIA_NUM_THREADS = 2
TensorFlow version is 2.3.1 with Python 3.8.5.
A small aside: Is there any way to print the accuracy, validation loss and validation accuracy for each epoch as well, like in the Python example? Adding cross_entropy
as a measure errored for me and I don’t think measures are reported at the end of each epoch, only at the end of training.