Plots for 3 values

Hello,

is there a possibility o make this example (https://juliaacademy.com/courses/deep-learning-with-flux-jl/lectures/9624892) work with all colours? I tried it by myself and the model seems to work fine. But i am not able to get it into plot…

My code is below

using Pkg
Pkg.add("CSV")
Pkg.add("DataFrames")
Pkg.add("Flux")
Pkg.add("Plots")
using CSV, DataFrames, Flux, Plots

apples1 = DataFrame(CSV.File("C:\\Users\\t.berky\\OneDrive\\Plocha\\BP\\julia\\juliaAcademy\\Courses\\Deep learning with Flux\\data\\Apple_Golden_1.dat", delim='\t', normalizenames=true))
apples2 = DataFrame(CSV.File("C:\\Users\\t.berky\\OneDrive\\Plocha\\BP\\julia\\juliaAcademy\\Courses\\Deep learning with Flux\\data\\Apple_Golden_2.dat", delim='\t', normalizenames=true))
apples3 = DataFrame(CSV.File("C:\\Users\\t.berky\\OneDrive\\Plocha\\BP\\julia\\juliaAcademy\\Courses\\Deep learning with Flux\\data\\Apple_Golden_3.dat", delim='\t', normalizenames=true))
bananas = DataFrame(CSV.File("C:\\Users\\t.berky\\OneDrive\\Plocha\\BP\\julia\\juliaAcademy\\Courses\\Deep learning with Flux\\data\\bananas.dat", delim='\t', normalizenames=true))
grapes1 = DataFrame(CSV.File("C:\\Users\\t.berky\\OneDrive\\Plocha\\BP\\julia\\juliaAcademy\\Courses\\Deep learning with Flux\\data\\Grape_White.dat", delim='\t', normalizenames=true))
grapes2 = DataFrame(CSV.File("C:\\Users\\t.berky\\OneDrive\\Plocha\\BP\\julia\\juliaAcademy\\Courses\\Deep learning with Flux\\data\\Grape_White_2.dat", delim='\t', normalizenames=true))

apples = vcat(apples1, apples2, apples3)
grapes = vcat(grapes1, grapes2)

x_apples = [ [apples[i, :red], apples[i, :blue], apples[i, :green]] for i in 1:size(apples, 1) ]
x_bananas = [ [bananas[i, :red], bananas[i, :blue], bananas[i, :green]] for i in 1:size(bananas, 1) ]
x_grapes = [ [grapes[i, :red], grapes[i, :blue], grapes[i, :green]] for i in 1:size(grapes, 1) ]
xs = vcat(x_apples, x_bananas, x_grapes)
ys = vcat(fill(Flux.onehot(1, 1:3), size(x_apples)),
          fill(Flux.onehot(2, 1:3), size(x_bananas)),
          fill(Flux.onehot(3, 1:3), size(x_grapes)));

model = Chain(Dense(3, 4, σ), Dense(4, 3, identity), softmax)
L(x,y) = Flux.crossentropy(model(x), y)
opt = Descent(0.1)
databatch = (Flux.batch(xs), Flux.batch(ys))
Flux.train!(L, params(model),Iterators.repeated(databatch, 1000), opt, cb = Flux.throttle(() -> println("Probíhá trénování"), 5))

W = model(x_apples[1])
X = model(x_bananas[1])

using Plots
function plot_decision_boundaries(model, x_apples, x_bananas, x_grapes)
    plot()
    xlabel!("hodnota červené")
    ylabel!("hodnota modré")
    contour!(0:0.01:1, 0:0.01:1, (x,y,z)->model([x,y,z])[1], levels=[0.5,0.501], color = cgrad([:red, :red]), colorbar=:none)
    contour!(0:0.01:1, 0:0.01:1, (x,y,z)->model([x,y,z])[2], levels=[0.5,0.501], color = cgrad([:blue, :blue]), colorbar=:none)
    contour!(0:0.01:1, 0:0.01:1, (x,y,z)->model([x,y,z])[3], levels=[0.5,0.501], color = cgrad([:green, :green]), colorbar=:none)

    scatter!(first.(x_apples), last.(x_apples), m=:cross, label="apples", color = :red)
    scatter!(first.(x_bananas), last.(x_bananas), m=:circle, label="bananas", color = :blue)
    scatter!(first.(x_grapes), last.(x_grapes), m=:square, label="grapes", color = :green)
end
plot_decision_boundaries(model, x_apples, x_bananas, x_grapes)
png(plot_decision_boundaries(model, x_apples, x_bananas, x_grapes), "3_1_1")