I am facing two separate issues. One is related to setting Reactant initialization options to improve performance and the other is about avoiding unnecessary latencies in neural network training loop via compilation.
Note that my code works normally without additional options that I am trying here to tweak to improve performance. Used cpu is 9900x (12 core/24 threads) and gpu is 5070 Ti (Linux, Julia 1.11.9).
Reactant initialization options
- When run with gpu backend, XLA allocates 75% of device vram regardless of how much it has. I have tried several things to disable it, but non worked:
using Preferences, UUIDs
Preferences.set_preferences!(
UUID("3c362404-f566-11ee-1572-e11a4b42c853"),
"xla_gpu_fraction" => 0.4
)
using Lux, Random, Optimisers, Reactant, Enzyme, MLUtils, ProgressMeter, XLSX
another try:
using Preferences, UUIDs
ENV["XLA_FLAGS"] = "--xla_gpu_fraction=0.4"
using Lux, Random, Optimisers, Reactant, Enzyme, MLUtils, ProgressMeter, XLSX
Preferences approach has no effect and ENV approach freezes Julia. Both approaches when run with "XLA_PYTHON_CLIENT_PREALLOCATE" set to false did not have effect either.
- With cpu backend, the training uses only 6 threads resulting in 25% cpu usage as observed in system monitor. I do not know how change that. Setting Julia threads has no impact on this. Both
Reactant.addressable_devices()andReactant.devices()output:
1-element Vector{Reactant.XLA.PJRT.Device}:
Reactant.XLA.PJRT.Device(Ptr{Nothing} @0x000000002ac8ee50, "CPU:0 cpu")
Avoid train latencies by compiling
Following several tutorials from Lux examples on homepage, initially I got to following form:
Training part:
total_loss = ConcretePJRTNumber{Float32, 1}(0.0f0)
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
AutoReactant(), MSELoss(agg=sum), (x, y), train_state
)
total_loss += loss
end
Validation part:
total_val_loss = 0.0f0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_val_loss += sum(abs2, ŷ .- y)
end
Both accumulation of total_loss in train part and cdev() and total_val_loss in validation part slow down computation for about 20% each for my test case. I managed to resolve validation segment (20% loss is basically fully annulled).
Compile:
function val_step(x, y, p, s, tl)
ŷ, s = model(x, p, s)
loss = sum(abs2, ŷ .- y)
return tl + loss, s
end
val_step_compiled = @compile val_step(first(val_loader)[1], first(val_loader)[2], ps, Lux.testmode(st), ConcretePJRTNumber{Float32, 1}(0.0f0))
Loop:
total_val_loss = ConcretePJRTNumber{Float32, 1}(0.0f0)
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
total_val_loss, st_ = val_step_compiled(x, y, train_state.parameters, st_, total_val_loss)
end
While this works for validation, I do not know how to do it for training part as single_train_step! can not be put in @compile as it has some device wrapping and compilation already. Also compute_gradients can not be used in this way.