Hello!
I’m trying to do some parallel computation on the 2 first dimensions of a 3D array A[j,k, t]
. To put things in context, the data are timeseries, where j
and k
are spatial grid points and t
is the time dimension. There is also no possibility to permutedims in this context (Climate dataset can be huge).
My objective is to compute the same function for each grid points, using the whole timeserie for a given grid point j, k
. Hence, going eachindex
but only for the spatial dimensions.
The first attempt is to simply put Threads.@threads
before one of the “spatial loop”. Here’s a MWE.
using BenchmarkTools
function simplethread()
A = randn(3, 3, 10000)
dataout = fill(NaN, 3, 3, 10000)
for k = 1:size(A, 2)
Threads.@threads for j = 1:size(A, 1)
val = somefunction(A[j,k,:])
dataout[j,k,:] = val
end
end
return dataout
end
function somefunction(datain) # (in real case, I do `Polynomials.polyval` over the timeserie using a previous `polyfit`).
dataout = datain .+ randn(1)
end
@benchmark simplethread()
BenchmarkTools.Trial:
memory estimate: 2.76 MiB
allocs estimate: 442
--------------
minimum time: 1.414 ms (0.00% GC)
median time: 1.576 ms (0.00% GC)
mean time: 2.046 ms (16.58% GC)
maximum time: 231.659 ms (98.22% GC)
--------------
samples: 2438
evals/sample: 1
Now, I want to effectively do calculations in parallel for each grid point, not just one of the spatial dimension. I tried using CartesianRange, but it does not work. Here’s my attempt :
function cartesianthread()
A = randn(3, 3, 10000)
dataout = fill(NaN, 3, 3, 10000)
R = CartesianRange(Base.front(indices(A)))
Threads.@threads for r in R
val = somefunction(A[r,:])
dataout[r,:] = val
end
return dataout
end
cartesianthread()
ERROR: MethodError: no method matching unsafe_getindex(::CartesianRange{CartesianIndex{2}}, ::Int64)
Closest candidates are:
unsafe_getindex(::StepRangeLen{T,#s45,#s44} where #s44<:Base.TwicePrecision where #s45<:Base.TwicePrecision, ::Integer) where T at twiceprecision.jl:193
unsafe_getindex(::StepRangeLen{T,R,S} where S where R, ::Integer) where T at range.jl:505
unsafe_getindex(::LinSpace, ::Integer) at range.jl:510
...
Stacktrace:
[1] (::##342#threadsfor_fun#26{CartesianRange{CartesianIndex{2}},Array{Float64,3},Array{Float64,3}})(::Bool) at ./threadingconstructs.jl:63
[2] macro expansion at ./threadingconstructs.jl:71 [inlined]
[3] cartesianthread() at ./REPL[120]:7
Trying without the Threads.@threads
macro inside the cartesianthread()
function returns the correct values.
Thanks for any help or hints about what can be done!