Flux.jl simple classification, add to Model Zoo?

I just got a simple classification example working with Flux.jl. I was originally looking for something like it, but did not find anything. I have tried to make comments along the way on the level I would have needed yesterday :slight_smile: The comments are sort of in the way of the code, but they are also placed exactly where most needed, which I think helps newcomers (like me) understand what is going on.

I think that the example and all the comments can be overwhelming, but that everything is explained so that the reader always has the option to understand what goes on. I was missing that in the examples that I found.

I am looking to hear - would this example be welcome in the flux Model Zoo? And if yes, are there suggestions for improvement?

One improvement is adding titles to the plots, and automating the iterative training+plotting. I was unable to make this happen in my first attempts.

##! Classifying input points as above or below a given line
#* Comments are maked by a hashtag and star, which looks good with the "Better Comments" extension for VSCode.
#* Code blocks are seperated by a double-hashtag.
#* Run the script line by line, or block by block.
using Flux
using Flux: train!, params, onehotbatch, onecold

#* The line used as reference for above or below
y(x) = 4x + 2

#* Function to give a vector of output classifiction based on an input matrix.
#* A full column is an input. Row 1 is x-coordinate, row 2 is y-coordinate
function classify(points)
    outputs = []
    for col in eachcol(points)
        y(col[1]) < col[2] && push!(outputs, :above)
        y(col[1]) ≥ col[2] && push!(outputs, :below)
    end
    return outputs .|> identity  #* identity constrains the element-type, which makes Makie.jl happy
end

#* Generating the data. x denotes inputs into the
#* ArtificialNeuralNetwork (ANN), y denotes outputs (classification)
inputs_train = vcat(0:0.01:1 |> transpose, [y(x) + randn() * x for x = 0:0.01:1] |> transpose)
inputs_test = vcat(0.005:0.01:1 |> transpose, [y(x) + randn() * x for x = 0.005:0.01:1] |> transpose)

inputs_train |> classify #* Checking that we get outputs that make sense

##
#* Using the `classify` function to generate the correct classification for this supervised learning
outputs_train = classify(inputs_train)
outputs_test = classify(inputs_test)

#* onehot encoding to translate the classifications to something that makes more sense to the ANN
categories = [:above, :below]
outputs_train_encoded = onehotbatch(outputs_train, categories)
outputs_test_encoded = onehotbatch(outputs_test, categories)

##* Plotting the data generated thus far, mostly as a sanity-check
#* AlgebraOfGraphics provides convenient color by category
using AlgebraOfGraphics
#* Because `data` will be used for something else:
using AlgebraOfGraphics: data as AOGdata



axis = (width = 500, height = 500)  #* The axis we will plot into

#* Putting our data into a named tuple of vectors, which is a DataFrame as far as AlgebraOfGraphics is concerned.
#* Actually making a dataframe has the same result.
train_data = (
    xs = inputs_train[1, :], 
    ys = inputs_train[2, :], 
    label = outputs_train
)

#* Let's see if our dataset looks right, by plotting it.
#* If you don't understand the plotting below, but want to, see the AlgebraOfGraphics documentation. Not important for Machine Learning
train_data_plot = AOGdata(train_data) * mapping(:xs, :ys, color = :label)
the_line = AOGdata((xs = 0:0.1:1, ys = y.(0:0.1:1))) * mapping(:xs, :ys) * linear()
draw(train_data_plot + the_line; axis)

##
#* Putting the data in a tuple inside a vector, because it is how Flux likes it's data.
data = [(inputs_train, outputs_train_encoded)]

#* Making the ANN:
model = Chain(
    Dense(2, 5),    #* 2 inputs because an input point has 2 coordinates
    Dense(5, 5, relu),    #* Arbitraty choice
    Dense(5, 2, σ)  #* 2 outputs because there are 2 categories. σ activation so that output can be interpreted as probabilities.
)

#* Defining the loss-function.
loss(inputs, outputs) = Flux.Losses.logitbinarycrossentropy(model(inputs), outputs)
#* crossentropy is good for classification, binary because there are only 2 categories, and logit because of numerical stability (as reccomended in the docstring for `crossentropy`)


opt = ADAM()  #* ADAM works for classification problems
parameters = params(model)   #* A variable that referres to the model parameters.


##
#* See performance without training

decisions = onecold(model(inputs_test), categories)
model_test_data = (
    xs = inputs_test[1, :],
    ys = inputs_test[2, :],
    decisions = decisions
)
model_test_plot = AOGdata(model_test_data) * mapping(:xs, :ys, color = :decisions) + the_line
draw(model_test_plot; axis)


##
#* Train for 500 rounds, plotting the result in the end of this block. Run this block several times to watch the network learn!
for _ = 1:500
    train!(loss, parameters, data, opt)
end

#* Onecold does the reverse of onehot, making translating the categoies from computer to human
#* The output is the most likely category
decisions = onecold(model(inputs_test), categories)
count(==(1), decisions .== outputs_test)   #* Count how many the ANN got right, if wanted


#* Plotting the test-data (which the model has not trained on), colored by the model's categorization.
model_test_data = (
    xs = inputs_test[1, :], 
    ys = inputs_test[2, :], 
    decisions = decisions
)
model_test_plot = AOGdata(model_test_data) * mapping(:xs, :ys, color = :decisions) + the_line
draw(model_test_plot; axis)

The plots:
Sany check of model data:

After 1000 train! calls:

Another 1000:

Another 1000:

Another 1000: