Distributed Data Parallel training with 2 GPUs fails with Flux.jl on AMD GPUs

The test script looks for files endings with _distributedtest.jl:

But there are none currently in Flux.jl/test/ext_distributed at 9147e84bb1cde1dd1e789107422bb98bfd8a07b9 · FluxML/Flux.jl · GitHub

With ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true", one simply gets:

Test Summary: | Total   Time
Flux.jl       |     0  25.3s
  Distributed |     0   0.9s
     Testing Flux tests passed 

However, calling AMDGPU.synchronize() before calling Optimisers.update, one gets now consistent loss values! (:tada:)

rank 0: losses Float32[2.040605, 2.0056882, 1.971421, 1.9378042, 1.9048386, 1.8725222, 1.840851, 1.8098198, 1.7794223, 1.7496514, 1.7204987, 1.6919537, 1.6640061, 1.6366463, 1.609862, 1.5836416, 1.5579731, 1.5328445, 1.5082449, 1.4841626, 1.4605856, 1.4375019, 1.414902, 1.3927755, 1.3711121, 1.3499002, 1.3291311, 1.3087943, 1.2888799, 1.2693791, 1.2502838, 1.2315855, 1.2132759, 1.1953462, 1.1777894, 1.1605976, 1.1437626, 1.1272769, 1.1111336, 1.0953257, 1.079846, 1.0646877, 1.0498438, 1.0353076, 1.0210719, 1.0071306, 0.99347717, 0.9801054, 0.96700954, 0.9541833, 0.9416207, 0.9293158, 0.9172623, 0.9054549, 0.89388776, 0.88255584, 0.87145394, 0.86057615, 0.84991765, 0.8394736, 0.8292386, 0.81920826, 0.8093777, 0.7997422, 0.79029727, 0.78103864, 0.77196205, 0.7630632, 0.75433797, 0.7457824, 0.73739254, 0.7291644, 0.721094, 0.71317744, 0.7054113, 0.697792, 0.69031584, 0.68297946, 0.67577934, 0.66871244, 0.6617756, 0.65496576, 0.6482798, 0.64171505, 0.63526875, 0.6289377, 0.62271947, 0.6166111, 0.6106102, 0.60471416, 0.5989205, 0.59322685, 0.58763087, 0.58212996, 0.57672226, 0.5714054, 0.56617725, 0.56103563, 0.5559785, 0.5510039]
rank 1: losses Float32[2.0405662, 2.005643, 1.97137, 1.9377482, 1.904777, 1.8724544, 1.8407779, 1.8097407, 1.7793366, 1.7495592, 1.7203997, 1.6918478, 1.6638947, 1.6365283, 1.6097383, 1.5835133, 1.5578407, 1.5327085, 1.5081055, 1.4840207, 1.4604423, 1.4373589, 1.4147593, 1.3926326, 1.3709689, 1.3497577, 1.3289893, 1.308654, 1.2887421, 1.2692459, 1.250155, 1.2314608, 1.213156, 1.1952326, 1.1776822, 1.160497, 1.1436689, 1.1271902, 1.111054, 1.095253, 1.0797796, 1.0646269, 1.0497882, 1.0352571, 1.0210273, 1.0070922, 0.99344563, 0.9800808, 0.9669918, 0.9541714, 0.94161445, 0.92931473, 0.91726625, 0.90546346, 0.893901, 0.8825737, 0.87147635, 0.8606035, 0.8499501, 0.839511, 0.82928115, 0.819256, 0.80943066, 0.7998004, 0.7903605, 0.7811067, 0.7720345, 0.76313984, 0.75441825, 0.745866, 0.7374791, 0.72925365, 0.7211859, 0.71327186, 0.705508, 0.6978909, 0.690417, 0.68308264, 0.6758847, 0.66881984, 0.6618851, 0.6550772, 0.6483929, 0.64182967, 0.63538444, 0.6290543, 0.6228367, 0.616729, 0.6107286, 0.60483295, 0.5990397, 0.59334624, 0.5877504, 0.58224976, 0.5768423, 0.5715256, 0.56629753, 0.56115603, 0.5560991, 0.5511245]
2 Likes