Alternative to FluxOptTools?

Dear all,
I have an optimization problem in which some of the functions are approximated by NNs, thus I am using Flux to represent them. The package FluxOptTools has allowed me to use Optim to optimize my “loss” function L(), which depends on these NNs, NNx and NNu. I created a MWE below, which is a reduced and toy version of my real implementation. L() is an Augmented Lagrangian and the full code has the proper loop to update multipliers and penalty terms, but for shortness I removed them from this MWE.

My question is: Is there any alternative to FluxOptTools to do what I am doing, so I have access to other algorithms Optim.jl doesn’t provide?

My domain is time, which is being integrated out in f() and in the constraints h(); and x₀, which is being iterated over in J(), thus I failed when using Flux’s train! function. I have also looked into Optimization.jl and Surrogates.jl but couldn’t figure an alternative way. I want to try optimization algorithms Optim.jl does not provide, like Adam().

I have spent a couple of weeks trying to figure this out and failed, thus I am asking for your help.

Here is the MWE, which works, so you can have a better idea of what I am doing…

Feel free to criticize anything you might find odd. Thank you!

using Optim, LinearAlgebra, FastGaussQuadrature, QuadGK, Zygote, Flux, FluxOptTools

const T  = 5
const X₀ = 5
const N  = 2
const M  = 4
const r  = 4
icrange = collect(0:0.5:X₀)

x₂⁰ = 1.0 # initial velocity
xT¹, xT² = 3.0, 0.0 # final positions and velocity 
xT = [xT¹, xT²]

NNumodel = Chain(
    Dense(2, 5, elu),
    Dense(5, 1)
) |> f64

NNu(t, x₀) = NNumodel(vcat(t, x₀))

NNxmodel = Chain(
    Dense(2, 5, elu),
    Dense(5, 2)
) |> f64

NNx(t, x₀) = NNxmodel(vcat(t, x₀))

function ẋ(t, x₀) 
    Δt = 10e-6
    return [ (NNx((t+Δt), x₀)[i] - NNx((t-Δt), x₀)[i])/(2*Δt) for i in 1:N ]
end

#Objective functional
J(t) = r*sum(NNu(t, x₀)[1]^2 for x₀ in icrange)

#Quadrature transform function
t(x) = 1/2*T*x + 1/2*T
W(w) = T/2*w

#Integration nodes
const pts = 60
x, w = gausslobatto(pts)
xₜ = t.(x)
wₜ = W.(w)

#--- Objective functional integral ---#
f() = dot(wₜ, J.(xₜ)) 

#--- True dynamics ---#
g(t, x₀) = [0 1; 0 0]*NNx(t, x₀) + [0, 1]*NNu(t, x₀)[1]

#---Equality constraints---#
h₁() = sum(dot(wₜ, 0.5*norm.(ẋ.(xₜ, Ref(x₀)) - g.(xₜ, Ref(x₀))).^2) for x₀ in icrange)
h₂() = sum(1/2*norm(NNx(T, x₀) - xT)^2        for x₀ in icrange)
h₃() = sum(1/2*norm(NNx(0, x₀) - [x₀, x₂⁰])^2 for x₀ in icrange)

function L()  # "Loss" -- Augmented Lagrangean 
    return f() + υ₁*h₁() + υ₂*h₂() + υ₃*h₃() + μ₁*(h₁())^2 + μ₂*(h₂())^2 + μ₃*(h₃())^2
end

μ₁ = μ₂ = μ₃ = 10.0
υ₁ = υ₂ = υ₃ = 50*rand()

Zygote.refresh()
θ = Flux.params(NNumodel, NNxmodel)
Lfun, gradfun, fg!, p0 = optfuns(L, θ)

opt = optimize(Optim.only_fg!(fg!), p0, LBFGS(), Optim.Options(iterations=5_000, g_tol=10e-3, store_trace=true, show_every=5, show_trace=true))

The internals of FluxOptTools are very simple, you could easily modify them to provide an interface for any other optimizer. The whole Optim interface is 25LOC only.

2 Likes

It is easy to write the function f, g! and fg! by hand and only painful part is that you need to convert structured representation of Flux’s models to a vector. You can do this “easily” using Flux’s destructure and restructure (see API · Optimisers.jl), but it fight fail (give wrong results) with some corner cases. You might be therefore better off to go with Lux.jl, which was designed with this application in mind (see All about Lux - Lux.jl).

I hope this sends you in a right direction.

2 Likes