I am trying to implement a custom Lux model, but I keep running out of memory even with small number (~500) of parameters. I have two kinds of layers (called sympnets) which take in and return two vectors. Each layer updates only one vector:
where \Delta t is fixed, \sigma is an activation function, \mathbf{a} and \mathbf{b} are trainable vectors, and \mathbf{K} is a trainable matrix. Here’s my implementation in Lux.jl
using Lux,MLUtils
using Random
import Zygote
import Optimisers
struct GLayer{I,T,type} <: LuxCore.AbstractLuxLayer
dim::I
h::T
act
end
const GLowerLayer{I,T} = GLayer{I,T, :lower}
const GUpperLayer{I,T} = GLayer{I,T, :upper}
GLowerLayer(dim::I, h::T, act) where {I,T} = GLowerLayer{I,T}(dim, h, act)
GUpperLayer(dim::I, h::T, act) where {I,T} = GUpperLayer{I,T}(dim, h, act)
function Lux.initialparameters(rng::AbstractRNG, layer::GLayer)
a = randn(rng, layer.dim)
K = randn(rng, layer.dim, layer.dim)
b = randn(rng, layer.dim)
return (; K=K, a=a, b=b)
end
function (layer::GUpperLayer)((q, p), ps, st::NamedTuple)
K = ps.K
a = ps.a
b = ps.b
act = layer.act
h = layer.h
q_new = q + h * K' * act.(K * p .+ b) .* a
return (q_new, p), st
end
function (layer::GLowerLayer)((q, p), ps, st::NamedTuple)
K = ps.K
a = ps.a
b = ps.b
act = layer.act
h = layer.h
p_new = p - h * K' * act.(K * q .+ b) .* a
return (q, p_new), st
end
I borrowed the loss and training functions from Lux.jl’s documentation
function loss_fn(model, ps, st, (q, p, q_next, p_next))
(q_pred, p_pred), new_st = model((q, p), ps, st)
loss = MSELoss()(q_pred, q_next) + MSELoss()(p_pred, p_next)
return loss, new_st, nothing
end
function train_model((q_training, p_training),
dt;
seed::Int=42,
N_EPOCHS::Int=2000,
hidden_dims::Int=4)
#Training data
q = reshape(q_training[1:end-1], 1, :)
p = reshape(p_training[1:end-1], 1, :)
q_next = reshape(q_training[2:end], 1, :)
p_next = reshape(p_training[2:end], 1, :)
#Define model
lowerlayer = GLowerLayer(size(q, 1), dt, tanh)
upperlayer = GUpperLayer(size(q, 1), dt, tanh)
fivelayercombo = [lowerlayer, upperlayer, lowerlayer, upperlayer, lowerlayer]
model = Chain(repeat(fivelayercombo, hidden_dims)...)
# initialize training state
rng = Xoshiro(seed)
ps, st = Lux.setup(rng, model)
train_state = Training.TrainState(model, ps, st, Optimisers.Adam(0.005))
dataloader = DataLoader((q, p, q_next, p_next); batchsize=128, shuffle=true, partial=false)
epoch = 0
for (q_batch, p_batch, q_next_batch, p_next_batch) in Iterators.cycle(dataloader)
learning_rate = 0.001
Optimisers.adjust!(train_state, learning_rate)
_, loss, stats, train_state = Training.single_train_step!(
AutoZygote(),
loss_fn,
(q_batch, p_batch, q_next_batch, p_next_batch),
train_state;
return_gradients=Val(false),
)
epoch >= N_EPOCHS && break
epoch += 1
if epoch % 100 == 0
@info "Epoch: $epoch /$N_EPOCHS"
end
end
smodel = StatefulLuxLayer(train_state.model, train_state.parameters, train_state.states)
return smodel
end
dt = 0.01
p_training = cos.(0:dt:2pi)
q_training = sin.(0:dt:2pi)
@time trained_model = train_model((q_training, p_training), dt, N_EPOCHS=1000, hidden_dims=5);
This works fine for a small number of hidden layers, but it scales badly with the number of hidden layers:
hidden_dims=5: 0.500541 seconds (3.96 M allocations: 717.689 MiB, 22.07% gc time)
hidden_dims=10: 1.879723 seconds (23.93 M allocations: 2.729 GiB, 15.48% gc time)
hidden_dims=20: 133.171544 seconds (133.74 M allocations: 11.729 GiB, 0.96% gc time, 95.70% compila
tion time)
hidden_dims=30: 675.982679 seconds (318.60 M allocations: 26.910 GiB, 0.38% gc time, 98.61% compilation time)
For higher numbers of hidden layers, the program crashes. My guess is that this line p_new=p-h*K'*act.(K*q.+b).*a is allocation heavy, but I am not sure if this is the issue.
I have trype AutoEnzyme instead of AutoZygote, but it error-ed out:
ERROR: LoadError: EnzymeRuntimeActivityError: Detected potential need for runtime activity.
Constant memory is stored (or returned) to a differentiable variable and correctness cannot be guaranteed with static activity analysis.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#faq-runtime-activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
a) rewrite this variable to not be conditionally active (fastest performance, slower to setup), or
b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Error cannot store inactive but differentiable variable [-0.6877661591839738 -0.7125916284799615 -0.680472569108694 -0.9161659367494549 0.03998933418663416 -0.15258690864856114 -0.726479760593413 -0.7663790266757843 -0.5878571033784827 0.02999550020249566 -0.9824526126243325 -0.9499842019607608 0.9938683634116449 0.9489846193555862 -0.858934493426592 0.9326150140222005 -0.053160236717356125 -0.5590049972802488 0.8304973704919705 -0.4954973729168449 -0.8904838085819885 -0.4603659148289983 -0.5254001251818793 0.5409722203769886 0.4631912649303452 0.9494856148646305 -0.5997473287940438 -0.6312666378723208 0.9134133613412252 0.990326804156158 0.32554933451756 0.29552020666133955 0.9396454736853249 -0.43353088275271773 0.5396320487339693 -0.008407247367148618 0.5141359916531132 -0.35525355998804264 -0.669239857276262 0.9246060124080203 -0.6156301052500863 -0.31758856607203484 0.7512804051402927 0.18885889497650057 -0.3645833414243013 -0.274824670323124 -0.9969897762717695 0.8257849931056082 0.5311861979208834 -0.918070474669267 -0.6197368435949633 -0.17232087571561025 0.74570521217672 -0.5754753801952172 0.361615431964962 -0.6984184692162136 -0.49095472496260095 -0.6839659518769007 -0.9834131875473108 -0.5082790774992584 -0.999292788975378 -0.9998449300020044 0.5728674601004813 0.5325349075556212 0.28595222510483553 0.8075581004051143 -0.45146575216142315 -0.3080905586823781 0.8468318446180152 0.45430566983030646 0.9092974268256817 -0.8615969003107405 0.9520903415905158 -0.9942155195492713 -0.8266821782320357 0.5971954413623921 -0.9936910036334645 0.15931820661424598 -0.9560397542711181 0.9639829961524481 -0.11294379406346737 -0.4063057021444168 0.9665943918332975 0.999783764189357 0.9612752029752999 0.9886517628517197 0.8191915683009983 -0.9719030694018208 0.7653549525292535 -0.12805476426637968 0.8632093666488737 -0.8640123164850744 -0.17746242484086058 -0.3971481672859602 -0.23107778829939224 -0.6766367361314569 -0.9742082498528091 0.7311458297268959 -0.2017901307561289 0.41687080242921076 -0.44680005240543 -0.999997146387718 0.6961352386273567 -0.02840352588360379 0.6668696350036979 -0.9986280112074989 0.9818535303723598 -0.8739082619290224 0.09146464223243675 0.04158066243329049 -0.14786317380431852 0.935616001553386 -0.999475522827284 0.12963414261969486 0.9580158602892249 -0.9983409441568876 -0.13796586727122728 -0.6581864701049049 0.9999417202299663 0.6530407515722648 -0.32705481486974064 0.9127639402605211 0.6222335553193046 0.6828032219306397 -0.9999710363300245 -0.7435791389442746 0.23770262642713458 0.10141798631660187] into active tuple
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Enzyme/1MkLT/src/rules/typeunstablerules.jl:20 [inlined]
[2] create_shadow_ret
@ ~/.julia/packages/Enzyme/1MkLT/src/rules/typeunstablerules.jl:3 [inlined]
[3] macro expansion
@ ~/.julia/packages/Enzyme/1MkLT/src/rules/typeunstablerules.jl:96 [inlined]
[4] runtime_tuple_augfwd(::Type{Val{(false, true)}}, ::Val{false}, ::Val{1}, ::Val{(true, true)}, ::Val{Any}, ::Matrix{Float64}, ::Nothing, ::Matrix{Float64}, ::Matrix{Float64})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/1MkLT/src/rules/typeunstablerules.jl:237
[5] GLayer
@ ~/Documents/my_gits/Julia codes/Sympnets/_research/minimal_GSymp.jl:45
[6] apply
@ ~/.julia/packages/LuxCore/qsnGJ/src/LuxCore.jl:155 [inlined]
[7] macro expansion
@ ~/.julia/packages/Lux/bpNXc/src/layers/containers.jl:0 [inlined]
[8] applychain
@ ~/.julia/packages/Lux/bpNXc/src/layers/containers.jl:570
[9] Chain
@ ~/.julia/packages/Lux/bpNXc/src/layers/containers.jl:568 [inlined]
[10] loss_fn
@ ~/Documents/my_gits/Julia codes/Sympnets/_research/minimal_GSymp.jl:52
[11] #5
@ ~/.julia/packages/Lux/bpNXc/src/helpers/training.jl:399 [inlined]
[12] augmented_julia__5_17809_inner_9wrap
@ ~/.julia/packages/Lux/bpNXc/src/helpers/training.jl:0
[13] macro expansion
@ ~/.julia/packages/Enzyme/1MkLT/src/compiler.jl:6652 [inlined]
[14] enzyme_call
@ ~/.julia/packages/Enzyme/1MkLT/src/compiler.jl:6131 [inlined]
[15] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/1MkLT/src/compiler.jl:6079 [inlined]
[16] autodiff
@ ~/.julia/packages/Enzyme/1MkLT/src/Enzyme.jl:412 [inlined]
[17] compute_gradients_impl
@ ~/.julia/packages/Lux/bpNXc/ext/LuxEnzymeExt/training.jl:17 [inlined]
[18] compute_gradients_impl_with_allocator_cache
@ ~/.julia/packages/Lux/bpNXc/src/helpers/training.jl:323 [inlined]
[19] #compute_gradients#1
@ ~/.julia/packages/Lux/bpNXc/src/helpers/training.jl:313 [inlined]
[20] compute_gradients
@ ~/.julia/packages/Lux/bpNXc/src/helpers/training.jl:311 [inlined]
[21] single_train_step_impl!
@ ~/.julia/packages/Lux/bpNXc/src/helpers/training.jl:500 [inlined]
[22] single_train_step_impl_with_allocator_cache!
@ ~/.julia/packages/Lux/bpNXc/src/helpers/training.jl:496 [inlined]
[23] #single_train_step!#7
@ ~/.julia/packages/Lux/bpNXc/src/helpers/training.jl:445 [inlined]
[24] train_model(::Tuple{Vector{Float64}, Vector{Float64}}, dt::Float64; seed::Int64, N_EPOCHS::Int64, hidden_dims::Int64)
@ Main ~/Documents/my_gits/Julia codes/Sympnets/_research/minimal_GSymp.jl:87
[25] macro expansion
@ ./timing.jl:581 [inlined]
[26] top-level scope
@ ~/Documents/my_gits/Julia codes/Sympnets/_research/minimal_GSymp.jl:315
in expression starting at /home/haroun/Documents/my_gits/Julia codes/Sympnets/_research/minimal_GSymp.jl:112
Edit: It seems like the train_model function is recompiled for each number of hidden layers. Is this expected? For example, after running it again with hidden_dims=20, it is much faster:
julia> @time trained_model = train_model((q_training, p_training), dt, N_EPO
CHS=1000, hidden_dims=20);
# first time: 138.925754 seconds (133.34 M allocations: 11.562 GiB, 0.96% gc time, 95.83%
compilation time)
# second time: 5.642650 seconds (74.47 M allocations: 8.149 GiB, 16.65% gc time)