Knet.jl: Simple MLP for Iris Dataset

Hi everyone, I’m currently studying (I’m not done reading the docs yet) how Knet works. I’m trying to create a classification model for the iris dataset. That is, using the four features of the iris dataset as the predictors (xtrn1) and the species of the iris dataset as the label (ytrn1).

Here’s my code based from the Knet’s LeNet example on Github.

using Knet
using RDatasets

iris = dataset("datasets", "iris");
xtrn1 = Matrix(iris[:, 1:4]);
ytrn1 = iris[:, 5];
ytrn1 = map(x -> x == "setosa" ? 1 : x == "versicolor" ? 2 : 3, ytrn1);
dtrn1 = minibatch(Float32.(xtrn1'), ytrn1, 10);

# Define the Dense layer
struct Dense; w; b; f; end
Dense(i::Int, o::Int, f = relu) = Dense(param(o, i), param0(o), f) # constructor
(d::Dense)(x) = d.f.(d.w * mat(x) .+ d.b) # define method for dense layer

# Define Chain layer
struct Chain; layers; end
(c::Chain)(x) = (for l in c.layers; x = l(x); end; x) # define method for feed-forward
(c::Chain)(x, y) = nll(c(x), y, dims = 1) # define method for negative-log likelihood loss

# Define the Model
model = Chain((Dense(4, 10), Dense(10, 3), x -> softmax(x, dims = 1)))
adam!(model, repeat(dtrn1, 10)) # train the model
accuracy(model, dtrn1)

So the ytrn1 is an array of 1’s, 2’s, and 3’s corresponding to the species. I did not transform it into a one-hot-vector, since I notice the Knet’s LeNet example on Github uses labels as is (without translating to one-hot-vector). My questions is, is this how Knet works by design? Contrary to how Flux works, where we translate the response variable to a one-hot-vector for multiclass.

Further, I used Float32.(xtrn1') conversion because if I use Float64.(xtrn1') I get the following error:
ERROR: Gradient type mismatch: w::Array{Float32,1} g::Array{Float64,1}
I also want to understand why?

Lastly, the accuracy I got from this is, most of the time 0.3333333333333333, sometimes 0.49333333333333335 or 0.66. So I’m not sure if I specified the data, model and the training correctly.

Thanks you very much for any help.

1 Like

In case someone is interested in this question, I made a blog post about this here.

1 Like

@alasaadstat I am following your very informative blog post. One very small thing - perhaps can you change the font? I find the font very ‘thin’.
Someone here will point out that I can change this in my browser (Chrome).