Added a first version of a tuner for the GEMM direct kernel; collapsed MWGD, NWGD and KWGD into one WGD parameter

pull/108/head
Cedric Nugteren 2016-09-25 14:48:34 +02:00
parent 669f43aed6
commit 73d135c2ce
5 changed files with 292 additions and 106 deletions

View File

@ -134,7 +134,8 @@ endif()
# ==================================================================================================
# Sets the supported routines and the used kernels. New routines and kernels should be added here.
set(KERNELS copy_fast copy_pad transpose_fast transpose_pad xaxpy xdot xger xgemm xgemv)
set(KERNELS copy_fast copy_pad transpose_fast transpose_pad xaxpy xdot xger
xgemm xgemm_direct xgemv)
set(SAMPLE_PROGRAMS_CPP sgemm)
set(SAMPLE_PROGRAMS_C sasum dgemv sgemm haxpy cache)
set(LEVEL1_ROUTINES xswap xscal xcopy xaxpy xdot xdotu xdotc xnrm2 xasum xamax)

View File

@ -18,7 +18,7 @@ const Database::DatabaseEntry Database::XgemmDirectHalf = {
"XgemmDirect", Precision::kHalf, {
{ // Default
kDeviceTypeAll, "default", {
{ "default", { {"KWGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"MWGD",32}, {"NDIMBD",8}, {"NDIMCD",8}, {"NWGD",32}, {"VWMD",1}, {"VWND",1} } },
{ "default", { {"WGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"NDIMBD",8}, {"NDIMCD",8}, {"VWMD",1}, {"VWND",1} } },
}
},
}
@ -30,7 +30,7 @@ const Database::DatabaseEntry Database::XgemmDirectSingle = {
"XgemmDirect", Precision::kSingle, {
{ // Default
kDeviceTypeAll, "default", {
{ "default", { {"KWGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"MWGD",32}, {"NDIMBD",8}, {"NDIMCD",8}, {"NWGD",32}, {"VWMD",1}, {"VWND",1} } },
{ "default", { {"WGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"NDIMBD",8}, {"NDIMCD",8}, {"VWMD",1}, {"VWND",1} } },
}
},
}
@ -42,7 +42,7 @@ const Database::DatabaseEntry Database::XgemmDirectComplexSingle = {
"XgemmDirect", Precision::kComplexSingle, {
{ // Default
kDeviceTypeAll, "default", {
{ "default", { {"KWGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"MWGD",32}, {"NDIMBD",8}, {"NDIMCD",8}, {"NWGD",32}, {"VWMD",1}, {"VWND",1} } },
{ "default", { {"WGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"NDIMBD",8}, {"NDIMCD",8}, {"VWMD",1}, {"VWND",1} } },
}
},
}
@ -54,7 +54,7 @@ const Database::DatabaseEntry Database::XgemmDirectDouble = {
"XgemmDirect", Precision::kDouble, {
{ // Default
kDeviceTypeAll, "default", {
{ "default", { {"KWGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"MWGD",32}, {"NDIMBD",8}, {"NDIMCD",8}, {"NWGD",32}, {"VWMD",1}, {"VWND",1} } },
{ "default", { {"WGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"NDIMBD",8}, {"NDIMCD",8}, {"VWMD",1}, {"VWND",1} } },
}
},
}
@ -66,7 +66,7 @@ const Database::DatabaseEntry Database::XgemmDirectComplexDouble = {
"XgemmDirect", Precision::kComplexDouble, {
{ // Default
kDeviceTypeAll, "default", {
{ "default", { {"KWGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"MWGD",32}, {"NDIMBD",8}, {"NDIMCD",8}, {"NWGD",32}, {"VWMD",1}, {"VWND",1} } },
{ "default", { {"WGD",32}, {"KWID",2}, {"MDIMAD",8}, {"MDIMCD",8}, {"NDIMBD",8}, {"NDIMCD",8}, {"VWMD",1}, {"VWND",1} } },
}
},
}

View File

@ -19,14 +19,8 @@ R"(
// Parameters set by the tuner or by the database. Here they are given a basic default value in case
// this kernel file is used outside of the CLBlast library. Note that all parameters here have a
// suffix 'D' to denote that they are for the 'direct' version of the GEMM kernel.
#ifndef MWGD
#define MWGD 8 // Tile-size in dimension M (e.g. 64, 128)
#endif
#ifndef NWGD
#define NWGD 8 // Tile-size in dimension N (e.g. 64, 128)
#endif
#ifndef KWGD
#define KWGD 8 // Tile-size in dimension K (e.g. 8, 16)
#ifndef WGD
#define WGD 8 // Tile-size in dimension M, N, and K (e.g. 8, 16, 32, 64)
#endif
#ifndef MDIMCD
#define MDIMCD 8 // Threads per workgroup in M-dimension (e.g. 8, 16, 32)
@ -41,7 +35,7 @@ R"(
#define NDIMBD 8 // Re-shaped tile dimension of matrix B: KDIMBD * NDIMBD
#endif
#ifndef KWID
#define KWID 1 // Unroll factor of the KWGD loop (smaller or equal than KWGD)
#define KWID 1 // Unroll factor of the WGD loop (smaller or equal than WGD)
#endif
#ifndef VWMD
#define VWMD 1 // Vector width of matrices A and C
@ -51,14 +45,14 @@ R"(
#endif
// Helper parameters based on the above tuning parameters
#define MWID (MWGD/MDIMCD) // Work per work-item (M-dimension)
#define NWID (NWGD/NDIMCD) // Work per work-item (N-dimension)
#define MWID (WGD/MDIMCD) // Work per work-item (M-dimension)
#define NWID (WGD/NDIMCD) // Work per work-item (N-dimension)
#define KDIMAD ((MDIMCD*NDIMCD)/(MDIMAD)) // Re-shaped tile dimension of matrix A: KDIMAD * MDIMAD
#define KDIMBD ((MDIMCD*NDIMCD)/(NDIMBD)) // Re-shaped tile dimension of matrix B: KDIMBD * NDIMBD
#define MWAD (MWGD/MDIMAD) // Amount of loads-per-thread for matrix A (M-dimension)
#define KWAD (KWGD/KDIMAD) // Amount of loads-per-thread for matrix A (K-dimension)
#define KWBD (KWGD/KDIMBD) // Amount of loads-per-thread for matrix B (K-dimension)
#define NWBD (NWGD/NDIMBD) // Amount of loads-per-thread for matrix B (N-dimension)
#define MWAD (WGD/MDIMAD) // Amount of loads-per-thread for matrix A (M-dimension)
#define KWAD (WGD/KDIMAD) // Amount of loads-per-thread for matrix A (K-dimension)
#define KWBD (WGD/KDIMBD) // Amount of loads-per-thread for matrix B (K-dimension)
#define NWBD (WGD/NDIMBD) // Amount of loads-per-thread for matrix B (N-dimension)
// =================================================================================================
@ -105,51 +99,51 @@ inline void GlobalToLocalDirectA(const __global realMD* restrict agm, __local re
// Computes the indices for the global memory
int mg = mia + la0*(MWAD/VWMD);
int kg = kia + la1*KWAD;
int idm = (a_transpose) ? mg + kwg/VWMD : mg + GetGroupID0()*(MWGD/VWMD);
int idk = (a_transpose) ? kg + GetGroupID0()*MWGD : kg + kwg;
int idm = (a_transpose) ? mg + kwg/VWMD : mg + GetGroupID0()*(WGD/VWMD);
int idk = (a_transpose) ? kg + GetGroupID0()*WGD : kg + kwg;
// Loads the data from global memory into the local memory
const realMD avec = agm[idk*(a_ld/VWMD) + idm + a_offset];
#if VWMD == 1
alm[kg*MWGD + mg] = avec;
alm[kg*WGD + mg] = avec;
#elif VWMD == 2
alm[kg*MWGD + mg*VWMD + 0] = avec.x;
alm[kg*MWGD + mg*VWMD + 1] = avec.y;
alm[kg*WGD + mg*VWMD + 0] = avec.x;
alm[kg*WGD + mg*VWMD + 1] = avec.y;
#elif VWMD == 4
alm[kg*MWGD + mg*VWMD + 0] = avec.x;
alm[kg*MWGD + mg*VWMD + 1] = avec.y;
alm[kg*MWGD + mg*VWMD + 2] = avec.z;
alm[kg*MWGD + mg*VWMD + 3] = avec.w;
alm[kg*WGD + mg*VWMD + 0] = avec.x;
alm[kg*WGD + mg*VWMD + 1] = avec.y;
alm[kg*WGD + mg*VWMD + 2] = avec.z;
alm[kg*WGD + mg*VWMD + 3] = avec.w;
#elif VWMD == 8
alm[kg*MWGD + mg*VWMD + 0] = avec.s0;
alm[kg*MWGD + mg*VWMD + 1] = avec.s1;
alm[kg*MWGD + mg*VWMD + 2] = avec.s2;
alm[kg*MWGD + mg*VWMD + 3] = avec.s3;
alm[kg*MWGD + mg*VWMD + 4] = avec.s4;
alm[kg*MWGD + mg*VWMD + 5] = avec.s5;
alm[kg*MWGD + mg*VWMD + 6] = avec.s6;
alm[kg*MWGD + mg*VWMD + 7] = avec.s7;
alm[kg*WGD + mg*VWMD + 0] = avec.s0;
alm[kg*WGD + mg*VWMD + 1] = avec.s1;
alm[kg*WGD + mg*VWMD + 2] = avec.s2;
alm[kg*WGD + mg*VWMD + 3] = avec.s3;
alm[kg*WGD + mg*VWMD + 4] = avec.s4;
alm[kg*WGD + mg*VWMD + 5] = avec.s5;
alm[kg*WGD + mg*VWMD + 6] = avec.s6;
alm[kg*WGD + mg*VWMD + 7] = avec.s7;
#elif VWMD == 16
alm[kg*MWGD + mg*VWMD + 0] = avec.s0;
alm[kg*MWGD + mg*VWMD + 1] = avec.s1;
alm[kg*MWGD + mg*VWMD + 2] = avec.s2;
alm[kg*MWGD + mg*VWMD + 3] = avec.s3;
alm[kg*MWGD + mg*VWMD + 4] = avec.s4;
alm[kg*MWGD + mg*VWMD + 5] = avec.s5;
alm[kg*MWGD + mg*VWMD + 6] = avec.s6;
alm[kg*MWGD + mg*VWMD + 7] = avec.s7;
alm[kg*MWGD + mg*VWMD + 8] = avec.s8;
alm[kg*MWGD + mg*VWMD + 9] = avec.s9;
alm[kg*MWGD + mg*VWMD + 10] = avec.sA;
alm[kg*MWGD + mg*VWMD + 11] = avec.sB;
alm[kg*MWGD + mg*VWMD + 12] = avec.sC;
alm[kg*MWGD + mg*VWMD + 13] = avec.sD;
alm[kg*MWGD + mg*VWMD + 14] = avec.sE;
alm[kg*MWGD + mg*VWMD + 15] = avec.sF;
alm[kg*WGD + mg*VWMD + 0] = avec.s0;
alm[kg*WGD + mg*VWMD + 1] = avec.s1;
alm[kg*WGD + mg*VWMD + 2] = avec.s2;
alm[kg*WGD + mg*VWMD + 3] = avec.s3;
alm[kg*WGD + mg*VWMD + 4] = avec.s4;
alm[kg*WGD + mg*VWMD + 5] = avec.s5;
alm[kg*WGD + mg*VWMD + 6] = avec.s6;
alm[kg*WGD + mg*VWMD + 7] = avec.s7;
alm[kg*WGD + mg*VWMD + 8] = avec.s8;
alm[kg*WGD + mg*VWMD + 9] = avec.s9;
alm[kg*WGD + mg*VWMD + 10] = avec.sA;
alm[kg*WGD + mg*VWMD + 11] = avec.sB;
alm[kg*WGD + mg*VWMD + 12] = avec.sC;
alm[kg*WGD + mg*VWMD + 13] = avec.sD;
alm[kg*WGD + mg*VWMD + 14] = avec.sE;
alm[kg*WGD + mg*VWMD + 15] = avec.sF;
#endif
if (a_conjugate) {
for (int vm=0; vm<VWMD; ++vm) {
COMPLEX_CONJUGATE(alm[kg*MWGD + mg*VWMD + vm]);
COMPLEX_CONJUGATE(alm[kg*WGD + mg*VWMD + vm]);
}
}
}
@ -170,51 +164,51 @@ inline void GlobalToLocalDirectB(const __global realND* restrict bgm, __local re
// Computes the indices for the global memory
int ng = nib + lb0*(NWBD/VWND);
int kg = kib + lb1*KWBD;
int idn = (b_transpose) ? ng + kwg/VWND : ng + GetGroupID1()*(NWGD/VWND);
int idk = (b_transpose) ? kg + GetGroupID1()*NWGD : kg + kwg;
int idn = (b_transpose) ? ng + kwg/VWND : ng + GetGroupID1()*(WGD/VWND);
int idk = (b_transpose) ? kg + GetGroupID1()*WGD : kg + kwg;
// Loads the data from global memory into the local memory
const realMD bvec = bgm[idk*(b_ld/VWND) + idn + b_offset];
const realND bvec = bgm[idk*(b_ld/VWND) + idn + b_offset];
#if VWND == 1
blm[kg*NWGD + ng] = bvec;
blm[kg*WGD + ng] = bvec;
#elif VWND == 2
blm[kg*NWGD + ng*VWND + 0] = bvec.x;
blm[kg*NWGD + ng*VWND + 1] = bvec.y;
blm[kg*WGD + ng*VWND + 0] = bvec.x;
blm[kg*WGD + ng*VWND + 1] = bvec.y;
#elif VWND == 4
blm[kg*NWGD + ng*VWND + 0] = bvec.x;
blm[kg*NWGD + ng*VWND + 1] = bvec.y;
blm[kg*NWGD + ng*VWND + 2] = bvec.z;
blm[kg*NWGD + ng*VWND + 3] = bvec.w;
blm[kg*WGD + ng*VWND + 0] = bvec.x;
blm[kg*WGD + ng*VWND + 1] = bvec.y;
blm[kg*WGD + ng*VWND + 2] = bvec.z;
blm[kg*WGD + ng*VWND + 3] = bvec.w;
#elif VWND == 8
blm[kg*NWGD + ng*VWND + 0] = bvec.s0;
blm[kg*NWGD + ng*VWND + 1] = bvec.s1;
blm[kg*NWGD + ng*VWND + 2] = bvec.s2;
blm[kg*NWGD + ng*VWND + 3] = bvec.s3;
blm[kg*NWGD + ng*VWND + 4] = bvec.s4;
blm[kg*NWGD + ng*VWND + 5] = bvec.s5;
blm[kg*NWGD + ng*VWND + 6] = bvec.s6;
blm[kg*NWGD + ng*VWND + 7] = bvec.s7;
blm[kg*WGD + ng*VWND + 0] = bvec.s0;
blm[kg*WGD + ng*VWND + 1] = bvec.s1;
blm[kg*WGD + ng*VWND + 2] = bvec.s2;
blm[kg*WGD + ng*VWND + 3] = bvec.s3;
blm[kg*WGD + ng*VWND + 4] = bvec.s4;
blm[kg*WGD + ng*VWND + 5] = bvec.s5;
blm[kg*WGD + ng*VWND + 6] = bvec.s6;
blm[kg*WGD + ng*VWND + 7] = bvec.s7;
#elif VWND == 16
blm[kg*NWGD + ng*VWND + 0] = bvec.s0;
blm[kg*NWGD + ng*VWND + 1] = bvec.s1;
blm[kg*NWGD + ng*VWND + 2] = bvec.s2;
blm[kg*NWGD + ng*VWND + 3] = bvec.s3;
blm[kg*NWGD + ng*VWND + 4] = bvec.s4;
blm[kg*NWGD + ng*VWND + 5] = bvec.s5;
blm[kg*NWGD + ng*VWND + 6] = bvec.s6;
blm[kg*NWGD + ng*VWND + 7] = bvec.s7;
blm[kg*NWGD + ng*VWND + 8] = bvec.s8;
blm[kg*NWGD + ng*VWND + 9] = bvec.s9;
blm[kg*NWGD + ng*VWND + 10] = bvec.sA;
blm[kg*NWGD + ng*VWND + 11] = bvec.sB;
blm[kg*NWGD + ng*VWND + 12] = bvec.sC;
blm[kg*NWGD + ng*VWND + 13] = bvec.sD;
blm[kg*NWGD + ng*VWND + 14] = bvec.sE;
blm[kg*NWGD + ng*VWND + 15] = bvec.sF;
blm[kg*WGD + ng*VWND + 0] = bvec.s0;
blm[kg*WGD + ng*VWND + 1] = bvec.s1;
blm[kg*WGD + ng*VWND + 2] = bvec.s2;
blm[kg*WGD + ng*VWND + 3] = bvec.s3;
blm[kg*WGD + ng*VWND + 4] = bvec.s4;
blm[kg*WGD + ng*VWND + 5] = bvec.s5;
blm[kg*WGD + ng*VWND + 6] = bvec.s6;
blm[kg*WGD + ng*VWND + 7] = bvec.s7;
blm[kg*WGD + ng*VWND + 8] = bvec.s8;
blm[kg*WGD + ng*VWND + 9] = bvec.s9;
blm[kg*WGD + ng*VWND + 10] = bvec.sA;
blm[kg*WGD + ng*VWND + 11] = bvec.sB;
blm[kg*WGD + ng*VWND + 12] = bvec.sC;
blm[kg*WGD + ng*VWND + 13] = bvec.sD;
blm[kg*WGD + ng*VWND + 14] = bvec.sE;
blm[kg*WGD + ng*VWND + 15] = bvec.sF;
#endif
if (b_conjugate) {
for (int vn=0; vn<VWND; ++vn) {
COMPLEX_CONJUGATE(blm[kg*NWGD + ng*VWND + vn]);
COMPLEX_CONJUGATE(blm[kg*WGD + ng*VWND + vn]);
}
}
}
@ -230,7 +224,7 @@ inline void LocalToPrivateDirectA(__local real* alm, real apm[MWID], const int k
#pragma unroll
for (int mi=0; mi<MWID; ++mi) {
const int mg = mi + get_local_id(0)*MWID;
const int index = (a_transpose) ? mg*KWGD + kg : kg*MWGD + mg;
const int index = (a_transpose) ? mg*WGD + kg : kg*WGD + mg;
apm[mi] = alm[index];
}
}
@ -241,7 +235,7 @@ inline void LocalToPrivateDirectB(__local real* blm, real bpm[NWID], const int k
#pragma unroll
for (int ni=0; ni<NWID; ++ni) {
const int ng = ni + get_local_id(1)*NWID;
const int index = (b_transpose) ? ng*KWGD + kg : kg*NWGD + ng;
const int index = (b_transpose) ? ng*WGD + kg : kg*WGD + ng;
bpm[ni] = blm[index];
}
}
@ -286,8 +280,8 @@ inline void StoreResultsDirect(__global real* cgm, real cpm[NWID][MWID],
for (int mi=0; mi<MWID; ++mi) {
int mg = mi + get_local_id(0)*MWID;
int ng = ni + get_local_id(1)*NWID;
int idm = mg + GetGroupID0() * MWGD;
int idn = ng + GetGroupID1() * NWGD;
int idm = mg + GetGroupID0() * WGD;
int idn = ng + GetGroupID1() * WGD;
// Determines the destination index
const int c_index = (c_transpose) ? idm*c_ld + idn : idn*c_ld + idm;
@ -320,8 +314,8 @@ __kernel void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK,
const __global real* restrict bgms = (const __global real* restrict) bgm;
// Allocates workgroup-private memory (local memory)
__local real alm[KWGD * MWGD];
__local real blm[KWGD * NWGD];
__local real alm[WGD * WGD];
__local real blm[WGD * WGD];
// Combined thread identifier (volatile to disable caching)
volatile int tid = get_local_id(0) + MDIMCD*get_local_id(1);
@ -335,15 +329,15 @@ __kernel void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK,
InitAccRegistersDirect(cpm);
// The faster version of GEMM is not allowed on the (incomplete) borders. Therefore, this section
// processes only the main parts: output blocks of MWGD by NWGD.
const int idm = get_local_id(0) * MWID + GetGroupID0() * MWGD;
const int idn = get_local_id(1) * NWID + GetGroupID1() * NWGD;
if ((idm < (kSizeM/MWGD)*MWGD) && (idn < (kSizeN/NWGD)*NWGD) &&
// processes only the main parts: output blocks of WGD by WGD.
const int idm = get_local_id(0) * MWID + GetGroupID0() * WGD;
const int idn = get_local_id(1) * NWID + GetGroupID1() * WGD;
if ((idm < (kSizeM/WGD)*WGD) && (idn < (kSizeN/WGD)*WGD) &&
(a_ld % VWMD == 0) && (b_ld % VWND == 0)) {
// Loops over all complete workgroup tiles
int kwg = 0;
for (; kwg < (kSizeK/KWGD) * KWGD; kwg+=KWGD) {
for (; kwg < (kSizeK/WGD) * WGD; kwg+=WGD) {
// Loads data: off-chip --> local (matrix A and B)
GlobalToLocalDirectA(agm, alm, a_ld, a_offset, tid, kwg, a_transpose, a_conjugate);
@ -351,7 +345,7 @@ __kernel void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK,
barrier(CLK_LOCAL_MEM_FENCE);
// Loops over all workitem tiles, unrolled by a factor KWID
for (int pwi=0; pwi<KWGD; pwi+=KWID) {
for (int pwi=0; pwi<WGD; pwi+=KWID) {
#pragma unroll
for (int pit=0; pit<KWID; ++pit) {
int kg = pwi + pit;

View File

@ -300,11 +300,11 @@ StatusCode Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k,
kernel.SetArgument(18, static_cast<int>(b_conjugate));
// Computes the global and local thread sizes
const auto m_ceiled = Ceil(m, db_["MWGD"]);
const auto n_ceiled = Ceil(n, db_["NWGD"]);
const auto m_ceiled = Ceil(m, db_["WGD"]);
const auto n_ceiled = Ceil(n, db_["WGD"]);
const auto global = std::vector<size_t>{
(m_ceiled * db_["MDIMCD"]) / db_["MWGD"],
(n_ceiled * db_["NDIMCD"]) / db_["NWGD"]
(m_ceiled * db_["MDIMCD"]) / db_["WGD"],
(n_ceiled * db_["NDIMCD"]) / db_["WGD"]
};
const auto local = std::vector<size_t>{db_["MDIMCD"], db_["NDIMCD"]};

View File

@ -0,0 +1,191 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file uses the CLTune auto-tuner to tune the direct xgemm kernels. There are two variations:
// - V==1: This tests some limited set of tuning parameters exhaustively.
// - V==2: This tests a much larger set of tuning parameters by randomly sampling a subset.
//
// =================================================================================================
#include <string>
#include <vector>
#include "utilities.hpp"
#include "tuning/tuning.hpp"
namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
template <typename T, int V>
class TuneXgemmDirect {
public:
// The representative kernel and the source code
static std::string KernelFamily() { return (V==1) ? "xgemm_direct_1" : "xgemm_direct_2"; }
static std::string KernelName() { return "XgemmDirect"; }
static std::string GetSources() {
return
#include "../src/kernels/common.opencl"
#include "../src/kernels/level3/xgemm_direct.opencl"
;
}
// The list of arguments relevant for this routine
static std::vector<std::string> GetOptions() {
return {kArgM, kArgN, kArgK, kArgAlpha, kArgBeta, kArgFraction};
}
// Tests for valid arguments
static void TestValidArguments(const Arguments<T> &) { }
// Sets the default values for the arguments
static size_t DefaultM() { return 128; }
static size_t DefaultN() { return 128; }
static size_t DefaultK() { return 128; }
static double DefaultFraction() { return (V==1) ? 1.0 : 16.0; } // test all or sample randomly
// Describes how to obtain the sizes of the buffers
static size_t GetSizeX(const Arguments<T> &) { return 1; } // N/A for this kernel
static size_t GetSizeY(const Arguments<T> &) { return 1; } // N/A for this kernel
static size_t GetSizeA(const Arguments<T> &args) { return args.m * args.k; }
static size_t GetSizeB(const Arguments<T> &args) { return args.n * args.k; }
static size_t GetSizeC(const Arguments<T> &args) { return args.m * args.n; }
static size_t GetSizeTemp(const Arguments<T> &) { return 1; } // N/A for this kernel
// Sets the tuning parameters and their possible values
static void SetParameters(cltune::Tuner &tuner, const size_t id) {
if (V==1) { // limited subset of tuning parameters - but explorable exhaustively
tuner.AddParameter(id, "WGD", {8, 16, 32});
tuner.AddParameter(id, "MDIMCD", {8, 16, 32});
tuner.AddParameter(id, "NDIMCD", {8, 16, 32});
tuner.AddParameter(id, "MDIMAD", {8, 16, 32});
tuner.AddParameter(id, "NDIMBD", {8, 16, 32});
tuner.AddParameter(id, "KWID", {2});
tuner.AddParameter(id, "VWMD", {1, 2, 4, 8});
tuner.AddParameter(id, "VWND", {1, 2, 4, 8});
} // a lot more tuning parameters - has to be sampled randomly, too much to test all
else {
tuner.AddParameter(id, "WGD", {8, 16, 32, 64, 128});
tuner.AddParameter(id, "MDIMCD", {8, 16, 32});
tuner.AddParameter(id, "NDIMCD", {8, 16, 32});
tuner.AddParameter(id, "MDIMAD", {8, 16, 32});
tuner.AddParameter(id, "NDIMBD", {8, 16, 32});
tuner.AddParameter(id, "KWID", {2, 8, 16});
tuner.AddParameter(id, "VWMD", {1, 2, 4, 8});
tuner.AddParameter(id, "VWND", {1, 2, 4, 8});
}
}
// Sets the constraints
static void SetConstraints(cltune::Tuner &tuner, const size_t id) {
auto MultipleOfX = [] (std::vector<size_t> v) { return IsMultiple(v[0], v[1]); };
auto MultipleOfXMulY = [] (std::vector<size_t> v) { return IsMultiple(v[0], v[1]*v[2]); };
auto MultipleOfXMulYDivZ = [] (std::vector<size_t> v) { return IsMultiple(v[0], (v[1]*v[2])/v[3]); };
// Requirement for unrolling the WGD loop
tuner.AddConstraint(id, MultipleOfX, {"WGD", "KWID"});
// Required for integer MWID and NWID
tuner.AddConstraint(id, MultipleOfXMulY, {"WGD", "MDIMCD", "VWMD"});
tuner.AddConstraint(id, MultipleOfXMulY, {"WGD", "NDIMCD", "VWND"});
// Required for integer MWIAD and NWIBD
tuner.AddConstraint(id, MultipleOfXMulY, {"WGD", "MDIMAD", "VWMD"});
tuner.AddConstraint(id, MultipleOfXMulY, {"WGD", "NDIMBD", "VWND"});
// WGD has to be a multiple of KDIMAD = ((MDIMCD*NDIMCD)/(MDIMAD)) and KDIMBD = (...)
tuner.AddConstraint(id, MultipleOfXMulYDivZ, {"WGD", "MDIMCD", "NDIMCD", "MDIMAD"});
tuner.AddConstraint(id, MultipleOfXMulYDivZ, {"WGD", "MDIMCD", "NDIMCD", "NDIMBD"});
// Extra constraints for variation 1 to limit the set of options significantly
if (V==1) {
auto IsEqual = [] (std::vector<size_t> v) { return v[0] == v[1]; };
tuner.AddConstraint(id, IsEqual, {"MDIMCD", "MDIMAD"});
tuner.AddConstraint(id, IsEqual, {"NDIMCD", "NDIMBD"});
}
}
// Sets the local memory size
static void SetLocalMemorySize(cltune::Tuner &tuner, const size_t id, const Arguments<T> &args) {
auto LocalMemorySize = [args] (std::vector<size_t> v) {
return ((v[0]*v[1] + v[2]*v[3])*GetBytes(args.precision));
};
tuner.SetLocalMemoryUsage(id, LocalMemorySize, {"WGD", "WGD", "WGD", "WGD"});
}
// Sets the base thread configuration
static std::vector<size_t> GlobalSize(const Arguments<T> &args) { return {args.m, args.n}; }
static std::vector<size_t> GlobalSizeRef(const Arguments<T> &args) { return GlobalSize(args); }
static std::vector<size_t> LocalSize() { return {1, 1}; }
static std::vector<size_t> LocalSizeRef() { return {8, 8}; }
// Transforms the thread configuration based on the parameters
using TransformVector = std::vector<std::vector<std::string>>;
static TransformVector MulLocal() { return {{"MDIMCD", "NDIMCD"}}; }
static TransformVector DivLocal() { return {}; }
static TransformVector MulGlobal() { return {{"MDIMCD", "NDIMCD"}}; }
static TransformVector DivGlobal() { return {{"WGD", "WGD"}}; }
// Sets the kernel's arguments
static void SetArguments(cltune::Tuner &tuner, const Arguments<T> &args,
std::vector<T> &, std::vector<T> &,
std::vector<T> &a_mat, std::vector<T> &b_mat, std::vector<T> &c_mat,
std::vector<T> &) {
tuner.AddArgumentScalar(static_cast<int>(args.m));
tuner.AddArgumentScalar(static_cast<int>(args.n));
tuner.AddArgumentScalar(static_cast<int>(args.k));
tuner.AddArgumentScalar(GetRealArg(args.alpha));
tuner.AddArgumentScalar(GetRealArg(args.beta));
tuner.AddArgumentInput(a_mat);
tuner.AddArgumentScalar(0); // a_offset
tuner.AddArgumentScalar(static_cast<int>(args.k)); // a_ld
tuner.AddArgumentInput(b_mat);
tuner.AddArgumentScalar(0); // b_offset
tuner.AddArgumentScalar(static_cast<int>(args.n)); // b_ld
tuner.AddArgumentOutput(c_mat);
tuner.AddArgumentScalar(0); // c_offset
tuner.AddArgumentScalar(static_cast<int>(args.n)); // c_ld
tuner.AddArgumentScalar(1); // a_do_transpose
tuner.AddArgumentScalar(1); // b_do_transpose
tuner.AddArgumentScalar(1); // c_do_transpose
tuner.AddArgumentScalar(0); // a_conjugate
tuner.AddArgumentScalar(0); // b_conjugate
}
// Describes how to compute the performance metrics
static size_t GetMetric(const Arguments<T> &args) {
return 2 * args.m * args.n * args.k;
}
static std::string PerformanceUnit() { return "GFLOPS"; }
};
// =================================================================================================
} // namespace clblast
// Shortcuts to the clblast namespace
using float2 = clblast::float2;
using double2 = clblast::double2;
// Function to tune a specific variation V (not within the clblast namespace)
template <int V>
void StartVariation(int argc, char *argv[]) {
switch(clblast::GetPrecision(argc, argv)) {
case clblast::Precision::kHalf: clblast::Tuner<clblast::TuneXgemmDirect<half,V>, half>(argc, argv); break;
case clblast::Precision::kSingle: clblast::Tuner<clblast::TuneXgemmDirect<float,V>, float>(argc, argv); break;
case clblast::Precision::kDouble: clblast::Tuner<clblast::TuneXgemmDirect<double,V>, double>(argc, argv); break;
case clblast::Precision::kComplexSingle: clblast::Tuner<clblast::TuneXgemmDirect<float2,V>, float2>(argc, argv); break;
case clblast::Precision::kComplexDouble: clblast::Tuner<clblast::TuneXgemmDirect<double2,V>, double2>(argc, argv); break;
}
}
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
StartVariation<1>(argc, argv);
StartVariation<2>(argc, argv);
return 0;
}
// =================================================================================================