FluxOptTools.jl
This package contains some utilities to enhance training of Flux.jl models.
Train using Optim
Optim.jl can be used to train Flux models (if Flux is on branch sf/zygote_updated
), here’s an example how
using Flux, Zygote, Optim, FluxOptTools, Statistics
m = Chain(Dense(1,3,tanh) , Dense(3,1))
x = LinRange(-pi,pi,100)'
y = sin.(x)
loss() = mean(abs2, m(x) .- y)
Zygote.refresh()
pars = Flux.params(m)
lossfun, gradfun, fg!, p0 = optfuns(loss, pars)
res = Optim.optimize(Optim.only_fg!(fg!), p0, Optim.Options(iterations=1000, store_trace=true))
The utility provided by this package is the function optfuns
which returns three functions and p0
, a vectorized version of pars
. L-BFGS typically has better convergence properties than, e.g., the ADAM optimizer. Here’s a benchmark where L-BFGS in red beats ADAM with tuned step size in blue.
The code for this benchmark is in the runtests.jl
.
Visualize loss landscape
We define a plot recipe such that a loss landscape can be plotted with
using Plots
plot(loss, pars, l=0.1, npoints=50, seriestype=:contour)
The landscape is plotted by selecting two random directions and extending the current point (pars
) a distance l*norm(pars)
(both negative and positive) along the two random directions. The number of loss evaluations will be npoints^2
.
Flatten and Unflatten
What this package really does is flattening and reassembling the types Flux.Params
and Zygote.Grads
to and from vectors. These functions are used like so
p = zeros(pars) # Creates a vector of length sum(length, pars)
copyto!(p,pars) # Store pars in vector p
copyto!(pars,p) # Reverse
g = zeros(grads) # Creates a vector of length sum(length, grads)
copyto!(g,grads) # Store grads in vector g
copyto!(grads,g) # Reverse
This is what is used under the hood in the functions returned from optfuns
in order to have everything on a form that Optim understands.