Added CUDA interface to get temporary-buffer size for GEMM routine

pull/238/head
Cedric Nugteren 2018-01-06 10:05:28 +01:00
parent 44431daecc
commit ce069545d4
5 changed files with 83 additions and 13 deletions

View File

@ -69,6 +69,7 @@ enum class StatusCode {
kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
// Custom additional status codes for CLBlast
kInsufficientMemoryTemp = -2050, // Temporary buffer provided to GEMM routine is too small
kInvalidBatchCount = -2049, // The batch count needs to be positive
kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
@ -620,6 +621,17 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
// =================================================================================================
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
template <typename T>
StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const size_t a_offset, const size_t a_ld,
const size_t b_offset, const size_t b_ld,
const size_t c_offset, const size_t c_ld,
const CUdevice device, size_t& temp_buffer_size);
// =================================================================================================
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
// for the same device. This cache can be cleared to free up system memory or in case of debugging.
StatusCode PUBLIC_API ClearCache();

View File

@ -46,8 +46,8 @@ FILES = [
"/include/clblast_cuda.h",
"/src/clblast_cuda.cpp",
]
HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 94, 21]
FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 25, 3]
HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21]
FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 36, 55]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63

View File

@ -2345,7 +2345,7 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
const size_t a_offset, const size_t a_ld,
const size_t b_offset, const size_t b_ld,
const size_t c_offset, const size_t c_ld,
RawCommandQueue* queue, size_t& temp_buffer_size) {
cl_command_queue* queue, size_t& temp_buffer_size) {
try {
// Retrieves the tuning database
@ -2371,23 +2371,23 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
const size_t, const size_t, RawCommandQueue*, size_t&);
const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
const size_t, const size_t, RawCommandQueue*, size_t&);
const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
const size_t, const size_t, RawCommandQueue*, size_t&);
const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
const size_t, const size_t, RawCommandQueue*, size_t&);
const size_t, const size_t, cl_command_queue*, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
const size_t, const size_t, RawCommandQueue*, size_t&);
const size_t, const size_t, cl_command_queue*, size_t&);
// =================================================================================================
} // namespace clblast

View File

@ -2436,5 +2436,57 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
const size_t,
const CUcontext, const CUdevice);
// =================================================================================================
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
template <typename T>
StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const size_t a_offset, const size_t a_ld,
const size_t b_offset, const size_t b_ld,
const size_t c_offset, const size_t c_ld,
const CUdevice device, size_t& temp_buffer_size) {
try {
// Retrieves the tuning database
const auto device_cpp = Device(device);
const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"};
Databases db(kernel_names);
Routine::InitDatabase(device_cpp, kernel_names, PrecisionValue<T>(), {}, db);
// Computes the buffer size
if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) {
temp_buffer_size = 0;
}
else {
temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k,
a_offset, a_ld, b_offset, b_ld, c_offset, c_ld,
db["MWG"], db["NWG"], db["KWG"]);
}
temp_buffer_size *= sizeof(T); // translate from num-elements to bytes
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
const size_t, const size_t, const CUdevice, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
const size_t, const size_t, const CUdevice, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
const size_t, const size_t, const CUdevice, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
const size_t, const size_t, const CUdevice, size_t&);
template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const size_t, const size_t, const size_t, const size_t,
const size_t, const size_t, const CUdevice, size_t&);
// =================================================================================================
} // namespace clblast

View File

@ -67,7 +67,6 @@ class TestXgemm {
args.c_size = GetSizeC(args);
// Optionally (V != 0) enforces indirect (V == 1) or direct (V == 2) kernels
auto queue_plain = queue();
if (V != 0) {
const auto device = queue.GetDevice();
const auto switch_threshold = (V == 1) ? size_t{0} : size_t{4096}; // large enough for tests
@ -78,9 +77,16 @@ class TestXgemm {
// Sets the size of the temporary buffer (optional argument to GEMM)
auto temp_buffer_size = size_t{0};
GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
&queue_plain, temp_buffer_size);
#ifdef OPENCL_API
auto queue_plain = queue();
GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
&queue_plain, temp_buffer_size);
#elif CUDA_API
GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
queue.GetDevice()(), temp_buffer_size);
#endif
args.ap_size = (temp_buffer_size + sizeof(T)) / sizeof(T); // + sizeof(T) to prevent zero
}
@ -117,7 +123,7 @@ class TestXgemm {
buffers.a_mat(), args.a_offset, args.a_ld,
buffers.b_mat(), args.b_offset, args.b_ld, args.beta,
buffers.c_mat(), args.c_offset, args.c_ld,
queue.GetContext()(), queue.GetDevice()());
queue.GetContext()(), queue.GetDevice()(), buffers.ap_mat()); // temp buffer
cuStreamSynchronize(queue());
#endif
return status;