I am using Flux (with this PR Fix missing imports in FluxMPIExt by Alexander-Barth · Pull Request #2589 · FluxML/Flux.jl · GitHub to fix some MPI related imports) on two AMD GPUs with
Distributed Data Parallel (DDP). However, the training fails at the 2nd iteration.
Here is a minimal reproducer:
serial = get(ENV,"SERIAL","false") == "true"
import AMDGPU
using Flux
using Optimisers
using Zygote
using Statistics
using Random
if !serial
import MPI
end
Random.seed!(42)
function pprintln(backend,args...)
MPI.Barrier(backend.comm)
print("rank ",DistributedUtils.local_rank(backend),": ")
println(args...)
end
pprintln(::Nothing,args...) = println(args...)
AMDGPU.allowscalar(false)
@show Flux.MPI_ROCM_AWARE
if !serial
const backend_type = MPIBackend
DistributedUtils.initialize(backend_type)
backend = DistributedUtils.get_distributed_backend(backend_type)
else
backend = nothing
end
T = Float32
device = gpu
x = randn(T,256,256,32,16*2) |> device
channels = 2 .^ vcat(5:7,6:-1:5)
model = Chain(
[Conv((3,3),channels[i] => channels[i+1],pad=SamePad(),selu) for i in 1:length(channels)-1]...
)
losses = T[]
model = model |> device
loss(x,y) = mean((x-y).^2)
opt_s = Optimisers.Adam(1f-4) # NaN in model weights at 2nd iteration
#opt_s = Optimisers.Descent(0.01f0) # ok
if !serial
data = DistributedUtils.DistributedDataContainer(backend, x)
model = DistributedUtils.synchronize!!(backend, DistributedUtils.FluxDistributedModel(model); root=0)
opt = DistributedUtils.DistributedOptimizer(backend, opt_s)
else
data = x
opt = opt_s
end
opt_state = Optimisers.setup(opt, model)
# NaN disapears if this is uncommented
# println("opt_state after setup ",opt_state.layers[1].weight.state[1][1:1])
if !serial
opt_state = DistributedUtils.synchronize!!(backend, opt_state; root=0)
end
dl = Flux.DataLoader(data,batchsize=16)
for i = 1:4
global model, opt_state
for (j,x_batch) in enumerate(dl)
val, grads = Flux.withgradient(model) do m
loss(x_batch,m(x_batch))
end
push!(losses, val)
opt_state, model = Optimisers.update(opt_state, model, grads[1])
pprintln(backend,"update ",i," ",model.layers[1].weight[1:1])
end
end
pprintln(backend,"losses ",losses)
Already at the 2nd mini-batch the model weights become NaN even with a very small learning rate (1e-6) with the Adam optimizer:
Flux.MPI_ROCM_AWARE = true
Flux.MPI_ROCM_AWARE = true
rank 1: update 1 Float32[0.005779413]
rank 0: update 1 Float32[0.005779413]
rank 1: update 2 Float32[NaN]
rank 0: update 2 Float32[0.0056868508]
rank 1: update 3 Float32[NaN]
rank 0: update 3 Float32[0.005596291]
rank 1: update 4 Float32[NaN]
rank 0: update 4 Float32[0.0055066617]
rank 1: lossesFloat32[2.0405662, NaN, NaN, NaN]
rank 0: lossesFloat32[2.040605, 2.0056882, 1.9737886, 1.9429569]
I do not have this issue when the code is run in serial (setting the env. variable SERIAL=true
):
Flux.MPI_ROCM_AWARE = true
update 1 Float32[0.005779413]
update 1 Float32[0.0056795366]
update 2 Float32[0.0055798953]
update 2 Float32[0.0054805833]
update 3 Float32[0.005381727]
update 3 Float32[0.005283449]
update 4 Float32[0.0051858225]
update 4 Float32[0.0050889943]
losses Float32[2.040605, 2.0056605, 1.9714396, 1.9378068, 1.9048955, 1.872551, 1.8409457, 1.8098731]
or if I use the stateless Optimisers.Descent
method.
Surprisingly, the NaN during the training also disappears when I print the optimizer’s state after Optimisers.Descent
.
But even with this print function, and letting run the code for 100 iterations, the values of the losses diverge notably after some time (0.5730838 for rank = 0, and 1.1246178 for rank = 1).
rank 1: losses Float32[2.0405662, 2.0056431, 1.9737768, 1.9429692, 1.9128205, 1.8832046, 1.8541049, 1.8254741, 1.7973075, 1.7696016, 1.7423553, 1.7155664, 1.6892323, 1.6633472, 1.63791, 1.6129179, 1.5883675, 1.5642548, 1.5405741, 1.5173209, 1.4944896, 1.4720752, 1.4500718, 1.4285016, 1.4072559, 1.386411, 1.3659607, 1.3473692, 1.329339, 1.3116252, 1.2980983, 1.2847195, 1.2713451, 1.258033, 1.2468462, 1.2437806, 1.2413762, 1.2389867, 1.2366787, 1.2343965, 1.2320442, 1.2297226, 1.2274451, 1.2251523, 1.2229185, 1.2207301, 1.2185131, 1.2163403, 1.2142112, 1.2125139, 1.2104331, 1.2085547, 1.2068253, 1.2052276, 1.2037412, 1.201832, 1.200054, 1.1983342, 1.1966426, 1.1949832, 1.1933335, 1.1916914, 1.1900835, 1.1884861, 1.1868973, 1.1853313, 1.1837924, 1.1822832, 1.1806498, 1.1790379, 1.1774572, 1.1758871, 1.1743231, 1.1727754, 1.1712611, 1.1697775, 1.1681856, 1.1666263, 1.1646206, 1.1626871, 1.160771, 1.1589265, 1.1571369, 1.1554058, 1.1537917, 1.1519566, 1.1498919, 1.1479226, 1.1459588, 1.1438408, 1.1418104, 1.1397337, 1.1376653, 1.13569, 1.1338012, 1.1320505, 1.1300945, 1.1282241, 1.1264334, 1.1246178]
rank 0: losses Float32[2.040605, 2.0056882, 1.9737886, 1.9429556, 1.9127839, 1.8831487, 1.8540053, 1.8253373, 1.7971399, 1.7694099, 1.7421459, 1.7153467, 1.6890085, 1.6631256, 1.6376983, 1.6127218, 1.588191, 1.5641009, 1.5404449, 1.5172176, 1.4944139, 1.4720266, 1.450049, 1.4284747, 1.4072984, 1.386515, 1.3661177, 1.3460996, 1.3264556, 1.3071792, 1.288264, 1.2697039, 1.2514939, 1.2336284, 1.2161016, 1.198908, 1.1820419, 1.1654985, 1.149272, 1.1333566, 1.1177475, 1.1024392, 1.0874268, 1.072705, 1.0582687, 1.0441121, 1.0302305, 1.0166179, 1.0032698, 0.9901813, 0.97734797, 0.9647645, 0.95242596, 0.94032794, 0.92846537, 0.91683334, 0.9054272, 0.8942421, 0.8832733, 0.87251693, 0.8619681, 0.8516228, 0.84147656, 0.8315253, 0.8217648, 0.8121911, 0.8028003, 0.7935883, 0.7845508, 0.7756847, 0.76698613, 0.7584513, 0.7500768, 0.74185914, 0.7337947, 0.72588044, 0.7181125, 0.7104877, 0.70300275, 0.6956547, 0.6884402, 0.68135583, 0.674399, 0.6675668, 0.6608563, 0.65426487, 0.6477898, 0.6414286, 0.63517886, 0.6290378, 0.623003, 0.6170719, 0.611242, 0.6055112, 0.5998771, 0.5943377, 0.5888909, 0.58353436, 0.5782659, 0.5730838]
Is there something wrong with my julia code above?
If not, I am wondering if there is a synchronization issue between the different MPI processes or if some uninitialized memory is used.
My environment:
julia 1.11.2
⌃ [21141c5a] AMDGPU v1.2.2
[0a1fb500] BlockDiagonals v0.1.42
[052768ef] CUDA v5.6.1
⌃ [13f3f980] CairoMakie v0.12.18
⌃ [b0b7db55] ComponentArrays v0.15.22
[efc8151c] DIVAnd v2.7.12
[cf87cc76] DataAssim v0.4.1
[8bb1440f] DelimitedFiles v1.9.1
[4e2335b7] FlowMatching v0.1.0 `..`
[587475ba] Flux v0.16.3 `~/.julia/dev/Flux`
[db073c08] GeoMakie v0.7.10
[033835bb] JLD2 v0.5.11
[f1d291b0] MLUtils v0.4.7
[da04e1cc] MPI v0.20.22
[3da0fdf6] MPIPreferences v0.1.11
[85f8d34a] NCDatasets v0.14.6
[3bd65402] Optimisers v0.4.4
[21216c6a] Preferences v1.4.3
[10745b16] Statistics v1.11.1
⌃ [e88e6eb3] Zygote v0.7.3
[02a925ec] cuDNN v1.4.1
[ade2ca70] Dates v1.11.0
[de0858da] Printf v1.11.0
[8dfed614] Test v1.11.0
GPU: AMD INSTINCT MI200
MPI: ROCM-aware Cray MPI implementation (tested with ROCm-aware (AMDGPU) MPI multi-GPU test · GitHub)