Dear all,
I have an optimization problem in which some of the functions are approximated by NNs, thus I am using Flux
to represent them. The package FluxOptTools
has allowed me to use Optim
to optimize my “loss” function L()
, which depends on these NNs, NNx
and NNu
. I created a MWE below, which is a reduced and toy version of my real implementation. L()
is an Augmented Lagrangian and the full code has the proper loop to update multipliers and penalty terms, but for shortness I removed them from this MWE.
My question is: Is there any alternative to FluxOptTools
to do what I am doing, so I have access to other algorithms Optim.jl
doesn’t provide?
My domain is time, which is being integrated out in f()
and in the constraints h()
; and x₀
, which is being iterated over in J()
, thus I failed when using Flux
’s train!
function. I have also looked into Optimization.jl
and Surrogates.jl
but couldn’t figure an alternative way. I want to try optimization algorithms Optim.jl
does not provide, like Adam()
.
I have spent a couple of weeks trying to figure this out and failed, thus I am asking for your help.
Here is the MWE, which works, so you can have a better idea of what I am doing…
Feel free to criticize anything you might find odd. Thank you!
using Optim, LinearAlgebra, FastGaussQuadrature, QuadGK, Zygote, Flux, FluxOptTools
const T = 5
const X₀ = 5
const N = 2
const M = 4
const r = 4
icrange = collect(0:0.5:X₀)
x₂⁰ = 1.0 # initial velocity
xT¹, xT² = 3.0, 0.0 # final positions and velocity
xT = [xT¹, xT²]
NNumodel = Chain(
Dense(2, 5, elu),
Dense(5, 1)
) |> f64
NNu(t, x₀) = NNumodel(vcat(t, x₀))
NNxmodel = Chain(
Dense(2, 5, elu),
Dense(5, 2)
) |> f64
NNx(t, x₀) = NNxmodel(vcat(t, x₀))
function ẋ(t, x₀)
Δt = 10e-6
return [ (NNx((t+Δt), x₀)[i] - NNx((t-Δt), x₀)[i])/(2*Δt) for i in 1:N ]
end
#Objective functional
J(t) = r*sum(NNu(t, x₀)[1]^2 for x₀ in icrange)
#Quadrature transform function
t(x) = 1/2*T*x + 1/2*T
W(w) = T/2*w
#Integration nodes
const pts = 60
x, w = gausslobatto(pts)
xₜ = t.(x)
wₜ = W.(w)
#--- Objective functional integral ---#
f() = dot(wₜ, J.(xₜ))
#--- True dynamics ---#
g(t, x₀) = [0 1; 0 0]*NNx(t, x₀) + [0, 1]*NNu(t, x₀)[1]
#---Equality constraints---#
h₁() = sum(dot(wₜ, 0.5*norm.(ẋ.(xₜ, Ref(x₀)) - g.(xₜ, Ref(x₀))).^2) for x₀ in icrange)
h₂() = sum(1/2*norm(NNx(T, x₀) - xT)^2 for x₀ in icrange)
h₃() = sum(1/2*norm(NNx(0, x₀) - [x₀, x₂⁰])^2 for x₀ in icrange)
function L() # "Loss" -- Augmented Lagrangean
return f() + υ₁*h₁() + υ₂*h₂() + υ₃*h₃() + μ₁*(h₁())^2 + μ₂*(h₂())^2 + μ₃*(h₃())^2
end
μ₁ = μ₂ = μ₃ = 10.0
υ₁ = υ₂ = υ₃ = 50*rand()
Zygote.refresh()
θ = Flux.params(NNumodel, NNxmodel)
Lfun, gradfun, fg!, p0 = optfuns(L, θ)
opt = optimize(Optim.only_fg!(fg!), p0, LBFGS(), Optim.Options(iterations=5_000, g_tol=10e-3, store_trace=true, show_every=5, show_trace=true))