I am using Flux (with this PR Fix missing imports in FluxMPIExt by Alexander-Barth · Pull Request #2589 · FluxML/Flux.jl · GitHub to fix some MPI related imports) on two AMD GPUs with
Distributed Data Parallel (DDP). However, the training fails at the 2nd iteration.
Here is a minimal reproducer:
serial = get(ENV,"SERIAL","false") == "true"
import AMDGPU
using Flux
using Optimisers
using Zygote
using Statistics
using Random
if !serial
import MPI
end
Random.seed!(42)
function pprintln(backend,args...)
MPI.Barrier(backend.comm)
print("rank ",DistributedUtils.local_rank(backend),": ")
println(args...)
end
pprintln(::Nothing,args...) = println(args...)
AMDGPU.allowscalar(false)
@show Flux.MPI_ROCM_AWARE
if !serial
const backend_type = MPIBackend
DistributedUtils.initialize(backend_type)
backend = DistributedUtils.get_distributed_backend(backend_type)
else
backend = nothing
end
T = Float32
device = gpu
x = randn(T,256,256,32,16*2) |> device
channels = 2 .^ vcat(5:7,6:-1:5)
model = Chain(
[Conv((3,3),channels[i] => channels[i+1],pad=SamePad(),selu) for i in 1:length(channels)-1]...
)
losses = T[]
model = model |> device
loss(x,y) = mean((x-y).^2)
opt_s = Optimisers.Adam(1f-4) # NaN in model weights at 2nd iteration
#opt_s = Optimisers.Descent(0.01f0) # ok
if !serial
data = DistributedUtils.DistributedDataContainer(backend, x)
model = DistributedUtils.synchronize!!(backend, DistributedUtils.FluxDistributedModel(model); root=0)
opt = DistributedUtils.DistributedOptimizer(backend, opt_s)
else
data = x
opt = opt_s
end
opt_state = Optimisers.setup(opt, model)
# NaN disapears if this is uncommented
# println("opt_state after setup ",opt_state.layers[1].weight.state[1][1:1])
if !serial
opt_state = DistributedUtils.synchronize!!(backend, opt_state; root=0)
end
dl = Flux.DataLoader(data,batchsize=16)
for i = 1:4
global model, opt_state
for (j,x_batch) in enumerate(dl)
val, grads = Flux.withgradient(model) do m
loss(x_batch,m(x_batch))
end
push!(losses, val)
opt_state, model = Optimisers.update(opt_state, model, grads[1])
pprintln(backend,"update ",i," ",model.layers[1].weight[1:1])
end
end
pprintln(backend,"losses ",losses)
Already at the 2nd mini-batch the model weights become NaN even with a very small learning rate (1e-6) with the Adam optimizer:
Flux.MPI_ROCM_AWARE = true
Flux.MPI_ROCM_AWARE = true
rank 1: update 1 Float32[0.005779413]
rank 0: update 1 Float32[0.005779413]
rank 1: update 2 Float32[NaN]
rank 0: update 2 Float32[0.0056868508]
rank 1: update 3 Float32[NaN]
rank 0: update 3 Float32[0.005596291]
rank 1: update 4 Float32[NaN]
rank 0: update 4 Float32[0.0055066617]
rank 1: lossesFloat32[2.0405662, NaN, NaN, NaN]
rank 0: lossesFloat32[2.040605, 2.0056882, 1.9737886, 1.9429569]
I do not have this issue when the code is run in serial (setting the env. variable SERIAL=true):
Flux.MPI_ROCM_AWARE = true
update 1 Float32[0.005779413]
update 1 Float32[0.0056795366]
update 2 Float32[0.0055798953]
update 2 Float32[0.0054805833]
update 3 Float32[0.005381727]
update 3 Float32[0.005283449]
update 4 Float32[0.0051858225]
update 4 Float32[0.0050889943]
losses Float32[2.040605, 2.0056605, 1.9714396, 1.9378068, 1.9048955, 1.872551, 1.8409457, 1.8098731]
or if I use the stateless Optimisers.Descent method.
Surprisingly, the NaN during the training also disappears when I print the optimizer’s state after Optimisers.Descent.
But even with this print function, and letting run the code for 100 iterations, the values of the losses diverge notably after some time (0.5730838 for rank = 0, and 1.1246178 for rank = 1).
rank 1: losses Float32[2.0405662, 2.0056431, 1.9737768, 1.9429692, 1.9128205, 1.8832046, 1.8541049, 1.8254741, 1.7973075, 1.7696016, 1.7423553, 1.7155664, 1.6892323, 1.6633472, 1.63791, 1.6129179, 1.5883675, 1.5642548, 1.5405741, 1.5173209, 1.4944896, 1.4720752, 1.4500718, 1.4285016, 1.4072559, 1.386411, 1.3659607, 1.3473692, 1.329339, 1.3116252, 1.2980983, 1.2847195, 1.2713451, 1.258033, 1.2468462, 1.2437806, 1.2413762, 1.2389867, 1.2366787, 1.2343965, 1.2320442, 1.2297226, 1.2274451, 1.2251523, 1.2229185, 1.2207301, 1.2185131, 1.2163403, 1.2142112, 1.2125139, 1.2104331, 1.2085547, 1.2068253, 1.2052276, 1.2037412, 1.201832, 1.200054, 1.1983342, 1.1966426, 1.1949832, 1.1933335, 1.1916914, 1.1900835, 1.1884861, 1.1868973, 1.1853313, 1.1837924, 1.1822832, 1.1806498, 1.1790379, 1.1774572, 1.1758871, 1.1743231, 1.1727754, 1.1712611, 1.1697775, 1.1681856, 1.1666263, 1.1646206, 1.1626871, 1.160771, 1.1589265, 1.1571369, 1.1554058, 1.1537917, 1.1519566, 1.1498919, 1.1479226, 1.1459588, 1.1438408, 1.1418104, 1.1397337, 1.1376653, 1.13569, 1.1338012, 1.1320505, 1.1300945, 1.1282241, 1.1264334, 1.1246178]
rank 0: losses Float32[2.040605, 2.0056882, 1.9737886, 1.9429556, 1.9127839, 1.8831487, 1.8540053, 1.8253373, 1.7971399, 1.7694099, 1.7421459, 1.7153467, 1.6890085, 1.6631256, 1.6376983, 1.6127218, 1.588191, 1.5641009, 1.5404449, 1.5172176, 1.4944139, 1.4720266, 1.450049, 1.4284747, 1.4072984, 1.386515, 1.3661177, 1.3460996, 1.3264556, 1.3071792, 1.288264, 1.2697039, 1.2514939, 1.2336284, 1.2161016, 1.198908, 1.1820419, 1.1654985, 1.149272, 1.1333566, 1.1177475, 1.1024392, 1.0874268, 1.072705, 1.0582687, 1.0441121, 1.0302305, 1.0166179, 1.0032698, 0.9901813, 0.97734797, 0.9647645, 0.95242596, 0.94032794, 0.92846537, 0.91683334, 0.9054272, 0.8942421, 0.8832733, 0.87251693, 0.8619681, 0.8516228, 0.84147656, 0.8315253, 0.8217648, 0.8121911, 0.8028003, 0.7935883, 0.7845508, 0.7756847, 0.76698613, 0.7584513, 0.7500768, 0.74185914, 0.7337947, 0.72588044, 0.7181125, 0.7104877, 0.70300275, 0.6956547, 0.6884402, 0.68135583, 0.674399, 0.6675668, 0.6608563, 0.65426487, 0.6477898, 0.6414286, 0.63517886, 0.6290378, 0.623003, 0.6170719, 0.611242, 0.6055112, 0.5998771, 0.5943377, 0.5888909, 0.58353436, 0.5782659, 0.5730838]
Is there something wrong with my julia code above?
If not, I am wondering if there is a synchronization issue between the different MPI processes or if some uninitialized memory is used.
My environment:
julia 1.11.2
⌃ [21141c5a] AMDGPU v1.2.2
[0a1fb500] BlockDiagonals v0.1.42
[052768ef] CUDA v5.6.1
⌃ [13f3f980] CairoMakie v0.12.18
⌃ [b0b7db55] ComponentArrays v0.15.22
[efc8151c] DIVAnd v2.7.12
[cf87cc76] DataAssim v0.4.1
[8bb1440f] DelimitedFiles v1.9.1
[4e2335b7] FlowMatching v0.1.0 `..`
[587475ba] Flux v0.16.3 `~/.julia/dev/Flux`
[db073c08] GeoMakie v0.7.10
[033835bb] JLD2 v0.5.11
[f1d291b0] MLUtils v0.4.7
[da04e1cc] MPI v0.20.22
[3da0fdf6] MPIPreferences v0.1.11
[85f8d34a] NCDatasets v0.14.6
[3bd65402] Optimisers v0.4.4
[21216c6a] Preferences v1.4.3
[10745b16] Statistics v1.11.1
⌃ [e88e6eb3] Zygote v0.7.3
[02a925ec] cuDNN v1.4.1
[ade2ca70] Dates v1.11.0
[de0858da] Printf v1.11.0
[8dfed614] Test v1.11.0
GPU: AMD INSTINCT MI200
MPI: ROCM-aware Cray MPI implementation (tested with ROCm-aware (AMDGPU) MPI multi-GPU test · GitHub)