Added CUDA interface to get temporary-buffer size for GEMM routine
parent
44431daecc
commit
ce069545d4
|
@ -69,6 +69,7 @@ enum class StatusCode {
|
|||
kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
|
||||
|
||||
// Custom additional status codes for CLBlast
|
||||
kInsufficientMemoryTemp = -2050, // Temporary buffer provided to GEMM routine is too small
|
||||
kInvalidBatchCount = -2049, // The batch count needs to be positive
|
||||
kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
|
||||
kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
|
||||
|
@ -620,6 +621,17 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
|
|||
|
||||
// =================================================================================================
|
||||
|
||||
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
|
||||
template <typename T>
|
||||
StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
|
||||
const size_t m, const size_t n, const size_t k,
|
||||
const size_t a_offset, const size_t a_ld,
|
||||
const size_t b_offset, const size_t b_ld,
|
||||
const size_t c_offset, const size_t c_ld,
|
||||
const CUdevice device, size_t& temp_buffer_size);
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
|
||||
// for the same device. This cache can be cleared to free up system memory or in case of debugging.
|
||||
StatusCode PUBLIC_API ClearCache();
|
||||
|
|
|
@ -46,8 +46,8 @@ FILES = [
|
|||
"/include/clblast_cuda.h",
|
||||
"/src/clblast_cuda.cpp",
|
||||
]
|
||||
HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 94, 21]
|
||||
FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 25, 3]
|
||||
HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21]
|
||||
FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 36, 55]
|
||||
HEADER_LINES_DOC = 0
|
||||
FOOTER_LINES_DOC = 63
|
||||
|
||||
|
|
|
@ -2345,7 +2345,7 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
|
|||
const size_t a_offset, const size_t a_ld,
|
||||
const size_t b_offset, const size_t b_ld,
|
||||
const size_t c_offset, const size_t c_ld,
|
||||
RawCommandQueue* queue, size_t& temp_buffer_size) {
|
||||
cl_command_queue* queue, size_t& temp_buffer_size) {
|
||||
try {
|
||||
|
||||
// Retrieves the tuning database
|
||||
|
@ -2371,23 +2371,23 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
|
|||
template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, RawCommandQueue*, size_t&);
|
||||
const size_t, const size_t, cl_command_queue*, size_t&);
|
||||
template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, RawCommandQueue*, size_t&);
|
||||
const size_t, const size_t, cl_command_queue*, size_t&);
|
||||
template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, RawCommandQueue*, size_t&);
|
||||
const size_t, const size_t, cl_command_queue*, size_t&);
|
||||
template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, RawCommandQueue*, size_t&);
|
||||
const size_t, const size_t, cl_command_queue*, size_t&);
|
||||
template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, RawCommandQueue*, size_t&);
|
||||
const size_t, const size_t, cl_command_queue*, size_t&);
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
|
|
@ -2436,5 +2436,57 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
|
|||
const size_t,
|
||||
const CUcontext, const CUdevice);
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
|
||||
template <typename T>
|
||||
StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
|
||||
const size_t m, const size_t n, const size_t k,
|
||||
const size_t a_offset, const size_t a_ld,
|
||||
const size_t b_offset, const size_t b_ld,
|
||||
const size_t c_offset, const size_t c_ld,
|
||||
const CUdevice device, size_t& temp_buffer_size) {
|
||||
try {
|
||||
|
||||
// Retrieves the tuning database
|
||||
const auto device_cpp = Device(device);
|
||||
const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"};
|
||||
Databases db(kernel_names);
|
||||
Routine::InitDatabase(device_cpp, kernel_names, PrecisionValue<T>(), {}, db);
|
||||
|
||||
// Computes the buffer size
|
||||
if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) {
|
||||
temp_buffer_size = 0;
|
||||
}
|
||||
else {
|
||||
temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k,
|
||||
a_offset, a_ld, b_offset, b_ld, c_offset, c_ld,
|
||||
db["MWG"], db["NWG"], db["KWG"]);
|
||||
}
|
||||
temp_buffer_size *= sizeof(T); // translate from num-elements to bytes
|
||||
return StatusCode::kSuccess;
|
||||
} catch (...) { return DispatchException(); }
|
||||
}
|
||||
template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const CUdevice, size_t&);
|
||||
template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const CUdevice, size_t&);
|
||||
template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const CUdevice, size_t&);
|
||||
template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const CUdevice, size_t&);
|
||||
template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const size_t, const size_t,
|
||||
const size_t, const size_t, const CUdevice, size_t&);
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
|
|
@ -67,7 +67,6 @@ class TestXgemm {
|
|||
args.c_size = GetSizeC(args);
|
||||
|
||||
// Optionally (V != 0) enforces indirect (V == 1) or direct (V == 2) kernels
|
||||
auto queue_plain = queue();
|
||||
if (V != 0) {
|
||||
const auto device = queue.GetDevice();
|
||||
const auto switch_threshold = (V == 1) ? size_t{0} : size_t{4096}; // large enough for tests
|
||||
|
@ -78,9 +77,16 @@ class TestXgemm {
|
|||
|
||||
// Sets the size of the temporary buffer (optional argument to GEMM)
|
||||
auto temp_buffer_size = size_t{0};
|
||||
GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
|
||||
args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
|
||||
&queue_plain, temp_buffer_size);
|
||||
#ifdef OPENCL_API
|
||||
auto queue_plain = queue();
|
||||
GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
|
||||
args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
|
||||
&queue_plain, temp_buffer_size);
|
||||
#elif CUDA_API
|
||||
GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
|
||||
args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
|
||||
queue.GetDevice()(), temp_buffer_size);
|
||||
#endif
|
||||
args.ap_size = (temp_buffer_size + sizeof(T)) / sizeof(T); // + sizeof(T) to prevent zero
|
||||
}
|
||||
|
||||
|
@ -117,7 +123,7 @@ class TestXgemm {
|
|||
buffers.a_mat(), args.a_offset, args.a_ld,
|
||||
buffers.b_mat(), args.b_offset, args.b_ld, args.beta,
|
||||
buffers.c_mat(), args.c_offset, args.c_ld,
|
||||
queue.GetContext()(), queue.GetDevice()());
|
||||
queue.GetContext()(), queue.GetDevice()(), buffers.ap_mat()); // temp buffer
|
||||
cuStreamSynchronize(queue());
|
||||
#endif
|
||||
return status;
|
||||
|
|
Loading…
Reference in New Issue