I am trying to optimize the rosenbrock function using the ADAM method, but need to make sure that it only iterates n/2 (where n iterations is given by a test code) times. How can I set up a counter to count every time the function executes, and then use that value to create a while loop saying, while the # of iterations is less than or equal to n/2, keep executing that function?
Below is my code.
Thanks!
function optimize(f, g, x0, n, prob)
#= global f_counter
global g_counter =#
x′= x0
α = .001
γv = 0.9 # 0.9
γs = 0.999 # 0.999
ϵ = 1e-8 # -8
A = Adam(α, γv, γs, ϵ, 0, 0, 0) # object of type Adam holding the values shown left
init!(A, f, g, x′)
while 2*B <= n #fc + 2*gc
x′ = step!(A, f, g, x′)
end
return x′
end
# Adam Accelerated Descent Method (Algorithm 5.8 from "Algorithms of Optimization" by Kochenderfer and Wheeler))
abstract type DescentMethod end
mutable struct Adam <: DescentMethod
α # = 0.001 learning rate
γv # = 0.9 Decay
γs # = 0.999 # Decay
ϵ # = 1*10^-8 small number
k # step counter
v # 1st moment estimate
s # 2nd moment estimate
end
function init!(M::Adam, f, ∇f, x)
M.k = 0
M.v = zeros(length(x))
M.s = zeros(length(x))
return M
end
function step!(M::Adam, f, ∇f, x)
α, γv, γs, ϵ, k = M.α, M.γv, M.γs, M.ϵ, M.k
s, v, g = M.s, M.v, ∇f(x)
v[:] = γv*v + (1-γv)*g
s[:] = γs*s + (1-γs)*g.*g
M.k = k += 1
v_hat = v ./ (1-γv^k)
s_hat = s ./ (1-γs^k)
#= f()
g() =#
return x - α*v_hat ./ (sqrt.(s_hat) .+ ϵ)
end