Hi,
To get a faster GPU sort
+ sortperm
for a specific use case, I succesfully wrapped CUB’s SortPairs
method for sorting 64-bit integer keys (sort
) and 32-bit integer values (sortperm
). However, when measuring the execution time of only the CUDA code, or the full ccall
, I notice a significant overhead, at least on my Windows machine. This overhead does not appear on my Linux machine. Curiously, the overhead is also not there on Windows when I’m ccall
ing a simple addition kernel C-wrapper.
Here is the code I’m using:
cub_wrapper.cu:
#ifdef _WIN32
#define EXPORT_SYMBOL __declspec(dllexport)
#else
#define EXPORT_SYMBOL __attribute__((visibility("default")))
#endif
#include <cub/cub.cuh>
#include <cuda_runtime.h>
#include <cstdint>
// Basically cub::DeviceRadixSort::SortPairs, instantiated for 64-bit integer keys, 32-bit integer values
void sort_pairs_internal(int64_t* d_keys_in, int64_t* d_keys_out, int32_t* d_values_in, int32_t* d_values_out, int32_t num_items)
{
void* d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
// Determine temporary device storage requirements (this call performs no actual sorting)
cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, d_values_in, d_values_out, num_items);
// Allocate necessary temporary storage
cudaMalloc(&d_temp_storage, temp_storage_bytes);
// Run actual sorting operation
cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, d_values_in, d_values_out, num_items);
// Free our temporary storage
cudaFree(d_temp_storage);
}
// C wrapper
extern "C" EXPORT_SYMBOL void sort_pairs(int64_t* d_keys_in, int64_t* d_keys_out, int32_t* d_values_in, int32_t* d_values_out, int32_t num_items) {
sort_pairs_internal(d_keys_in, d_keys_out, d_values_in, d_values_out, num_items);
}
// C wrapper with timer
extern "C" EXPORT_SYMBOL float timed_sort_pairs(int64_t* d_keys_in, int64_t* d_keys_out, int32_t* d_values_in, int32_t* d_values_out, int32_t num_items) {
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaEventRecord(start);
sort_pairs_internal(d_keys_in, d_keys_out, d_values_in, d_values_out, num_items);
cudaEventRecord(stop);
cudaEventSynchronize(stop);
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
cudaEventDestroy(start);
cudaEventDestroy(stop);
return milliseconds;
}
// Create shared libary using
// nvcc -Xcompiler -fPIC -shared -arch=sm_86 -o cub_wrapper.so cub_wrapper.cu
// (The -Xcompiler -fPIC flag will be ignored by cl.exe on Windows.)
cub_ccall_test.jl:
using Statistics
using Test
using CUDA
const lib = "./cub_wrapper.so"
N::Int32 = 2^24
keys_in = CUDA.rand(Int64, N)
keys_out = similar(keys_in)
values_in = CuArray{Int32}(undef, N)
values_in .= 1:N
values_out = similar(values_in)
function sort_pairs!(keys_in::CuArray{Int64}, keys_out::CuArray{Int64}, values_in::CuArray{Int32}, values_out::CuArray{Int32}, N::Int32)
ccall((:sort_pairs, lib), Cvoid,
(CuRef{Int64}, CuRef{Int64}, CuRef{Int32}, CuRef{Int32}, Int32),
keys_in, keys_out, values_in, values_out, N
)
end
function sort_pairs_time!(keys_in::CuArray{Int64}, keys_out::CuArray{Int64}, values_in::CuArray{Int32}, values_out::CuArray{Int32}, N::Int32)
ccall((:timed_sort_pairs, lib), Cfloat,
(CuRef{Int64}, CuRef{Int64}, CuRef{Int32}, CuRef{Int32}, Int32),
keys_in, keys_out, values_in, values_out, N
)
end
CUDA.@sync sort_pairs!(keys_in, keys_out, values_in, values_out, N)
@test all(keys_out .== sort(keys_in))
@test all(values_out .== sortperm(keys_in))
function benchmark_sort_pairs_jltime(nb_runs, keys_in, keys_out, values_in, values_out, N)
times = Array{Float64}(undef, nb_runs)
for i = 1:nb_runs
stats = @timed CUDA.@sync sort_pairs!(keys_in, keys_out, values_in, values_out, N)
times[i] = stats.time
end
return median(times)
end
# Could also use BenchmarkTools here, but use same approach as for ..._cutime below, for consistency.
function benchmark_sort_pairs_cutime(nb_runs, keys_in, keys_out, values_in, values_out, N)
times = Array{Float32}(undef, nb_runs)
for i = 1:nb_runs
times[i] = sort_pairs_time!(keys_in, keys_out, values_in, values_out, N)
end
return median(times)
end
nb_runs = 100
println("GPU add, time including ccall:")
println("\t$(benchmark_sort_pairs_jltime(nb_runs, keys_in, keys_out, values_in, values_out, N) * 1000) ms")
println()
println("GPU add, time excluding ccall:")
println("\t$(benchmark_sort_pairs_cutime(nb_runs, keys_in, keys_out, values_in, values_out, N)) ms")
with outputs:
Windows (RTX 3070, Windows 10 22H2, Julia 1.10.0+0.x64.w64.mingw32, nvcc 11.8, cl 19.39.33522):
sort_pairs, time including ccall:
17.63445 ms
sort_pairs, time excluding ccall:
11.05888 ms
Linux (RTX 3080 Ti, Ubuntu 22.04.4 LTS, Julia 1.10.0+0.x64.linux.gnu, nvcc 11.8, g++ 11.4.0):
sort_pairs, time including ccall:
7.131116499999999 ms
sort_pairs, time excluding ccall:
7.1138716 ms
Does anyone know why there is this difference between my Windows and Linux machines concerning the ccall
overhead, and how I can get rid of this overhead on the former?