Hi, I’m trying to train a MLP model to predict number of real roots of polynomials. x_train and y_train include array of arrays such as [[-204, 20, 13, 1, 0]] which are coefficients of polynomials. x_test and y_test include number of real roots of each polynomial such 1,2,5… I have polynomials up to 100 degree therefore 101 expresses that. I’m stuck in this error. I’m trying to give data to MLP properly but I couldn’t do it even though I tried so hard. Can you please help me about this issue?
Here is my code:
using Flux
function input()
## x_train
lines = Tuple(readlines("/home/user/Desktop/x_train.txt"))
x_train = []
for i in lines
push!(x_train, convert(Array{Float32},eval(Meta.parse(i))))
end
x_train = reshape(x_train, 2, :)
## y_train
lines = Tuple(readlines("/home/user/Desktop/y_train.txt"))
y_train = []
for i in lines
push!(y_train, eval(Meta.parse(i)))
end
y_train = reshape(y_train, 2, :)
## x_test
lines = Tuple(readlines("/home/user/Desktop/x_test.txt"))
x_test = []
for i in lines
push!(x_test, convert(Array{Float32},eval(Meta.parse(i))))
end
x_test = reshape(x_test, 1, :)
## y_test
lines = Tuple(readlines("/home/user/Desktop/y_test.txt"))
y_test = []
for i in lines
push!(y_test, eval(Meta.parse(i)))
end
y_test = reshape(y_test, 1, :)
return x_train, x_test, y_train, y_test
end
x_train, x_test, y_train, y_test = input()
train_size = 18
function NeuralNetwork()
return Chain(
Dense(2, 9,relu),
Dense(9,101), softmax
)
end
# Organizing the data in batches
X = hcat(x_train,y_train)
Y = vcat(ones(train_size),zeros(train_size))
Y = reshape(Y, 18, 2)
println(size(X))
println(size(Y))
data = Flux.Data.DataLoader((X, Y'), batchsize=100,shuffle=true);
# Defining our model, optimization algorithm and loss function
m = NeuralNetwork()
opt = Descent(0.05)
loss(x, y) = sum(Flux.Losses.binarycrossentropy(m(x), y))
ps = Flux.params(m)
epochs = 20
for i in 1:epochs
Flux.train!(loss, ps, data, opt)
end
Error Message I Got:
ERROR: LoadError: DimensionMismatch("dimensions must match: a has dims (Base.OneTo(5),), b has dims (Base.OneTo(4),), mismatch at 1")
Thanks in advance!