Zygote vs. Forward Diff with Optim

Hi all,
I am solving an optimization problem using an Augmented Lagrangian (AL) of a constrained parameterized problem. I have about 400 parameters and the AL spits out a scalar which is to be minimized. I have already solved the problem, but I am experimenting with distinct gradient approaches: finite differences and forward (both already provided by Optim) and I am testing Zygote for reserve autodifferentiation. This is relevant to me because I want to cast my algorithm into a NN lens. There are also many reasons why I am not doing this directly with Flux.

It happens that the options provided by Optim are much faster than Zygotes’s reverse one. The time the algorithm running with Zygote takes to complete one iteration is three times the time it takes for the complete solution when using =:forward with Optim.

I am aware of these discussions here:

In order to provide an idea of what I am doing, I created this mini-toy problem:

using Optim, Zygote

μ = 2
υ = rand()

#Objective Function
f(x::Vector{Float64}) = x[1]^2 + x[2]^2 

#Constraint
g(x::Vector{Float64}) = x[1]   + x[2] - 1

#Augmented Lagrangian 
L(x::Vector{Float64}) = f(x) + υ*g(x) + μ*g(x)^2

#This is where I write the reverse diff gradient
function g!(G, x)
    G .= gradient(L, x)[1]
end

xₖ    = rand(2)
VIOLₖ = 10e12
k     = 0
ϵ     = 10e-4

cond = false

while cond == false
    #Optimization
    @label innerloop
    println("---------------------------------------------------------------------------------------------------")
    
    if k > 15 @goto terminate end
    println("Iteration: $k")
    
    opt = optimize(L, g!, xₖ, ConjugateGradient(), Optim.Options(show_trace=true, g_tol=10e-3, show_every=10))
    
    #Storing solution
    xₖ = opt.minimizer
    x  = xₖ

    #Checking violations
    VIOLₖ₋₁ = copy(VIOLₖ)
    VIOLₖ   = abs(g(x))

    if VIOLₖ < ϵ
        println("*Solution found: x=$x")
        @goto terminate
    elseif VIOLₖ < 0.9*VIOLₖ₋₁
        println("*VIOL condition validated...")
        @goto outerloop
    else 
        μ = 10*μ                
                      
        println("*Updating penalty term...")
        print("Penalty terms: μ=$μ")
        
        k += 1
        @goto innerloop
    end

    #Update Lagrange multipliers
    @label outerloop
    println("*Accessing outerloop...")
    υ = υ + 2*μ*g(x)
        
    k += 1
    @goto innerloop

    @label terminate
    cond = true
end

In the actual problem, I have more constraints (like g above), and they involve some integration (for which I am using FastGaussQuadrature, just out of curiosity).

The way I write g! above is exactly how I write it in the actual problem.

I am especially confused with the fact that the rule of thumb tells us that the reverse diff is more adequate for many inputs, less outputs, which is precisely my case.

Any comments on why I am getting an unworkable slow computation time with Zygote are very welcome.

If you feel you are an expert on the matter and perhaps could help me even further by looking into my full code (including a review on other potential performance issues), please don’t hesitate to send me a message to discuss a potential (compensated) consultation (I apologize if this goes against any community guidelines).

Thank you in advance.

With this code it’s not obvious why ForwardDiff would work better (actually it would probably fail due to strict Float64 typing).
But the fact that you use numerical integration might be a clue. Reverse-mode autodiff in general comes with a high memory cost as soon as iterative procedures are involved.

1 Like

Have you tried Integrals.jl? It defines algorithms that have better compatibility with Zygote.jl thanks to custom “chain rules”. I can tell you more if you’re interested but I think it would already solve your problem

Thank you, I suspected the iterative aspect had something to do with it.

I will definitely look into Integrals.jl and get back to you! Thank you again.

1 Like