From ce069545d4b9ac32a094117de75919087a7bc21e Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sat, 6 Jan 2018 10:05:28 +0100 Subject: [PATCH] Added CUDA interface to get temporary-buffer size for GEMM routine --- include/clblast_cuda.h | 12 ++++++++ scripts/generator/generator.py | 4 +-- src/clblast.cpp | 12 ++++---- src/clblast_cuda.cpp | 52 ++++++++++++++++++++++++++++++++++ test/routines/level3/xgemm.hpp | 16 +++++++---- 5 files changed, 83 insertions(+), 13 deletions(-) diff --git a/include/clblast_cuda.h b/include/clblast_cuda.h index 0f510981..e0d1d638 100644 --- a/include/clblast_cuda.h +++ b/include/clblast_cuda.h @@ -69,6 +69,7 @@ enum class StatusCode { kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small // Custom additional status codes for CLBlast + kInsufficientMemoryTemp = -2050, // Temporary buffer provided to GEMM routine is too small kInvalidBatchCount = -2049, // The batch count needs to be positive kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel @@ -620,6 +621,17 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T // ================================================================================================= +// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) +template +StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + const CUdevice device, size_t& temp_buffer_size); + +// ================================================================================================= + // CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on // for the same device. This cache can be cleared to free up system memory or in case of debugging. StatusCode PUBLIC_API ClearCache(); diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index e0c26140..5fbce2c4 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -46,8 +46,8 @@ FILES = [ "/include/clblast_cuda.h", "/src/clblast_cuda.cpp", ] -HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 94, 21] -FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 25, 3] +HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21] +FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 36, 55] HEADER_LINES_DOC = 0 FOOTER_LINES_DOC = 63 diff --git a/src/clblast.cpp b/src/clblast.cpp index 461cf31f..f5e2f1be 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2345,7 +2345,7 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const size_t a_offset, const size_t a_ld, const size_t b_offset, const size_t b_ld, const size_t c_offset, const size_t c_ld, - RawCommandQueue* queue, size_t& temp_buffer_size) { + cl_command_queue* queue, size_t& temp_buffer_size) { try { // Retrieves the tuning database @@ -2371,23 +2371,23 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, - const size_t, const size_t, RawCommandQueue*, size_t&); + const size_t, const size_t, cl_command_queue*, size_t&); template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, - const size_t, const size_t, RawCommandQueue*, size_t&); + const size_t, const size_t, cl_command_queue*, size_t&); template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, - const size_t, const size_t, RawCommandQueue*, size_t&); + const size_t, const size_t, cl_command_queue*, size_t&); template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, - const size_t, const size_t, RawCommandQueue*, size_t&); + const size_t, const size_t, cl_command_queue*, size_t&); template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, - const size_t, const size_t, RawCommandQueue*, size_t&); + const size_t, const size_t, cl_command_queue*, size_t&); // ================================================================================================= } // namespace clblast diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp index 187443eb..21514c74 100644 --- a/src/clblast_cuda.cpp +++ b/src/clblast_cuda.cpp @@ -2436,5 +2436,57 @@ template StatusCode PUBLIC_API GemmBatched(const Layout, const Transpose, const size_t, const CUcontext, const CUdevice); +// ================================================================================================= + +// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) +template +StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose, + const size_t m, const size_t n, const size_t k, + const size_t a_offset, const size_t a_ld, + const size_t b_offset, const size_t b_ld, + const size_t c_offset, const size_t c_ld, + const CUdevice device, size_t& temp_buffer_size) { + try { + + // Retrieves the tuning database + const auto device_cpp = Device(device); + const auto kernel_names = std::vector{"Xgemm", "GemmRoutine"}; + Databases db(kernel_names); + Routine::InitDatabase(device_cpp, kernel_names, PrecisionValue(), {}, db); + + // Computes the buffer size + if (Xgemm::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) { + temp_buffer_size = 0; + } + else { + temp_buffer_size = Xgemm::GetTempSize(layout, a_transpose, b_transpose, m, n, k, + a_offset, a_ld, b_offset, b_ld, c_offset, c_ld, + db["MWG"], db["NWG"], db["KWG"]); + } + temp_buffer_size *= sizeof(T); // translate from num-elements to bytes + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, const CUdevice, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, const CUdevice, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, const CUdevice, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, const CUdevice, size_t&); +template StatusCode PUBLIC_API GemmTempBufferSize(const Layout, const Transpose, const Transpose, + const size_t, const size_t, const size_t, + const size_t, const size_t, const size_t, const size_t, + const size_t, const size_t, const CUdevice, size_t&); + // ================================================================================================= } // namespace clblast diff --git a/test/routines/level3/xgemm.hpp b/test/routines/level3/xgemm.hpp index a74cab2d..4cfa9c83 100644 --- a/test/routines/level3/xgemm.hpp +++ b/test/routines/level3/xgemm.hpp @@ -67,7 +67,6 @@ class TestXgemm { args.c_size = GetSizeC(args); // Optionally (V != 0) enforces indirect (V == 1) or direct (V == 2) kernels - auto queue_plain = queue(); if (V != 0) { const auto device = queue.GetDevice(); const auto switch_threshold = (V == 1) ? size_t{0} : size_t{4096}; // large enough for tests @@ -78,9 +77,16 @@ class TestXgemm { // Sets the size of the temporary buffer (optional argument to GEMM) auto temp_buffer_size = size_t{0}; - GemmTempBufferSize(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k, - args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld, - &queue_plain, temp_buffer_size); + #ifdef OPENCL_API + auto queue_plain = queue(); + GemmTempBufferSize(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k, + args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld, + &queue_plain, temp_buffer_size); + #elif CUDA_API + GemmTempBufferSize(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k, + args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld, + queue.GetDevice()(), temp_buffer_size); + #endif args.ap_size = (temp_buffer_size + sizeof(T)) / sizeof(T); // + sizeof(T) to prevent zero } @@ -117,7 +123,7 @@ class TestXgemm { buffers.a_mat(), args.a_offset, args.a_ld, buffers.b_mat(), args.b_offset, args.b_ld, args.beta, buffers.c_mat(), args.c_offset, args.c_ld, - queue.GetContext()(), queue.GetDevice()()); + queue.GetContext()(), queue.GetDevice()(), buffers.ap_mat()); // temp buffer cuStreamSynchronize(queue()); #endif return status;