I am practicing to make my real code work on parallel. My challenge is in the relation between the threads and the main process. To make it clear, I am putting the serial code func_series()
, which solves a linear equation Ax=b
in time-loop with making update on A
and b
at each iteration. A is diagonal and the solution can be partitioned. I wrote func_parallel()
with two flags sharedFlag_runTDIter
and sharedFlag_finishTDIterThread
to synchronize between the threads, in which the threads update the values of A
and b
(from the main thread/memory) and return the equation solution (to the main thread/memory) at each time-step. However, I am not having the same results which is coming from the synchronization or miss use of remotecall_fetch
. Could you please guide me?
using LinearAlgebra, SparseArrays, .Threads, Distributed
addprocs(3);
@everywhere using SharedArrays
Result = [];
function func_series()
Nthreads = 3;
tmin = 1;
tmax = 2;
timeSim = tmin:tmax;
A = sparse([1.0 2 0 0 0 0; 3 1 0 0 0 0; 0 0 3 1 0 0; 0 0 2 4 0 0; 0 0 0 0 9 3; 0 0 0 0 2 4]);
b = [1.0, 1, 1, 1, 1, 1];
x = zeros(length(b));
x_record = zeros(length(x),size(timeSim,1));
indexRange = [1:2,
3:4,
5:6];
workvec = [similar(x, 2),
similar(x, 2),
similar(x, 2)];
for i in tmin:tmax
for j = 1:Nthreads
x[indexRange[j]] .= A[indexRange[j],indexRange[j]]\b[indexRange[j]];
end
A.nzval .+= i;
b .+= i;
x_record[:,i] .= x;
end
return x_record
end
function func_parallel()
Nthreads = 3;
tmin = 1;
tmax = 2;
timeSim = tmin:tmax;
A = sparse([1.0 2 0 0 0 0; 3 1 0 0 0 0; 0 0 3 1 0 0; 0 0 2 1 0 0; 0 0 0 0 1 3; 0 0 0 0 2 1]);
A_remotecall = remotecall(() -> A, 1);
b = [1.0, 1, 1, 1, 1, 1];
b_remotecall = remotecall(() -> b, 1);
x = zeros(length(b));
x_record = zeros(length(x),size(timeSim,1));
sharedFlag_runTDIter = SharedArray{Bool}(1);
sharedFlag_runTDIter[1] = false;
sharedFlag_finishTDIterThread = SharedArray{Bool}(Nthreads);
sharedFlag_finishTDIterThread .= false;
indexRange = [1:2,
3:4,
5:6];
workvec = [similar(x, 2),
similar(x, 2),
similar(x, 2)];
thrs = [Threads.@spawn begin
for _ in tmin:tmax
while sharedFlag_runTDIter[j] == false
nothing
end
remotecall_fetch(x[indexRange[j]] .= fetch(A_remotecall)[indexRange[j],indexRange[j]]\fetch(b_remotecall)[indexRange[j]]);
sharedFlag_finishTDIterThread[j] = true; # to inform the main thread that the result is ready
sharedFlag_runTDIter[j] = false; # to prevent the thread from solving the next time-step before receiving the update of A and b
end
end for j = 1:Nthreads];
for i in tmin:tmax
sharedFlag_runTDIter .= true; # to allow threads start solving x=A\b
if reduce(&, sharedFlag_finishTDIterThread) == true # wait until all threads finish solving x=A\b
A.nzval .+= i;
b .+= i;
end
x_record[:,i] .= x;
end
#fetch.(thrs);
return x_record
end
Result = func_series()
6Ă—2 Matrix{Float64}:
0.2 0.25
0.4 0.5
0.3 0.428571
0.1 0.142857
0.0333333 0.0526316
0.233333 0.368421
Result = func_parallel()
6Ă—2 Matrix{Float64}:
0.2 0.2
0.4 0.4
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0