Hi,
I am trying (and failing in terms of precision) to use neural networks to fit stochastic optimal control problems. Since this problem has many parts and I would like to have 1e-6 precision I went back to something much simpler. Also with the much simpler problem: fitting x^2 at n points I fail to get 1e-6 precision within a reasonable amount of iterations / time.
Therefore my question is:
Do you have any hints as to how to efficiently approximate functions using neural networks?
I understand I can play with the following:
- network depth
- network width
- activation function
- optimiser
- optimiser parameter scheduling
- batch normalisation
and playing around with these I figured that:
- increasing depth and width helps in achieving higher precision but not reliably and at some point it becomes computationally inefficient
- ADAM seems to converge fastest (if it reaches 1e-6 at all)
- 1e-5 can be attained in a much more reliable and efficient way
- learning rate scheduling helps a lot in getting to lower errors faster
- batch normalisation is supposed to help when sampling points are random but it makes convergence much harder
To make it more concrete:
using Flux, Statistics, Plots, Animations, Random, Dates
# Optimiser
opt = ADAM()
# Learning rate scheduling function
learning_rate_function = Animations.Animation([0, 200000],
[.005, 1e-5],
[Animations.expin(1e-4)])
# Sample points
Random.seed!(1)
input = rand(1,100) .+ .5
output = input .^ 2
# Neural Network and parameters
nn_width = 10
m = Chain(Dense(1,nn_width,tanh),
Dense(nn_width,nn_width,tanh),
Dense(nn_width,nn_width,tanh),
Dense(nn_width,1))
θ = params(m)
# Loss function
loss() = mean(abs, m(input) .- output)
losses = []
η = []
start = Dates.now()
# Trainging loop
for i in 1:500000
# comment out for constant learning rate
# opt.eta = learning_rate_function(size(losses)[1])
push!(η,opt.eta)
# Gradient with respect to parameters
∇ = Flux.gradient(θ) do
loss()
end
push!(losses,loss())
# Update parameters
Flux.update!(opt,θ,∇)
# Print progress
if size(losses)[1] % 10000 == 6000
p1 = plot(losses, title = "Mean(|loss|)")
p2 = plot(η, title = "η")
p = plot(p1,p2,layout = (2,1),legend = false, yaxis= :log)
savefig(p, "training_jl.png")
end
# Stopping criterium
if losses[end] < 1e-6
println("Tolerance reached in ",round(Dates.now() - start, Dates.Minute), " and ", size(losses)[1], " steps.")
break
end
end