Hi, I am trying to using MLP to approximate a multivariate function, however, the loss gets stuck very quickly and I am not able to get a model that can approximate the function well. I have tried the same thing in python TensorFlow and it converges very quickly
model2= Chain(Dense(4,64),
Dense(64,32,tanh),
Dense(32,32,tanh),
Dense(32,1))|> gpu
training_data_2=zeros((1000,5))|> gpu
# output=zeros((1000,1))
for sets in 1:1000
entries=[]
for i in 1:4
append!(entries,rand(1:100)/100)
end
y= 4*(entries[1]^2)*entries[3]-(6*entries[2]*(entries[4]))
training_data_2[sets,1:4]=entries
training_data_2[sets,5]=y
end
function loss_func(x,y)
return Flux.mse(model2(transpose(x)),y)/(size(x)[1])
end
function callback2!(x,y)
loss=loss_func(x,y)
push!(losses,loss)
end
opt=ADAM(0.001)
function my_custom_train2!(loss_func, ps, data, opt,cb,loss_epochs,batch_size=100)
# training_loss is declared local so it will be available for logging outside the gradient calculation.
len=size(data)[1]
i=1
j=batch_size
total_loss=0
steps=1
while j<len
x=data[i:j,1:4]
y=data[i:j,5]
gs = gradient(ps) do
training_loss = loss_func(x,y)
# Code inserted here will be differentiated, unless you need that gradient information
# it is better to do the work outside this block.
return training_loss
end
# Insert whatever code you want here that needs training_loss, e.g. logging.
# logging_callback(training_loss)
# Insert what ever code you want here that needs gradient.
# E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge.
# for pss in ps
# gs[pss] == nothing && continue
# update!(opt, x, gs[x])
# end
update!(opt, ps, gs)
cb(x,y)
steps=steps+1
total_loss=total_loss+loss_func(x,y)
# Here you might like to check validation set accuracy, and break out to do early stopping.
i=j
if j+batch_size< len
j=j+batch_size
else
j=len
end
end
final_loss= total_loss/steps
push!(loss_epochs,final_loss)
print("Final Average loss for this epochs "); println(final_loss)
end
for iter in 1:100
my_custom_train2!(loss_func, ps2,training_data_2 , opt,callback2,loss_epochs)
end