Added CUDA interface to get temporary-buffer size for GEMM routine
parent
44431daecc
commit
ce069545d4
|
@ -69,6 +69,7 @@ enum class StatusCode {
|
||||||
kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
|
kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
|
||||||
|
|
||||||
// Custom additional status codes for CLBlast
|
// Custom additional status codes for CLBlast
|
||||||
|
kInsufficientMemoryTemp = -2050, // Temporary buffer provided to GEMM routine is too small
|
||||||
kInvalidBatchCount = -2049, // The batch count needs to be positive
|
kInvalidBatchCount = -2049, // The batch count needs to be positive
|
||||||
kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
|
kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
|
||||||
kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
|
kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
|
||||||
|
@ -620,6 +621,17 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
|
||||||
|
template <typename T>
|
||||||
|
StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const size_t a_offset, const size_t a_ld,
|
||||||
|
const size_t b_offset, const size_t b_ld,
|
||||||
|
const size_t c_offset, const size_t c_ld,
|
||||||
|
const CUdevice device, size_t& temp_buffer_size);
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
|
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
|
||||||
// for the same device. This cache can be cleared to free up system memory or in case of debugging.
|
// for the same device. This cache can be cleared to free up system memory or in case of debugging.
|
||||||
StatusCode PUBLIC_API ClearCache();
|
StatusCode PUBLIC_API ClearCache();
|
||||||
|
|
|
@ -46,8 +46,8 @@ FILES = [
|
||||||
"/include/clblast_cuda.h",
|
"/include/clblast_cuda.h",
|
||||||
"/src/clblast_cuda.cpp",
|
"/src/clblast_cuda.cpp",
|
||||||
]
|
]
|
||||||
HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 94, 21]
|
HEADER_LINES = [123, 21, 126, 24, 29, 41, 29, 65, 32, 95, 21]
|
||||||
FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 25, 3]
|
FOOTER_LINES = [36, 56, 27, 38, 6, 6, 6, 9, 2, 36, 55]
|
||||||
HEADER_LINES_DOC = 0
|
HEADER_LINES_DOC = 0
|
||||||
FOOTER_LINES_DOC = 63
|
FOOTER_LINES_DOC = 63
|
||||||
|
|
||||||
|
|
|
@ -2345,7 +2345,7 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
|
||||||
const size_t a_offset, const size_t a_ld,
|
const size_t a_offset, const size_t a_ld,
|
||||||
const size_t b_offset, const size_t b_ld,
|
const size_t b_offset, const size_t b_ld,
|
||||||
const size_t c_offset, const size_t c_ld,
|
const size_t c_offset, const size_t c_ld,
|
||||||
RawCommandQueue* queue, size_t& temp_buffer_size) {
|
cl_command_queue* queue, size_t& temp_buffer_size) {
|
||||||
try {
|
try {
|
||||||
|
|
||||||
// Retrieves the tuning database
|
// Retrieves the tuning database
|
||||||
|
@ -2371,23 +2371,23 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
|
||||||
template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
|
template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
|
||||||
const size_t, const size_t, const size_t,
|
const size_t, const size_t, const size_t,
|
||||||
const size_t, const size_t, const size_t, const size_t,
|
const size_t, const size_t, const size_t, const size_t,
|
||||||
const size_t, const size_t, RawCommandQueue*, size_t&);
|
const size_t, const size_t, cl_command_queue*, size_t&);
|
||||||
template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
|
template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
|
||||||
const size_t, const size_t, const size_t,
|
const size_t, const size_t, const size_t,
|
||||||
const size_t, const size_t, const size_t, const size_t,
|
const size_t, const size_t, const size_t, const size_t,
|
||||||
const size_t, const size_t, RawCommandQueue*, size_t&);
|
const size_t, const size_t, cl_command_queue*, size_t&);
|
||||||
template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
|
template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
|
||||||
const size_t, const size_t, const size_t,
|
const size_t, const size_t, const size_t,
|
||||||
const size_t, const size_t, const size_t, const size_t,
|
const size_t, const size_t, const size_t, const size_t,
|
||||||
const size_t, const size_t, RawCommandQueue*, size_t&);
|
const size_t, const size_t, cl_command_queue*, size_t&);
|
||||||
template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
|
template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
|
||||||
const size_t, const size_t, const size_t,
|
const size_t, const size_t, const size_t,
|
||||||
const size_t, const size_t, const size_t, const size_t,
|
const size_t, const size_t, const size_t, const size_t,
|
||||||
const size_t, const size_t, RawCommandQueue*, size_t&);
|
const size_t, const size_t, cl_command_queue*, size_t&);
|
||||||
template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
|
template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
|
||||||
const size_t, const size_t, const size_t,
|
const size_t, const size_t, const size_t,
|
||||||
const size_t, const size_t, const size_t, const size_t,
|
const size_t, const size_t, const size_t, const size_t,
|
||||||
const size_t, const size_t, RawCommandQueue*, size_t&);
|
const size_t, const size_t, cl_command_queue*, size_t&);
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
} // namespace clblast
|
} // namespace clblast
|
||||||
|
|
|
@ -2436,5 +2436,57 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
|
||||||
const size_t,
|
const size_t,
|
||||||
const CUcontext, const CUdevice);
|
const CUcontext, const CUdevice);
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional)
|
||||||
|
template <typename T>
|
||||||
|
StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const size_t a_offset, const size_t a_ld,
|
||||||
|
const size_t b_offset, const size_t b_ld,
|
||||||
|
const size_t c_offset, const size_t c_ld,
|
||||||
|
const CUdevice device, size_t& temp_buffer_size) {
|
||||||
|
try {
|
||||||
|
|
||||||
|
// Retrieves the tuning database
|
||||||
|
const auto device_cpp = Device(device);
|
||||||
|
const auto kernel_names = std::vector<std::string>{"Xgemm", "GemmRoutine"};
|
||||||
|
Databases db(kernel_names);
|
||||||
|
Routine::InitDatabase(device_cpp, kernel_names, PrecisionValue<T>(), {}, db);
|
||||||
|
|
||||||
|
// Computes the buffer size
|
||||||
|
if (Xgemm<T>::UseDirectKernel(m, n, k, db["XGEMM_MIN_INDIRECT_SIZE"])) {
|
||||||
|
temp_buffer_size = 0;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k,
|
||||||
|
a_offset, a_ld, b_offset, b_ld, c_offset, c_ld,
|
||||||
|
db["MWG"], db["NWG"], db["KWG"]);
|
||||||
|
}
|
||||||
|
temp_buffer_size *= sizeof(T); // translate from num-elements to bytes
|
||||||
|
return StatusCode::kSuccess;
|
||||||
|
} catch (...) { return DispatchException(); }
|
||||||
|
}
|
||||||
|
template StatusCode PUBLIC_API GemmTempBufferSize<float>(const Layout, const Transpose, const Transpose,
|
||||||
|
const size_t, const size_t, const size_t,
|
||||||
|
const size_t, const size_t, const size_t, const size_t,
|
||||||
|
const size_t, const size_t, const CUdevice, size_t&);
|
||||||
|
template StatusCode PUBLIC_API GemmTempBufferSize<double>(const Layout, const Transpose, const Transpose,
|
||||||
|
const size_t, const size_t, const size_t,
|
||||||
|
const size_t, const size_t, const size_t, const size_t,
|
||||||
|
const size_t, const size_t, const CUdevice, size_t&);
|
||||||
|
template StatusCode PUBLIC_API GemmTempBufferSize<float2>(const Layout, const Transpose, const Transpose,
|
||||||
|
const size_t, const size_t, const size_t,
|
||||||
|
const size_t, const size_t, const size_t, const size_t,
|
||||||
|
const size_t, const size_t, const CUdevice, size_t&);
|
||||||
|
template StatusCode PUBLIC_API GemmTempBufferSize<double2>(const Layout, const Transpose, const Transpose,
|
||||||
|
const size_t, const size_t, const size_t,
|
||||||
|
const size_t, const size_t, const size_t, const size_t,
|
||||||
|
const size_t, const size_t, const CUdevice, size_t&);
|
||||||
|
template StatusCode PUBLIC_API GemmTempBufferSize<half>(const Layout, const Transpose, const Transpose,
|
||||||
|
const size_t, const size_t, const size_t,
|
||||||
|
const size_t, const size_t, const size_t, const size_t,
|
||||||
|
const size_t, const size_t, const CUdevice, size_t&);
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
} // namespace clblast
|
} // namespace clblast
|
||||||
|
|
|
@ -67,7 +67,6 @@ class TestXgemm {
|
||||||
args.c_size = GetSizeC(args);
|
args.c_size = GetSizeC(args);
|
||||||
|
|
||||||
// Optionally (V != 0) enforces indirect (V == 1) or direct (V == 2) kernels
|
// Optionally (V != 0) enforces indirect (V == 1) or direct (V == 2) kernels
|
||||||
auto queue_plain = queue();
|
|
||||||
if (V != 0) {
|
if (V != 0) {
|
||||||
const auto device = queue.GetDevice();
|
const auto device = queue.GetDevice();
|
||||||
const auto switch_threshold = (V == 1) ? size_t{0} : size_t{4096}; // large enough for tests
|
const auto switch_threshold = (V == 1) ? size_t{0} : size_t{4096}; // large enough for tests
|
||||||
|
@ -78,9 +77,16 @@ class TestXgemm {
|
||||||
|
|
||||||
// Sets the size of the temporary buffer (optional argument to GEMM)
|
// Sets the size of the temporary buffer (optional argument to GEMM)
|
||||||
auto temp_buffer_size = size_t{0};
|
auto temp_buffer_size = size_t{0};
|
||||||
GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
|
#ifdef OPENCL_API
|
||||||
args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
|
auto queue_plain = queue();
|
||||||
&queue_plain, temp_buffer_size);
|
GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
|
||||||
|
args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
|
||||||
|
&queue_plain, temp_buffer_size);
|
||||||
|
#elif CUDA_API
|
||||||
|
GemmTempBufferSize<T>(args.layout, args.a_transpose, args.b_transpose, args.m, args.n, args.k,
|
||||||
|
args.a_offset, args.a_ld, args.b_offset, args.b_ld, args.c_offset, args.c_ld,
|
||||||
|
queue.GetDevice()(), temp_buffer_size);
|
||||||
|
#endif
|
||||||
args.ap_size = (temp_buffer_size + sizeof(T)) / sizeof(T); // + sizeof(T) to prevent zero
|
args.ap_size = (temp_buffer_size + sizeof(T)) / sizeof(T); // + sizeof(T) to prevent zero
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,7 +123,7 @@ class TestXgemm {
|
||||||
buffers.a_mat(), args.a_offset, args.a_ld,
|
buffers.a_mat(), args.a_offset, args.a_ld,
|
||||||
buffers.b_mat(), args.b_offset, args.b_ld, args.beta,
|
buffers.b_mat(), args.b_offset, args.b_ld, args.beta,
|
||||||
buffers.c_mat(), args.c_offset, args.c_ld,
|
buffers.c_mat(), args.c_offset, args.c_ld,
|
||||||
queue.GetContext()(), queue.GetDevice()());
|
queue.GetContext()(), queue.GetDevice()(), buffers.ap_mat()); // temp buffer
|
||||||
cuStreamSynchronize(queue());
|
cuStreamSynchronize(queue());
|
||||||
#endif
|
#endif
|
||||||
return status;
|
return status;
|
||||||
|
|
Loading…
Reference in New Issue