Flux Loss is not converging when trying to approximate a multivariate function

Hi, I am trying to using MLP to approximate a multivariate function, however, the loss gets stuck very quickly and I am not able to get a model that can approximate the function well. I have tried the same thing in python TensorFlow and it converges very quickly

model2= Chain(Dense(4,64),
            
             Dense(64,32,tanh),
            Dense(32,32,tanh),
            Dense(32,1))|> gpu
training_data_2=zeros((1000,5))|> gpu
# output=zeros((1000,1))

for sets in 1:1000
    entries=[]
    for i in 1:4
        append!(entries,rand(1:100)/100)
    end
    y= 4*(entries[1]^2)*entries[3]-(6*entries[2]*(entries[4]))
    training_data_2[sets,1:4]=entries
    training_data_2[sets,5]=y
end
function loss_func(x,y)
    return Flux.mse(model2(transpose(x)),y)/(size(x)[1])
end
function callback2!(x,y)
    loss=loss_func(x,y)
    push!(losses,loss)
    end

opt=ADAM(0.001)

function my_custom_train2!(loss_func, ps, data, opt,cb,loss_epochs,batch_size=100)
  # training_loss is declared local so it will be available for logging outside the gradient calculation.

    len=size(data)[1]
    i=1
    j=batch_size
    total_loss=0
    steps=1
    while j<len
        x=data[i:j,1:4]
        y=data[i:j,5]
        gs = gradient(ps) do
          training_loss = loss_func(x,y)
          # Code inserted here will be differentiated, unless you need that gradient information
          # it is better to do the work outside this block.
          return training_loss
        end
        # Insert whatever code you want here that needs training_loss, e.g. logging.
        # logging_callback(training_loss)
        # Insert what ever code you want here that needs gradient.
        # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge.
#          for pss in ps
#             gs[pss] == nothing && continue
#             update!(opt, x, gs[x])
#           end
        update!(opt, ps, gs)
        cb(x,y)
        steps=steps+1
        total_loss=total_loss+loss_func(x,y)
        # Here you might like to check validation set accuracy, and break out to do early stopping.
    i=j
    if j+batch_size< len
        j=j+batch_size
    else
        j=len
    end
            
    end
    final_loss= total_loss/steps
    push!(loss_epochs,final_loss)
    print("Final Average loss for this epochs "); println(final_loss)
end

for iter in 1:100
    my_custom_train2!(loss_func, ps2,training_data_2 , opt,callback2,loss_epochs)
end