I am trying to port this TensorFlow MNIST example to MLJFlux.jl. This is what I have so far, it is currently achieving an accuracy of 97.1%, which is pretty similar to the TensorFlow example:
using MLJ, MLJFlux, Flux mutable struct MNISTBuilder <: MLJFlux.Builder n_hidden::Int end function MLJFlux.build(builder::MNISTBuilder, (n1, n2), m, n_channels) return Chain( flatten, Dense(n1 * n2, builder.n_hidden, relu), Dense(builder.n_hidden, m), ) end imgs = Flux.Data.MNIST.images() labels = Flux.Data.MNIST.labels() labels = coerce(labels, Multiclass) @load ImageClassifier clf = ImageClassifier(; builder=MNISTBuilder(128), optimiser=ADAM(0.001), loss=Flux.crossentropy, epochs=6, batch_size=128, ) mach = machine(clf, imgs, labels) @time evaluate!( mach; resampling=Holdout(fraction_train=5/6, shuffle=true, rng=123), operation=predict_mode, measure=[accuracy, #=cross_entropy, =#misclassification_rate], verbosity = 3, )
I was quite surprised that my implementation was actually a lot faster than the TensorFlow example (~12s after a warmup run vs ~26s). For the TensorFlow example, I just put
%%time in front of the
model.fit call. I was expecting TensorFlow to do well for such a simple example, so I am wondering whether I missed anything that gives my implementation an unfair advantage over the TensorFlow one and would appreciate any feedback from people with more expertise here. Perhaps I am just too skeptical though, and Flux is really that much faster, I definitely wouldn’t have any problem with that!
This is my setup for testing:
julia> versioninfo() Julia Version 1.5.2 Commit 539f3ce943 (2020-09-23 23:17 UTC) Platform Info: OS: Linux (x86_64-pc-linux-gnu) CPU: Intel(R) Core(TM) i7-7500U CPU @ 2.70GHz WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-9.0.1 (ORCJIT, skylake) Environment: JULIA_NUM_THREADS = 2
TensorFlow version is 2.3.1 with Python 3.8.5.
A small aside: Is there any way to print the accuracy, validation loss and validation accuracy for each epoch as well, like in the Python example? Adding
cross_entropy as a measure errored for me and I don’t think measures are reported at the end of each epoch, only at the end of training.