How to efficiently and precisely fit a function with neural networks?

Hi,

I am trying (and failing in terms of precision) to use neural networks to fit stochastic optimal control problems. Since this problem has many parts and I would like to have 1e-6 precision I went back to something much simpler. Also with the much simpler problem: fitting x^2 at n points I fail to get 1e-6 precision within a reasonable amount of iterations / time.

Therefore my question is:
Do you have any hints as to how to efficiently approximate functions using neural networks?

I understand I can play with the following:

  • network depth
  • network width
  • activation function
  • optimiser
  • optimiser parameter scheduling
  • batch normalisation

and playing around with these I figured that:

  • increasing depth and width helps in achieving higher precision but not reliably and at some point it becomes computationally inefficient
  • ADAM seems to converge fastest (if it reaches 1e-6 at all)
  • 1e-5 can be attained in a much more reliable and efficient way
  • learning rate scheduling helps a lot in getting to lower errors faster
  • batch normalisation is supposed to help when sampling points are random but it makes convergence much harder

To make it more concrete:

using Flux, Statistics, Plots, Animations, Random, Dates


# Optimiser
opt = ADAM()

# Learning rate scheduling function
learning_rate_function = Animations.Animation([0, 200000],
                                                [.005, 1e-5],
                                                [Animations.expin(1e-4)])


# Sample points
Random.seed!(1)

input = rand(1,100) .+ .5

output = input .^ 2


# Neural Network and parameters
nn_width = 10

m = Chain(Dense(1,nn_width,tanh),
            Dense(nn_width,nn_width,tanh),
            Dense(nn_width,nn_width,tanh),
            Dense(nn_width,1))

θ = params(m)

# Loss function
loss() = mean(abs, m(input) .- output)

losses = []
η = []

start = Dates.now()

# Trainging loop
for i in 1:500000
    # comment out for constant learning rate
    # opt.eta = learning_rate_function(size(losses)[1])

    push!(η,opt.eta)

    # Gradient with respect to parameters
    ∇ = Flux.gradient(θ) do
        loss()
    end

    push!(losses,loss())

    # Update parameters
    Flux.update!(opt,θ,∇)

    # Print progress
    if size(losses)[1] % 10000 == 6000
        p1 = plot(losses, title = "Mean(|loss|)")
        p2 = plot(η, title = "η")
        p = plot(p1,p2,layout = (2,1),legend = false, yaxis= :log)

        savefig(p, "training_jl.png")
    end

    # Stopping criterium
    if losses[end] < 1e-6
        println("Tolerance reached in ",round(Dates.now() - start, Dates.Minute), " and ", size(losses)[1], " steps.")
        break
    end
end
1 Like

Neural networks generally seem to have pretty slow convergence if you are trying to fit to high precision. (And there is not a lot of theory behind their convergence rate. See this paper on rational neural networks, however.)

How many inputs and outputs does your function have? Is it smooth?

2 Likes

Thanks for your quick reply and the paper. I’ll look up how to implement rational neural networks.

As to your question, I started with 2 inputs and 2 outputs on a smooth function but quickly figured that reaching satisfactory precision in a reasonable time is not trivial. Then I tried the simplest approximation problem I could come up with (see code above). Even there it is not trivial and now I am out of ideas. But before moving on to something else I wanted to ask for help in the community.

Hi,
have you tried a different loss function? (why do you use mean abs error in your example? Why not MSE?).
Sounds to me like you’d want gradients to vanish near the optimum for that kind of problem.
Also: Are you aware that in your example, the parameters of your NN are 32-bit-floats (the default in flux)?

A (maybe naive) side question: Why that much precision? Sound like a strange requirement for a problem with “stochastic” in its name :upside_down_face:

Thanks for the suggestions.

I tried MSE and it does not improve the efficiency issue. As to your comment, why would you want the gradients to vanish near the optimum? It seems the optimisation stalls close to the optimum (1e-5) and in order to get further (<1e-6) you would need non-zero gradients!?

@32bit: I am aware of it but my thinking was that this shouldn’t matter too much. The smallest model I tried has 250 parameters, I guess this should make up for the lack of precision. But who knows. I’ll give it a try.

@precision: the big picture problem would require that the model estimates the ergodic distribution and long run steady states. Especially for long run steady states precision of the solution becomes important. Furthermore, simple perturbation is way faster (as it stands) and can achieve 1e-6 precision on simple problems.

Here is an updated version of the code taking all suggestions on board:

  • Rational function activation function (I took the suggested version in the paper linked by steven)
  • 64bit precision of model parameters
  • minimise MSE instead of MAE

None of these changes improved the situation substantially. Quite amazing that fitting x^2 is such a hard problem for neural networks.

using Flux, Statistics, Plots, Animations, Random, Dates

rational(x,a,b) = (a[1] .* x .^ 3 .+ a[2] .* x .^ 2 .+ a[3] .* x .+ a[4]) ./ (b[1] .* x .^ 2 .+ b[2] .* x .+ b[3])

struct Rational_32{M<:AbstractMatrix, B, T}
  weight::M
  bias::B
  a₁::T
  a₂::T
end

function Rational_32(in::Integer, out::Integer;
               initW = nothing, initb = nothing, 
               inita₁ = nothing, inita₂ = nothing,
               init = Flux.glorot_uniform, bias=true)

  W = if initW !== nothing
    Base.depwarn("keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)", :Dense)
    initW(out, in)
  else
    init(out, in)
  end

  b = if bias === true && initb !== nothing
    Base.depwarn("keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)", :Dense)
    initb(out)
  else
    Flux.create_bias(W, bias, size(W,1))
  end

  a₁ = if inita₁ !== nothing
    Base.depwarn("keyword inita is deprecated, please simply supply the desired vectors", :Rational_32)
    inita(4)
  else
    Flux.ones32(4,1)
  end

  a₂ = if inita₂ !== nothing
    Base.depwarn("keyword initb is deprecated, please simply supply the desired vectors", :Rational_32)
    initb(3)
  else
    Flux.ones32(3,1)
  end

  return Rational_32(W, b, a₁, a₂)
end

Flux.@functor Rational_32

function (p::Rational_32)(x::AbstractVecOrMat)
  return rational(p.weight*x .+ p.bias, p.a₁, p.a₂)
end

(a::Rational_32)(x::AbstractArray) = 
  reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

function Base.show(io::IO, l::Rational_32)
  print(io, "Rational_32(", size(l.weight, 2), ", ", size(l.weight, 1))
  l.bias == Flux.Zeros() && print(io, "; bias=false")
  print(io, ")")
end



# Optimiser
opt = ADAM(.005)

# Learning rate scheduling function
learning_rate_function = Animations.Animation([0, 200000],
                                                [.005, 1e-5],
                                                [Animations.expin(1e-4)])


# Sample points
Random.seed!(100)

input = rand(1,100) .+ .5

output = input .^ 2


# Neural Network and parameters
nn_width = 10

m = Chain(Rational_32(1,nn_width),
            Rational_32(nn_width,nn_width),
            Rational_32(nn_width,nn_width),
            Dense(nn_width,1))

m = f64(m)

θ = params(m)

m(input)

# Loss function
loss_abs2() = sum(abs2, m(input) .- output)
loss_abs() = mean(abs, m(input) .- output)

losses = []
η = []

start = Dates.now()

# Trainging loop
for i in 1:500000
    # comment out for constant learning rate
    opt.eta = learning_rate_function(size(losses)[1])

    push!(η,opt.eta)

    # Gradient with respect to parameters
    ∇ = Flux.gradient(θ) do
        loss_abs2()
    end

    push!(losses,loss_abs())

    # Update parameters
    Flux.update!(opt,θ,∇)

    # Print progress
    if size(losses)[1] % 10000 == 6000
        p1 = plot(losses, title = "Mean(|loss|)")
        p2 = plot(η, title = "η")
        p = plot(p1,p2,layout = (2,1),legend = false, yaxis= :log)

        savefig(p, "training_jl.png")
    end

    # Stopping criterium
    if losses[end] < 1e-6
        println("Tolerance reached in ",round(Dates.now() - start, Dates.Minute), " and ", size(losses)[1], " steps.")
        break
    end
end
1 Like

For only a few inputs (even if there are many outputs) and smooth functions I would just use Chebyshev interpolation. e.g. via

3 Likes

If you want to reach very low tolerances, you may want to try a solver like L-BFGS rather than ADAM once you reach close to a minimum.
https://github.com/baggepinnen/FluxOptTools.jl
has some tools to train Flux models using Optim, see the plot in the readme illustrating the difference in convergence betwen ADAM and BFGS

2 Likes

Thanks for the hint. Indeed for the simple problem this solves it fast and with high precision (1xx milliseconds for 2 inputs & 1 output + adaptive domain).

I’ll test a simple control problem later today.

At what size does it become unfeasible? I’m thinking of problems with 10-100 state variables in your experience?

You mean use ADAM to reach 1e-4 and then pass it over to L-BFGS?

I tried the linked package but without a prior pass of ADAM. L-BFGS converges faster (stopped because of failed line search) than ADAM but never reached 1e-6.

That is one option, you can use L-BFGS from the start, but ADAM is designed to work well for problems with lots of data and many parameters (stochastic gradient descent), so if your problem is large, it might be faster to prime L-bFGS with an initial point from ADAM.

I assume that the problem is nonconvex? Neither ADAM nor L-BFGS can really guarantee solve it to optimality for you then. Once again, ADAM might give you a better starting point for L-BFGS in this case

1 Like

No, the cost of Chebyshev interpolation increases exponentially with the number of variables. I would typically use it only for < 5 input variables.

(There are other methods you could try, e.g. sparse grids or radial basis functions. But I’m very skeptical that you’re going to get 6 digits of accuracy with any interpolation method for 10–100 variables unless something is extremely special about your function that you can exploit, e.g. if it is separable.)

2 Likes

Maybe I’m misunderstanding, but I think the OP wants to achieve a really low value of the loss function (on validation data as well as on training data) — a fit that is accurate to 6 digits — which is much harder than simply minimizing the loss function accurately.

To get a high-accuracy fit, in general you will need to increase the size of the training set and increase the size of the neural network (the number of “fit parameters”) and minimize the loss function accurately until your desired fit tolerance (on validation data) is achieved. Unfortunately, as a I said, as far as I know there is not a lot theory about the rate at which this process converges, especially in high dimensions. But from general principles it should be extremely hard to get high accuracy in high dimensions.

(Most machine-learning work is focused on achieving quite low accuracy fits, not 6-digit accuracy, in part because the data is typically so noisy that a high-accuracy fit is meaningless.)

2 Likes

This seems to work reasonably well on small problems. I tried: ADAM for max 100k iterations or stop once it reaches 5e-5 and then pass it to L-BFGS. I get to 1e-6 fairly consistent in about a minute, which is ok for bigger problems. I will see how it scales. Furthermore, I tested with and without all other suggestions so far and it seems rational activation function, Float64 and MSE is definitely helpful.

Here is some code for the ADAM + L-BFGS implementation:

using Flux, Statistics, Plots, Animations, Random, Dates, FluxOptTools, Zygote, Optim

rational(x,a,b) = (a[1] .* x .^ 3 .+ a[2] .* x .^ 2 .+ a[3] .* x .+ a[4]) ./ (b[1] .* x .^ 2 .+ b[2] .* x .+ b[3])

struct Rational_32{M<:AbstractMatrix, B, T}
  weight::M
  bias::B
  a₁::T
  a₂::T
end

function Rational_32(in::Integer, out::Integer;
               initW = nothing, initb = nothing, 
               inita₁ = nothing, inita₂ = nothing,
               init = Flux.glorot_uniform, bias=true)

  W = if initW !== nothing
    Base.depwarn("keyword initW is deprecated, please use init (which similarly accepts a funtion like randn)", :Dense)
    initW(out, in)
  else
    init(out, in)
  end

  b = if bias === true && initb !== nothing
    Base.depwarn("keyword initb is deprecated, please simply supply the bias vector, bias=initb(out)", :Dense)
    initb(out)
  else
    Flux.create_bias(W, bias, size(W,1))
  end

  a₁ = if inita₁ !== nothing
    Base.depwarn("keyword inita is deprecated, please simply supply the desired vectors", :Rational_32)
    inita(4)
  else
    Flux.ones32(4,1)
  end

  a₂ = if inita₂ !== nothing
    Base.depwarn("keyword initb is deprecated, please simply supply the desired vectors", :Rational_32)
    initb(3)
  else
    Flux.ones32(3,1)
  end

  return Rational_32(W, b, a₁, a₂)
end

Flux.@functor Rational_32

function (p::Rational_32)(x::AbstractVecOrMat)
  return rational(p.weight*x .+ p.bias, p.a₁, p.a₂)
end

(a::Rational_32)(x::AbstractArray) = 
  reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

function Base.show(io::IO, l::Rational_32)
  print(io, "Rational_32(", size(l.weight, 2), ", ", size(l.weight, 1))
  l.bias == Flux.Zeros() && print(io, "; bias=false")
  print(io, ")")
end



# Optimiser
opt = ADAM()

# Learning rate scheduling function
learning_rate_function = Animations.Animation([0, 200000],
                                                [.005, 1e-5],
                                                [Animations.expin(1e-4)])


# Sample points
Random.seed!(1100)

batch_size = 100

input = [rand(1,batch_size) .* .1 .+ .14; randn(1,batch_size) .* .0068]


# Neural Network and parameters
nn_width = 20

m = Chain(Rational_32(2,nn_width),
            Rational_32(nn_width,nn_width),
            Rational_32(nn_width,nn_width),
            Rational_32(nn_width,nn_width),
            Dense(nn_width,1))



output = 0.327 .* exp.(input[2,:]') .* input[1,:]' .^ 0.33

m = f64(m)

θ = params(m)

m(input)

# Loss function
loss_abs2() = sum(abs2, m(input) .- output)
loss_abs() = mean(abs, m(input) .- output)

losses = []
η = []

start = Dates.now()

# Trainging loop
for i in 1:100000
    # comment out for constant learning rate
    opt.eta = learning_rate_function(size(losses)[1])

    push!(η,opt.eta)

    # Gradient with respect to parameters
    ∇ = Flux.gradient(θ) do
        loss_abs2()
    end

    push!(losses,loss_abs())

    # Update parameters
    Flux.update!(opt,θ,∇)

    # Print progress
    if size(losses)[1] % 10000 == 6000
        p1 = plot(losses, title = "Mean(|loss|)")
        p2 = plot(η, title = "η")
        p = plot(p1,p2,layout = (2,1),legend = false, yaxis= :log)

        savefig(p, "training_jl.png")
    end

    # Stopping criterium
    if losses[end] < 4e-5 && size(losses)[1] > 50000
        println("Tolerance reached in ",round(Dates.now() - start, Dates.Minute), " and ", size(losses)[1], " steps.")
        break
    end
end


Zygote.refresh()

lossfun, gradfun, fg!, p0 = optfuns(loss_abs2, θ)

res = Optim.optimize(Optim.only_fg!(fg!), p0, Optim.Options(iterations=10000, store_trace=true))

println("Tolerance ",loss_abs()," reached in ",round(Dates.now() - start, Dates.Minute), ".")

A remaining question is how do you implement changing data in the loss function? Once you move to autoregressive processes with fixed points you would need to alternate between generating new data from the approximation and improving the approximation on the new data with L-BFGS. I figured it out in Flux but with Optim I don’t see it.

1 Like

My experience from playing around is that increasing the size of the training set and making the network larger do help in achieving a better fit but it becomes too expensive in terms of computational time.

Another approach is Taylor series. They are fast but imprecise further away from the approximation point.

There should be some insights from the unknown function approximation literature. I’ll have a look later.

If you have a semi-analytical expression for the function you are trying to approximate, then there might be a lot more tricks.

Usually, Taylor series are used only in a small portion of the domain near function zeros or other singularities; continued-fraction expansions are a typical tool in other portions of the domain.

2 Likes

I finally took some time to code a simple dynamic programming problem and solve it with Chebyshev polynomials.

Using Chebyshev polynomials on this size of problem is definitely fast and precise. Is there anything else you can think of for higher dimensional problems other than sparse grids (e.g. Smolyak) or finite elements!?

Here is the code:


using FastChebInterp, Plots, Dates

function simulate(state, parameters, guess)

    # # Parameters
    @views alpha        = parameters[1]
    @views beta         = parameters[2]
    @views std_epsilon_Z= parameters[3]

    # State
    @views epsilon_Z___ϵ___  = state[1]
    @views K___t₋___           = max(eps(state[2]),state[2])

    # Guess
    @views K___t₀___  = max(eps(guess[1]),guess[1])
    @views K___t₊___  = max(eps(guess[2]),guess[2])


    
    C___t₀___ = exp(std_epsilon_Z .* epsilon_Z___ϵ___) .* K___t₋___.^alpha .- K___t₀___

    C___t₊___ = K___t₀___.^alpha .- K___t₊___
    
    K = C___t₀___ ./ C___t₊___ .* beta * alpha .* K___t₀___.^(alpha)

    return K

end


lb = [-3, .14]
ub = [3, .24]

n_polys = (5,5)

x = chebpoints(n_polys, lb, ub)

parameters = [0.330, 0.990, .01]


start = Dates.now()

init_guess = fill([.2,.2],n_polys.+1)

pars = fill(parameters,n_polys.+1)

c = chebinterp(simulate.(x,pars,init_guess),lb,ub)

old = fill(.18,n_polys.+1)
old_hat = fill(.18,n_polys.+1)


new_guess = fill([0.0,0.0],n_polys.+1)

for ii in 1:100

  new  = c.(x)
  new_hat = c.([[0,i] for i in new])
  
  if maximum([maximum(abs,(old-new)./old), maximum(abs,(old_hat-new_hat)./old_hat)]) < 1e-6
    println("Tolerance reached in ",Dates.now() - start, " and ", ii, " steps.")    
    break
  end

  for i in 1:n_polys[1]+1
    for k in 1:n_polys[2]+1
      new_guess[i,k] = [new[i,k], new_hat[i,k]]
    end
  end

  c = chebinterp(simulate.(x,pars,new_guess),lb,ub)

  old = new
  old_hat = new_hat
end


# Steady state
c([0,0.188299632443837]) - 0.188299632443837


# IRFs and true solution
irf = fill(0.0,40)
irf[1] = c([1,0.188299632443837])

for i in 2:40
  irf[i] = c([0,irf[i-1]])
end


true_sol = fill(0.0,40)
true_sol[1] = .3267 * exp(.01 * 1) * 0.188299632443837 ^ .33

for i in 2:40
  true_sol[i] = .3267 * true_sol[i-1] ^ .33
end

plot(true_sol)
plot!(irf)
1 Like