I am trying to use an optimiser from Manopt.jl to train a Lux.jl neural network using Reactant.jl. It is orders of magnitude slower than without Reactant.jl. Does anyone know how to fix it?
The compiled functions, when benchmarked on their own are very fast. But when I try to use them inside an optimiser, they are constantly recompiled. I’ve gathered that it might be the input type of the function is changing, but surely an optimiser would call with consistent input types, or at least a finite number of them, so compilation would only be required a couple of times at most.
The code is here
using Reactant
using Zygote, Enzyme, Lux
using Random
using ComponentArrays
using ADTypes
using BenchmarkTools
using Manifolds
using Manopt
import ManifoldsBase.inner
import ManifoldsBase.norm
import Base.-
inner(M::Euclidean, p, X, Y) = sum(X .* Y)
norm(M::Euclidean, p, X) = (sum(X .* X) ^ (1/2))
-(x::ConcretePJRTNumber{Float64, 1}) = ConcretePJRTNumber(0 - x)
Random.seed!(32)
Reactant.set_default_backend("cpu")
xdev = reactant_device()
layer_dimension = 8
model = Chain(Dense(1, layer_dimension, swish),
Chain(SkipConnection(Dense(layer_dimension, layer_dimension, swish), +),
Chain(SkipConnection(Dense(layer_dimension, layer_dimension, swish), +), Dense(layer_dimension, 1))))
ps, st = Lux.setup(Random.default_rng(), model);
# make it into Float64
ps = ps |> ComponentArray .|> Float64
ps_ra = ps |> xdev
st_ra = st |> xdev
xx = randn(1, 100)
xx_test = randn(1, 100)
yy = (xx .+ 1.0) .^ 2.0 .- 1.0
yy_test = (xx_test .+ 1.0) .^ 2.0 .- 1.0
xx_ra = xx |> xdev
yy_ra = yy |> xdev
xx_test_ra = xx |> xdev
yy_test_ra = yy_test |> xdev
model_compiled = @compile model(xx_ra, ps_ra, Lux.testmode(st_ra))
function Loss(params, state, model, xx, yy)
return sum(abs2, (first(model(xx, params, state)) .- yy) ) / 2 # + 0.01 * sum(abs2, params) / 2
end
function Gradient(params, state, model, xx, yy)
residual = first(model(xx, params, state)) - yy
gr1 = vector_jacobian_product(z -> first(model(xx, z, state)), AutoEnzyme(), params, residual)
return gr1 #+ 0.01 .* params
end
function Jacobian(params, state, model, xx)
return batched_jacobian(z -> first(model(z, params, state)), AutoEnzyme(), xx)
end
Loss_compiled = @compile Loss(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra)
Gradient_compiled = @compile Gradient(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra)
Jacobian_compiled = @compile Jacobian(ps_ra, Lux.testmode(st_ra), model, xx_ra)
Manifold = Euclidean(length(ps_ra))
res = trust_regions(
Manifold,
(Manifold, ps_ra) -> Loss_compiled(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra),
(Manifold, ps_ra) -> Gradient_compiled(ps_ra, Lux.testmode(st_ra), model, xx_ra, yy_ra),
ps_ra;
stopping_criterion=StopAfterIteration(1000) | StopWhenGradientNormLess(1e-6),
debug = [:Iteration, :Cost, :GradientNorm, " | ", 1, "\n", :Stop]
)