Stochastic gradient descent

Hello, I am working on implementing a function SGD ( Stochastic gradient descent )
I did it like the following but I want to Parallelize the for loop over all batches in the function in an optimum way can anyone help?

function SGD(nn::Network,
        training_data_x::Vector{Vector{Float64}},
        training_data_y::Vector{Vector{Float64}},
        epochs::Int, batch_size::Int, eta::Float64,
        lambda::Float64 = 0.0;
        validation_data_x::Vector{Vector{Float64}} = [],
        validation_data_y::Union{Vector{Int64}, Vector{Vector{Float64}}} = [],
        monitor_training_cost = true,
        monitor_validation_cost = true,
        monitor_training_accuracy = true,
        monitor_validation_accuracy = true)
    nn.training_cost = []
    nn.validation_cost = []
    nn.training_accuracy = []
    nn.validation_accuracy = []
    
for epoch in 1:epochs
        local perm = Random.randperm(length(training_data_x))
        for k in 1:batch_size:length(training_data_x)
            update!(nn,
                training_data_x[perm[k:min(k+batch_size‑1, end)]],
                training_data_y[perm[k:min(k+batch_size‑1, end)]],
                eta, lambda, length(training_data_x))
        end
        
        @info @Printf.sprintf("Epoch %d done", epoch)
        if monitor_training_cost
            push!(nn.training_cost,
                total_cost(nn, training_data_x, training_data_y, lambda))
            @info @Printf.sprintf("Cost on training data: %f",
                nn.training_cost[end])
        end
        if monitor_validation_cost
            push!(nn.validation_cost,
                total_cost(nn, validation_data_x, validation_data_y, lambda))
            @info @Printf.sprintf("Cost on validation data: %f",
                nn.validation_cost[end])
        end
        
        if monitor_training_accuracy
            local a = accuracy(nn, training_data_x, training_data_y)
            local l = length(training_data_x)
            local r = a/l
            @info @Printf.sprintf("Accuracy on training data: %5d / %5d = %5.1f%% correct",
                a, l, 100*r)
            push!(nn.training_accuracy, r)
        end
        
        if monitor_validation_accuracy
            local a = accuracy(nn, validation_data_x, validation_data_y)
            local l = length(validation_data_x)
            local r = a/l
            @info @Printf.sprintf("Accuracy on validation data: %5d / %5d = %5.1f%% correct",
                a, l, 100*r)
            push!(nn.validation_accuracy, r)
        end
    end
    nn
end