Hi all,
I am solving an optimization problem using an Augmented Lagrangian (AL) of a constrained parameterized problem. I have about 400 parameters and the AL spits out a scalar which is to be minimized. I have already solved the problem, but I am experimenting with distinct gradient approaches: finite differences and forward (both already provided by Optim
) and I am testing Zygote
for reserve autodifferentiation. This is relevant to me because I want to cast my algorithm into a NN lens. There are also many reasons why I am not doing this directly with Flux
.
It happens that the options provided by Optim
are much faster than Zygotes
ās reverse one. The time the algorithm running with Zygote
takes to complete one iteration is three times the time it takes for the complete solution when using =:forward
with Optim
.
I am aware of these discussions here:
In order to provide an idea of what I am doing, I created this mini-toy problem:
using Optim, Zygote
Ī¼ = 2
Ļ
= rand()
#Objective Function
f(x::Vector{Float64}) = x[1]^2 + x[2]^2
#Constraint
g(x::Vector{Float64}) = x[1] + x[2] - 1
#Augmented Lagrangian
L(x::Vector{Float64}) = f(x) + Ļ
*g(x) + Ī¼*g(x)^2
#This is where I write the reverse diff gradient
function g!(G, x)
G .= gradient(L, x)[1]
end
xā = rand(2)
VIOLā = 10e12
k = 0
Ļµ = 10e-4
cond = false
while cond == false
#Optimization
@label innerloop
println("---------------------------------------------------------------------------------------------------")
if k > 15 @goto terminate end
println("Iteration: $k")
opt = optimize(L, g!, xā, ConjugateGradient(), Optim.Options(show_trace=true, g_tol=10e-3, show_every=10))
#Storing solution
xā = opt.minimizer
x = xā
#Checking violations
VIOLāāā = copy(VIOLā)
VIOLā = abs(g(x))
if VIOLā < Ļµ
println("*Solution found: x=$x")
@goto terminate
elseif VIOLā < 0.9*VIOLāāā
println("*VIOL condition validated...")
@goto outerloop
else
Ī¼ = 10*Ī¼
println("*Updating penalty term...")
print("Penalty terms: Ī¼=$Ī¼")
k += 1
@goto innerloop
end
#Update Lagrange multipliers
@label outerloop
println("*Accessing outerloop...")
Ļ
= Ļ
+ 2*Ī¼*g(x)
k += 1
@goto innerloop
@label terminate
cond = true
end
In the actual problem, I have more constraints (like g
above), and they involve some integration (for which I am using FastGaussQuadrature
, just out of curiosity).
The way I write g!
above is exactly how I write it in the actual problem.
I am especially confused with the fact that the rule of thumb tells us that the reverse diff is more adequate for many inputs, less outputs, which is precisely my case.
Any comments on why I am getting an unworkable slow computation time with Zygote
are very welcome.
If you feel you are an expert on the matter and perhaps could help me even further by looking into my full code (including a review on other potential performance issues), please donāt hesitate to send me a message to discuss a potential (compensated) consultation (I apologize if this goes against any community guidelines).
Thank you in advance.