Distributed Data Parallel training with 2 GPUs fails with Flux.jl on AMD GPUs

I am using Flux (with this PR Fix missing imports in FluxMPIExt by Alexander-Barth · Pull Request #2589 · FluxML/Flux.jl · GitHub to fix some MPI related imports) on two AMD GPUs with
Distributed Data Parallel (DDP). However, the training fails at the 2nd iteration.

Here is a minimal reproducer:

serial = get(ENV,"SERIAL","false") == "true"

import AMDGPU
using Flux
using Optimisers
using Zygote
using Statistics
using Random
if !serial
    import MPI
end

Random.seed!(42)

function pprintln(backend,args...)
    MPI.Barrier(backend.comm)
    print("rank ",DistributedUtils.local_rank(backend),": ")
    println(args...)
end
pprintln(::Nothing,args...) = println(args...)

AMDGPU.allowscalar(false)

@show Flux.MPI_ROCM_AWARE

if !serial
    const backend_type = MPIBackend
    DistributedUtils.initialize(backend_type)
    backend = DistributedUtils.get_distributed_backend(backend_type)
else
    backend = nothing
end

T = Float32
device = gpu

x = randn(T,256,256,32,16*2) |> device

channels = 2 .^ vcat(5:7,6:-1:5)

model = Chain(
    [Conv((3,3),channels[i] => channels[i+1],pad=SamePad(),selu) for i in 1:length(channels)-1]...
)

losses = T[]
model = model |> device

loss(x,y) = mean((x-y).^2)

opt_s = Optimisers.Adam(1f-4) # NaN in model weights at 2nd iteration
#opt_s = Optimisers.Descent(0.01f0) # ok

if !serial
    data = DistributedUtils.DistributedDataContainer(backend, x)
    model = DistributedUtils.synchronize!!(backend, DistributedUtils.FluxDistributedModel(model); root=0)
    opt = DistributedUtils.DistributedOptimizer(backend, opt_s)
else
    data = x
    opt = opt_s
end

opt_state = Optimisers.setup(opt, model)

# NaN disapears if this is uncommented
# println("opt_state after setup ",opt_state.layers[1].weight.state[1][1:1])

if !serial
    opt_state = DistributedUtils.synchronize!!(backend, opt_state; root=0)
end

dl = Flux.DataLoader(data,batchsize=16)

for i = 1:4
    global model, opt_state
    for (j,x_batch) in enumerate(dl)
        val, grads = Flux.withgradient(model) do m
            loss(x_batch,m(x_batch))
        end

        push!(losses, val)        
        opt_state, model = Optimisers.update(opt_state, model, grads[1])

        pprintln(backend,"update ",i," ",model.layers[1].weight[1:1])
    end
end

pprintln(backend,"losses ",losses)

Already at the 2nd mini-batch the model weights become NaN even with a very small learning rate (1e-6) with the Adam optimizer:

Flux.MPI_ROCM_AWARE = true
Flux.MPI_ROCM_AWARE = true
rank 1: update 1 Float32[0.005779413]
rank 0: update 1 Float32[0.005779413]
rank 1: update 2 Float32[NaN]
rank 0: update 2 Float32[0.0056868508]
rank 1: update 3 Float32[NaN]
rank 0: update 3 Float32[0.005596291]
rank 1: update 4 Float32[NaN]
rank 0: update 4 Float32[0.0055066617]
rank 1: lossesFloat32[2.0405662, NaN, NaN, NaN]
rank 0: lossesFloat32[2.040605, 2.0056882, 1.9737886, 1.9429569]

I do not have this issue when the code is run in serial (setting the env. variable SERIAL=true):

Flux.MPI_ROCM_AWARE = true
update 1 Float32[0.005779413]
update 1 Float32[0.0056795366]
update 2 Float32[0.0055798953]
update 2 Float32[0.0054805833]
update 3 Float32[0.005381727]
update 3 Float32[0.005283449]
update 4 Float32[0.0051858225]
update 4 Float32[0.0050889943]
losses Float32[2.040605, 2.0056605, 1.9714396, 1.9378068, 1.9048955, 1.872551, 1.8409457, 1.8098731]

or if I use the stateless Optimisers.Descent method.
Surprisingly, the NaN during the training also disappears when I print the optimizer’s state after Optimisers.Descent.

But even with this print function, and letting run the code for 100 iterations, the values of the losses diverge notably after some time (0.5730838 for rank = 0, and 1.1246178 for rank = 1).

rank 1: losses Float32[2.0405662, 2.0056431, 1.9737768, 1.9429692, 1.9128205, 1.8832046, 1.8541049, 1.8254741, 1.7973075, 1.7696016, 1.7423553, 1.7155664, 1.6892323, 1.6633472, 1.63791, 1.6129179, 1.5883675, 1.5642548, 1.5405741, 1.5173209, 1.4944896, 1.4720752, 1.4500718, 1.4285016, 1.4072559, 1.386411, 1.3659607, 1.3473692, 1.329339, 1.3116252, 1.2980983, 1.2847195, 1.2713451, 1.258033, 1.2468462, 1.2437806, 1.2413762, 1.2389867, 1.2366787, 1.2343965, 1.2320442, 1.2297226, 1.2274451, 1.2251523, 1.2229185, 1.2207301, 1.2185131, 1.2163403, 1.2142112, 1.2125139, 1.2104331, 1.2085547, 1.2068253, 1.2052276, 1.2037412, 1.201832, 1.200054, 1.1983342, 1.1966426, 1.1949832, 1.1933335, 1.1916914, 1.1900835, 1.1884861, 1.1868973, 1.1853313, 1.1837924, 1.1822832, 1.1806498, 1.1790379, 1.1774572, 1.1758871, 1.1743231, 1.1727754, 1.1712611, 1.1697775, 1.1681856, 1.1666263, 1.1646206, 1.1626871, 1.160771, 1.1589265, 1.1571369, 1.1554058, 1.1537917, 1.1519566, 1.1498919, 1.1479226, 1.1459588, 1.1438408, 1.1418104, 1.1397337, 1.1376653, 1.13569, 1.1338012, 1.1320505, 1.1300945, 1.1282241, 1.1264334, 1.1246178]

rank 0: losses Float32[2.040605, 2.0056882, 1.9737886, 1.9429556, 1.9127839, 1.8831487, 1.8540053, 1.8253373, 1.7971399, 1.7694099, 1.7421459, 1.7153467, 1.6890085, 1.6631256, 1.6376983, 1.6127218, 1.588191, 1.5641009, 1.5404449, 1.5172176, 1.4944139, 1.4720266, 1.450049, 1.4284747, 1.4072984, 1.386515, 1.3661177, 1.3460996, 1.3264556, 1.3071792, 1.288264, 1.2697039, 1.2514939, 1.2336284, 1.2161016, 1.198908, 1.1820419, 1.1654985, 1.149272, 1.1333566, 1.1177475, 1.1024392, 1.0874268, 1.072705, 1.0582687, 1.0441121, 1.0302305, 1.0166179, 1.0032698, 0.9901813, 0.97734797, 0.9647645, 0.95242596, 0.94032794, 0.92846537, 0.91683334, 0.9054272, 0.8942421, 0.8832733, 0.87251693, 0.8619681, 0.8516228, 0.84147656, 0.8315253, 0.8217648, 0.8121911, 0.8028003, 0.7935883, 0.7845508, 0.7756847, 0.76698613, 0.7584513, 0.7500768, 0.74185914, 0.7337947, 0.72588044, 0.7181125, 0.7104877, 0.70300275, 0.6956547, 0.6884402, 0.68135583, 0.674399, 0.6675668, 0.6608563, 0.65426487, 0.6477898, 0.6414286, 0.63517886, 0.6290378, 0.623003, 0.6170719, 0.611242, 0.6055112, 0.5998771, 0.5943377, 0.5888909, 0.58353436, 0.5782659, 0.5730838]

Is there something wrong with my julia code above?
If not, I am wondering if there is a synchronization issue between the different MPI processes or if some uninitialized memory is used.

My environment:

julia 1.11.2

⌃ [21141c5a] AMDGPU v1.2.2
  [0a1fb500] BlockDiagonals v0.1.42
  [052768ef] CUDA v5.6.1
⌃ [13f3f980] CairoMakie v0.12.18
⌃ [b0b7db55] ComponentArrays v0.15.22
  [efc8151c] DIVAnd v2.7.12
  [cf87cc76] DataAssim v0.4.1
  [8bb1440f] DelimitedFiles v1.9.1
  [4e2335b7] FlowMatching v0.1.0 `..`
  [587475ba] Flux v0.16.3 `~/.julia/dev/Flux`
  [db073c08] GeoMakie v0.7.10
  [033835bb] JLD2 v0.5.11
  [f1d291b0] MLUtils v0.4.7
  [da04e1cc] MPI v0.20.22
  [3da0fdf6] MPIPreferences v0.1.11
  [85f8d34a] NCDatasets v0.14.6
  [3bd65402] Optimisers v0.4.4
  [21216c6a] Preferences v1.4.3
  [10745b16] Statistics v1.11.1
⌃ [e88e6eb3] Zygote v0.7.3
  [02a925ec] cuDNN v1.4.1
  [ade2ca70] Dates v1.11.0
  [de0858da] Printf v1.11.0
  [8dfed614] Test v1.11.0

GPU: AMD INSTINCT MI200
MPI: ROCM-aware Cray MPI implementation (tested with ROCm-aware (AMDGPU) MPI multi-GPU test · GitHub)

1 Like

Hi!

Unlikely that it will help, but can you try with Zygote@0.7.4?

Surprisingly, the NaN during the training also disappears when I print the optimizer’s state after Optimisers.Descent .

This looks like a synchronization issue. Can you also add explicit AMDGPU.synchronize(), instead of printing

2 Likes

Also, do the distributed Flux.jl tests pass for you?

Updating to Zygote@0.7.4 did not change, but indeed using AMDGPU.synchronize() instead of printing also avoids the NaN issue. The loss values still diverge, unfortunately. I just tried Lux.jl and there the loss value remain consistent.

As far as I can tell, there are no Flux distributed tests currently, but the one that I added in the PR passes.

There are some tests. Can you try uncommenting this line and see if it passes?

The test script looks for files endings with _distributedtest.jl:

But there are none currently in Flux.jl/test/ext_distributed at 9147e84bb1cde1dd1e789107422bb98bfd8a07b9 · FluxML/Flux.jl · GitHub

With ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true", one simply gets:

Test Summary: | Total   Time
Flux.jl       |     0  25.3s
  Distributed |     0   0.9s
     Testing Flux tests passed 

However, calling AMDGPU.synchronize() before calling Optimisers.update, one gets now consistent loss values! (:tada:)

rank 0: losses Float32[2.040605, 2.0056882, 1.971421, 1.9378042, 1.9048386, 1.8725222, 1.840851, 1.8098198, 1.7794223, 1.7496514, 1.7204987, 1.6919537, 1.6640061, 1.6366463, 1.609862, 1.5836416, 1.5579731, 1.5328445, 1.5082449, 1.4841626, 1.4605856, 1.4375019, 1.414902, 1.3927755, 1.3711121, 1.3499002, 1.3291311, 1.3087943, 1.2888799, 1.2693791, 1.2502838, 1.2315855, 1.2132759, 1.1953462, 1.1777894, 1.1605976, 1.1437626, 1.1272769, 1.1111336, 1.0953257, 1.079846, 1.0646877, 1.0498438, 1.0353076, 1.0210719, 1.0071306, 0.99347717, 0.9801054, 0.96700954, 0.9541833, 0.9416207, 0.9293158, 0.9172623, 0.9054549, 0.89388776, 0.88255584, 0.87145394, 0.86057615, 0.84991765, 0.8394736, 0.8292386, 0.81920826, 0.8093777, 0.7997422, 0.79029727, 0.78103864, 0.77196205, 0.7630632, 0.75433797, 0.7457824, 0.73739254, 0.7291644, 0.721094, 0.71317744, 0.7054113, 0.697792, 0.69031584, 0.68297946, 0.67577934, 0.66871244, 0.6617756, 0.65496576, 0.6482798, 0.64171505, 0.63526875, 0.6289377, 0.62271947, 0.6166111, 0.6106102, 0.60471416, 0.5989205, 0.59322685, 0.58763087, 0.58212996, 0.57672226, 0.5714054, 0.56617725, 0.56103563, 0.5559785, 0.5510039]
rank 1: losses Float32[2.0405662, 2.005643, 1.97137, 1.9377482, 1.904777, 1.8724544, 1.8407779, 1.8097407, 1.7793366, 1.7495592, 1.7203997, 1.6918478, 1.6638947, 1.6365283, 1.6097383, 1.5835133, 1.5578407, 1.5327085, 1.5081055, 1.4840207, 1.4604423, 1.4373589, 1.4147593, 1.3926326, 1.3709689, 1.3497577, 1.3289893, 1.308654, 1.2887421, 1.2692459, 1.250155, 1.2314608, 1.213156, 1.1952326, 1.1776822, 1.160497, 1.1436689, 1.1271902, 1.111054, 1.095253, 1.0797796, 1.0646269, 1.0497882, 1.0352571, 1.0210273, 1.0070922, 0.99344563, 0.9800808, 0.9669918, 0.9541714, 0.94161445, 0.92931473, 0.91726625, 0.90546346, 0.893901, 0.8825737, 0.87147635, 0.8606035, 0.8499501, 0.839511, 0.82928115, 0.819256, 0.80943066, 0.7998004, 0.7903605, 0.7811067, 0.7720345, 0.76313984, 0.75441825, 0.745866, 0.7374791, 0.72925365, 0.7211859, 0.71327186, 0.705508, 0.6978909, 0.690417, 0.68308264, 0.6758847, 0.66881984, 0.6618851, 0.6550772, 0.6483929, 0.64182967, 0.63538444, 0.6290543, 0.6228367, 0.616729, 0.6107286, 0.60483295, 0.5990397, 0.59334624, 0.5877504, 0.58224976, 0.5768423, 0.5715256, 0.56629753, 0.56115603, 0.5560991, 0.5511245]
2 Likes

AMDGPU.synchronize() shouldn’t be needed in the normal case. Can you please open an issue on Flux.jl describing that synchronization is needed for things to work correctly?
So that we can keep track of things.

Good idea, I just opened an issue here: ROCM-Aware MPI requires AMDGPU.synchronize() · Issue #2591 · FluxML/Flux.jl · GitHub