I am implementing my own NN “library”, but I am stuck on which is the best way to then implement the network itself, i.e. which type of layers to use, how many and how many neurons.
Let’s consider for example the following dataset:
xtrain = [0.1 0.2; 0.3 0.5; 0.4 0.1; 0.5 0.4; 0.7 0.9; 0.2 0.1; 0.4 0.2; 0.3 0.3; 0.6 0.9; 0.3 0.4; 0.9 0.8]
ytrain = [(0.1* x[1]+0.2* x[2]+0.3)* rand(0.9:0.001:1.1) for x in eachrow(xtrain)]
xtest = [0.5 0.6; 0.14 0.2; 0.3 0.7; 20.0 40.0;]
ytest = [(0.1* x[1]+0.2* x[2]+0.3)* rand(0.9:0.001:1.1) for x in eachrow(xtest)]
After lot of trial and errors I realised that something like this gives me good results, even for the last testing data that is intentionally outside the training range:
l1 = FullyConnectedLayer(linearf,2,3,w=ones(3,2), wb=zeros(3))
l2 = FullyConnectedLayer(linearf,3,1, w=ones(1,3), wb=zeros(1))
mynn = buildNetwork([l1,l2],squaredCost,name="Feed-forward Neural Network Model 1")
train!(mynn,xtrain,ytrain,maxepochs=10000,η=0.01,rshuffle=false,nMsgs=10)
But the fact that the last element is catched well, it is because of the strictly linear nature of the relation between x and y.
When I violate that, things start to become harder. I am unable to get good results with the following dataset:
xtrain = [0.1 0.2; 0.3 0.5; 0.4 0.1; 0.5 0.4; 0.7 0.9; 0.2 0.1; 0.4 0.2; 0.3 0.3; 0.6 0.9; 0.3 0.4; 0.9 0.8]
ytrain = [(0.1* x[1]^2+0.2* x[2]+0.3)* rand(0.95:0.001:1.05) for x in eachrow(xtrain)]
xtest = [0.5 0.6; 0.14 0.2; 0.3 0.7; 20.0 40.0;]
ytest = [(0.1*x[1]^2+0.2*x[2]+0.3)*rand(0.95:0.001:1.05) for x in eachrow(xtest)]
I did try other activation functions (ReLU, tanh,…) and added many more layers/neurons, but what I obtain is then that the NN returns always the same \hat Y whatever the input.
How can I implement a NN that could “learn” a potential non-linear relationship so to applying it also to out-of-sample data, as in the second dataset given?
More broadly, I learned how to make the algorithm to NN, that was easy, but what strategies should I now use for actual building the network ? Any reference on this topic ?