diff --git a/src/database/database.cpp b/src/database/database.cpp index 34c44a29..2696fb9b 100644 --- a/src/database/database.cpp +++ b/src/database/database.cpp @@ -21,6 +21,7 @@ #include "database/kernels/xgemv_fast_rot.hpp" #include "database/kernels/xger.hpp" #include "database/kernels/xgemm.hpp" +#include "database/kernels/xgemm_direct.hpp" #include "database/kernels/copy.hpp" #include "database/kernels/pad.hpp" #include "database/kernels/transpose.hpp" @@ -38,6 +39,7 @@ const std::vector Database::database = { XgemvFastRotHalf, XgemvFastRotSingle, XgemvFastRotDouble, XgemvFastRotComplexSingle, XgemvFastRotComplexDouble, XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble, XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble, + XgemmDirectHalf, XgemmDirectSingle, XgemmDirectDouble, XgemmDirectComplexSingle, XgemmDirectComplexDouble, CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble, PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble, TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble, diff --git a/src/database/database.hpp b/src/database/database.hpp index a6ab49c5..7c0afb46 100644 --- a/src/database/database.hpp +++ b/src/database/database.hpp @@ -75,6 +75,7 @@ class Database { static const DatabaseEntry XgemvFastRotHalf, XgemvFastRotSingle, XgemvFastRotDouble, XgemvFastRotComplexSingle, XgemvFastRotComplexDouble; static const DatabaseEntry XgerHalf, XgerSingle, XgerDouble, XgerComplexSingle, XgerComplexDouble; static const DatabaseEntry XgemmHalf, XgemmSingle, XgemmDouble, XgemmComplexSingle, XgemmComplexDouble; + static const DatabaseEntry XgemmDirectHalf, XgemmDirectSingle, XgemmDirectDouble, XgemmDirectComplexSingle, XgemmDirectComplexDouble; static const DatabaseEntry CopyHalf, CopySingle, CopyDouble, CopyComplexSingle, CopyComplexDouble; static const DatabaseEntry PadHalf, PadSingle, PadDouble, PadComplexSingle, PadComplexDouble; static const DatabaseEntry TransposeHalf, TransposeSingle, TransposeDouble, TransposeComplexSingle, TransposeComplexDouble; diff --git a/src/database/kernels/xgemm_direct.hpp b/src/database/kernels/xgemm_direct.hpp new file mode 100644 index 00000000..76055ef2 --- /dev/null +++ b/src/database/kernels/xgemm_direct.hpp @@ -0,0 +1,76 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Database generator +// +// This file populates the database with best-found tuning parameters for the 'Xgemm' kernels. +// +// ================================================================================================= + +namespace clblast { +// ================================================================================================= + +const Database::DatabaseEntry Database::XgemmDirectHalf = { + "XgemmDirect", Precision::kHalf, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"KWGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"MWGD",32}, {"NDIMBD",8}, {"NDIMCD",8}, {"NWGD",32}, {"VWMD",1}, {"VWND",1} } }, + } + }, + } +}; + +// ================================================================================================= + +const Database::DatabaseEntry Database::XgemmDirectSingle = { + "XgemmDirect", Precision::kSingle, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"KWGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"MWGD",32}, {"NDIMBD",8}, {"NDIMCD",8}, {"NWGD",32}, {"VWMD",1}, {"VWND",1} } }, + } + }, + } +}; + +// ================================================================================================= + +const Database::DatabaseEntry Database::XgemmDirectComplexSingle = { + "XgemmDirect", Precision::kComplexSingle, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"KWGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"MWGD",32}, {"NDIMBD",8}, {"NDIMCD",8}, {"NWGD",32}, {"VWMD",1}, {"VWND",1} } }, + } + }, + } +}; + +// ================================================================================================= + +const Database::DatabaseEntry Database::XgemmDirectDouble = { + "XgemmDirect", Precision::kDouble, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"KWGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"MWGD",32}, {"NDIMBD",8}, {"NDIMCD",8}, {"NWGD",32}, {"VWMD",1}, {"VWND",1} } }, + } + }, + } +}; + +// ================================================================================================= + +const Database::DatabaseEntry Database::XgemmDirectComplexDouble = { + "XgemmDirect", Precision::kComplexDouble, { + { // Default + kDeviceTypeAll, "default", { + { "default", { {"KWGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"MWGD",32}, {"NDIMBD",8}, {"NDIMCD",8}, {"NWGD",32}, {"VWMD",1}, {"VWND",1} } }, + } + }, + } +}; + +// ================================================================================================= +} // namespace clblast diff --git a/src/kernels/level3/xgemm_direct.opencl b/src/kernels/level3/xgemm_direct.opencl index fb5972ba..801887dd 100644 --- a/src/kernels/level3/xgemm_direct.opencl +++ b/src/kernels/level3/xgemm_direct.opencl @@ -16,68 +16,140 @@ // literal). Comment-out this line for syntax-highlighting when developing. R"( +// Parameters set by the tuner or by the database. Here they are given a basic default value in case +// this kernel file is used outside of the CLBlast library. Note that all parameters here have a +// suffix 'D' to denote that they are for the 'direct' version of the GEMM kernel. +#ifndef MWGD + #define MWGD 8 // Tile-size in dimension M (e.g. 64, 128) +#endif +#ifndef NWGD + #define NWGD 8 // Tile-size in dimension N (e.g. 64, 128) +#endif +#ifndef KWGD + #define KWGD 8 // Tile-size in dimension K (e.g. 8, 16) +#endif +#ifndef MDIMCD + #define MDIMCD 8 // Threads per workgroup in M-dimension (e.g. 8, 16, 32) +#endif +#ifndef NDIMCD + #define NDIMCD 8 // Threads per workgroup in N-dimension (e.g. 8, 16, 32) +#endif +#ifndef MDIMAD + #define MDIMAD 8 // Re-shaped tile dimension of matrix A: KDIMAD * MDIMAD +#endif +#ifndef NDIMBD + #define NDIMBD 8 // Re-shaped tile dimension of matrix B: KDIMBD * NDIMBD +#endif +#ifndef KWID + #define KWID 1 // Unroll factor of the KWGD loop (smaller or equal than KWGD) +#endif +#ifndef VWMD + #define VWMD 1 // Vector width of matrices A and C +#endif +#ifndef VWND + #define VWND 1 // Vector width of matrix B +#endif + +// Helper parameters based on the above tuning parameters +#define MWID (MWGD/MDIMCD) // Work per work-item (M-dimension) +#define NWID (NWGD/NDIMCD) // Work per work-item (N-dimension) +#define KDIMAD ((MDIMCD*NDIMCD)/(MDIMAD)) // Re-shaped tile dimension of matrix A: KDIMAD * MDIMAD +#define KDIMBD ((MDIMCD*NDIMCD)/(NDIMBD)) // Re-shaped tile dimension of matrix B: KDIMBD * NDIMBD +#define MWAD (MWGD/MDIMAD) // Amount of loads-per-thread for matrix A (M-dimension) +#define KWAD (KWGD/KDIMAD) // Amount of loads-per-thread for matrix A (K-dimension) +#define KWBD (KWGD/KDIMBD) // Amount of loads-per-thread for matrix B (K-dimension) +#define NWBD (NWGD/NDIMBD) // Amount of loads-per-thread for matrix B (N-dimension) + +// ================================================================================================= + +// Data-widths in dimension M +#if VWMD == 1 + typedef real realMD; +#elif VWMD == 2 + typedef real2 realMD; +#elif VWMD == 4 + typedef real4 realMD; +#elif VWMD == 8 + typedef real8 realMD; +#elif VWMD == 16 + typedef real16 realMD; +#endif + +// Data-widths in dimension N +#if VWND == 1 + typedef real realND; +#elif VWND == 2 + typedef real2 realND; +#elif VWND == 4 + typedef real4 realND; +#elif VWND == 8 + typedef real8 realND; +#elif VWND == 16 + typedef real16 realND; +#endif + // ================================================================================================= // Caches global off-chip memory into local (shared) memory on-chip. This function is specific for // caching the A input matrix. -inline void GlobalToLocalDirectA(const __global realM* restrict agm, __local real* alm, +inline void GlobalToLocalDirectA(const __global realMD* restrict agm, __local real* alm, const int a_ld, const int a_offset, const int tid, const int kwg, const int a_transpose, const int a_conjugate) { - const int la0 = tid % MDIMA; - const int la1 = tid / MDIMA; + const int la0 = tid % MDIMAD; + const int la1 = tid / MDIMAD; #pragma unroll - for (int mia=0; mia local (matrix A and B) GlobalToLocalDirectA(agm, alm, a_ld, a_offset, tid, kwg, a_transpose, a_conjugate); GlobalToLocalDirectB(bgm, blm, b_ld, b_offset, tid, kwg, b_transpose, b_conjugate); barrier(CLK_LOCAL_MEM_FENCE); - // Loops over all workitem tiles, unrolled by a factor KWI - for (int pwi=0; pwi private (matrix A) @@ -303,7 +375,7 @@ __kernel void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK, // Loads A into register memory #pragma unroll - for (int mi=0; mi Xgemm::Xgemm(Queue &queue, EventPointer event, const std::string &name): - Routine(queue, event, name, {"Copy","Pad","Transpose","Padtranspose","Xgemm"}, PrecisionValue()) { + Routine(queue, event, name, {"Copy","Pad","Transpose","Padtranspose","Xgemm", "XgemmDirect"}, + PrecisionValue()) { source_string_ = #include "../../kernels/level3/level3.opencl" #include "../../kernels/level3/copy_fast.opencl" @@ -299,13 +300,13 @@ StatusCode Xgemm::GemmDirect(const size_t m, const size_t n, const size_t k, kernel.SetArgument(18, static_cast(b_conjugate)); // Computes the global and local thread sizes - const auto m_ceiled = Ceil(m, db_["MWG"]); - const auto n_ceiled = Ceil(n, db_["NWG"]); + const auto m_ceiled = Ceil(m, db_["MWGD"]); + const auto n_ceiled = Ceil(n, db_["NWGD"]); const auto global = std::vector{ - (m_ceiled * db_["MDIMC"]) / db_["MWG"], - (n_ceiled * db_["NDIMC"]) / db_["NWG"] + (m_ceiled * db_["MDIMCD"]) / db_["MWGD"], + (n_ceiled * db_["NDIMCD"]) / db_["NWGD"] }; - const auto local = std::vector{db_["MDIMC"], db_["NDIMC"]}; + const auto local = std::vector{db_["MDIMCD"], db_["NDIMCD"]}; // Launches the kernel auto status = RunKernel(kernel, queue_, device_, global, local, event_);