Multiple GPUs - One GPU per Process - Only one GPU works on four

hey, After checking nvidia-smi command, only one GPU works on four when running this code. It seems that the processes are not distributed to the workers, Am I missing something?

using Plots
using Flux:params
using Flux

# spawn one worker per device
using Distributed, CUDA
@everywhere using CUDA

# assign devices
asyncmap((zip(workers(), devices()))) do (p, d)
    remotecall_wait(p) do
        @info "Worker $p uses $d"

# Define the model
m = Chain(
  Dense(10, 5,  σ ),  # first layer with 10 inputs and 5 outputs, using the sigmoid function as the activation
  Dense(5, 1),  # second layer with 5 inputs and 1 output
  identity  # identity function as the activation for the output layer
) |> gpu

# Define a loss function and an optimizer
loss(x, y) = Flux.mse(m(x), y)
optimizer = ADAM()

# Generate some synthetic data
X = Array{Float64}(rand(10, 1000)) |> gpu
Y = Array{Float64}(rand(1, 1000)) |> gpu

data = [(X, Y)]  

# Train the model
@time CUDA.allowscalar for i in 1:600
  Flux.train!(loss, params(m), [(X, Y)], optimizer)

# Make predictions on the training data
predictions = m(X)

# Plot the predictions versus the true values
scatter(Y, predictions, xlabel="True values", ylabel="Predictions", primary = false)

From worker 3: [ Info: Worker 3 uses CuDevice(1)
From worker 4: [ Info: Worker 4 uses CuDevice(2)
From worker 2: [ Info: Worker 2 uses CuDevice(0)
From worker 5: [ Info: Worker 5 uses CuDevice(3)
6.436934 seconds (9.83 M allocations: 684.209 MiB, 3.08% gc time, 1.01% compilation time: 100% of which was recompilation)

I think you are only running your script in the main process and not running anything in the worker processes. Anything in the main script by default will run on the main process (you can check by using myid). You have to use functions like remotecall_wait/pmap etc to run code on the other processes, the main process is usually only used to coordinate the work between the workers.

EDIT: You may have to use FluxMPI.jl to use all of the GPUs.

