I’m trying to run a slurm job using MPI, that involves a nested loop, where the outer loop is over a smaller list of size m
than the inner loop (size n
). For the inner loop, I want to use as many processes as possible nproc
→ n
. How can I make this work if n
>> m
and hence nproc
>> m
?
Below is my current implementation, that only works as long as nproc
<= m
.
using BSON
using CTExperiments
using CounterfactualExplanations
using DotEnv
using Logging
using MPI
using Serialization
using TaijaParallel
DotEnv.load!()
# Get config and set up grid:
config_file = get_config_from_args()
root_name = CTExperiments.from_toml(config_file)["name"]
root_save_dir = joinpath(ENV["OUTPUT_DIR"], root_name)
exper_grid = ExperimentGrid(config_file; new_save_dir=root_save_dir)
# Initialize MPI
MPI.Init()
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
nprocs = MPI.Comm_size(comm)
if MPI.Comm_rank(MPI.COMM_WORLD) != 0
global_logger(NullLogger())
exper_list = nothing
else
# Generate list of experiments and run them:
exper_list = setup_experiments(exper_grid)
@info "Running $(length(exper_list)) experiments ..."
end
# Broadcast exper_list from rank 0 to all ranks
exper_list = MPI.bcast(exper_list, comm; root=0)
MPI.Barrier(comm) # Ensure all processes reach this point before finishing
if length(exper_list) < nprocs
@warn "There are less experiments than processes. Check CPU efficiency of job."
end
chunks = TaijaParallel.split_obs(exper_list, nprocs) # split experiments into chunks for each process
worker_chunk = MPI.scatter(chunks, comm) # distribute across processes
for (i, experiment) in enumerate(worker_chunk)
if rank != 0
# Shut up logging for other ranks to avoid cluttering output
CTExperiments.shutup!(experiment.training_params)
end
# Setup:
_save_dir = experiment.meta_params.save_dir
_name = experiment.meta_params.experiment_name
# Skip if already finished
if has_results(experiment)
@info "Rank $(rank): Skipping $(_name), model already exists."
continue
end
# Running the experiment
@info "Rank $(rank): Running experiment: $(_name) ($i/$(length(worker_chunk)))"
println("Saving checkpoints in: ", _save_dir)
model, logs = run_training(experiment; checkpoint_dir=_save_dir)
# Saving the results:
save_results(experiment, model, logs)
end
# Finalize MPI
MPI.Barrier(comm) # Ensure all processes reach this point before finishing
If more processes are available than items contained in exper_list
below, then no error is thrown and the logs actually show the following messages, but model, logs = run_training(experiment; checkpoint_dir=_save_dir)
never executes.
[ Info: Running 9 experiments ...
┌ Warning: There are less experiments than processes. Check CPU efficiency of job.
└ @ Main /scratch/paltmeyer/code/CounterfactualTraining.jl/paper/experiments/run_grid.jl:43
[ Info: Rank 0: Running experiment: experiment_1 (1/1)
[ Info: Using `MPI.jl` for multi-processing.
Any help would be much appreciated!