Save model when training with Distributed Data Parallel

Hey all,
I want to parallelize my machine learning model on GPUs using DDP from Flux.
I get the code to run with adapting what is written on the Flux GPU Support page, but when I try to save the model I get the error
unsafe_store! at ./pointer.jl:146 [inlined]
I execute the code with mpiexec.
Now the question is: How do you save the model to disk?

Here is the main() of my code, reduced to the most essential things for better readibility (not a MWE)

#the model is a standard UNET architecture that is evaluated sparsely, i.e. only where a MASK .== 1

loss(model, x, y, mask) = Flux.Losses.mse(model(x)[mask], y[mask])

function main(unet, arg)
    model = unet |> gpu
    model = DistributedUtils.synchronize!!(backend, DistributedUtils.FluxDistributedModel(model); root=0)
    opt = DistributedUtils.DistributedOptimizer(backend, Optimisers.Adam())
    st_opt = Optimisers.setup(opt, model)
    st_opt = DistributedUtils.synchronize!!(backend, st_opt; root=0)


    train_years = [2015, 2016, 2017]
    path = "/path/to/files"


    losses = Float32[]
    val_filelist = readdir(join([path, "2018"]), join = true)
    @info "start training"

  
    for epoch in 1:10
        for yyyy in train_years
            filelist = readdir(join([path, yyyy]), join = true)
            #bc the dataset is way larger than memory I have to iterate over individual files
            for f in filelist
                X = jldopen(f)["x_norm"]
                Y = jldopen(f)["y"]
                MASK = jldopen(f)["mask"]
                #drop the last samples from each file to prevent the last batch from having less samples than the number of GPUs, 6
                nbatch = size(X)[end] ÷ 6
                dataset = (X[:, :, :, :, 1:6*nbatch], Y[:, :, :, :, 1:6*nbatch], MASK[:, :, :, :, 1:6*nbatch])
                dataset = DistributedUtils.DistributedDataContainer(backend, dataset)

                train_loader = MLUtils.DataLoader(dataset; batchsize=1, partial=false, collate=true, shuffle=true, parallel=true) |> gpu

                for (x, y, mask) in train_loader
                    l, grad = Flux.withgradient(model) do m
                        loss(m, x, y, mask)
                    end
                    st_opt, model = Optimisers.update!(st_opt, model, grad[1])
                    
                end
                GC.gc(false)
                CUDA.reclaim()

###### validation after each training file########
                batch_loss = Float32[]
                for f_val in val_filelist
                    x_val, y_val, mask_val = load_testdata(f_val)
                    n_val = size(x_val)[end] ÷ 6
                    val_dataset = (x_val[:, :, :, :, 1:6*n_val], y_val[:, :, :, :, 1:6*n_val], mask_val[:, :, :, :, 1:6*n_val])
                    val_dataset = DistributedUtils.DistributedDataContainer(backend, val_dataset)

                    val_loader = MLUtils.DataLoader(val_dataset; batchsize = 1, partial = false, collate = true, shuffle = true, parallel = true) |> gpu
                    for (x, y, mask) in val_loader
                        l = loss(model, x, y, mask)
                        push!(batch_loss, l)
                    end
                end
                GC.gc(false)
                CUDA.reclaim()

                push!(losses, mean(batch_loss))
                @info mean(batch_loss)
            end

            model_cpu = cpu(model)
            losses_cpu = cpu(losses)
            st_opt_cpu = cpu(st_opt)

            jldsave("path/to/save/model.jld2", model_state = Flux.state(model_cpu), loss = losses_cpu)
        end
    end
end


I have also tried

MPI.Init()
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)

function main(unet, arg)
...

    if rank == 0
        jldsave("trained_model.jld2", model_state = Flux.state(cpu(model)))
    end


    MPI.Barrier(comm)
    MPI.Finalize()
end

with the result, that my model does not train anymore, the loss does not decrease at all. The ideal case would be to be able to save the model after every epoch.