Added a lower/upper triangular version of the GEMM kernel

pull/8/head
CNugteren 2015-06-23 17:58:51 +02:00
parent 0a3831e6d1
commit 9fc38cdf5e
1 changed files with 244 additions and 135 deletions

View File

@ -127,6 +127,55 @@ R"(
// =================================================================================================
// Initializes the accumulation registers to zero
inline void InitAccRegisters(realM cpm[NWI][MWI/VWM]) {
#pragma unroll
for (int mi=0; mi<MWI/VWM; ++mi) {
#pragma unroll
for (int ni=0; ni<NWI; ++ni) {
#if VWM == 1
SetToZero(cpm[ni][mi]);
#elif VWM == 2
SetToZero(cpm[ni][mi].x);
SetToZero(cpm[ni][mi].y);
#elif VWM == 4
SetToZero(cpm[ni][mi].x);
SetToZero(cpm[ni][mi].y);
SetToZero(cpm[ni][mi].z);
SetToZero(cpm[ni][mi].w);
#elif VWM == 8
SetToZero(cpm[ni][mi].s0);
SetToZero(cpm[ni][mi].s1);
SetToZero(cpm[ni][mi].s2);
SetToZero(cpm[ni][mi].s3);
SetToZero(cpm[ni][mi].s4);
SetToZero(cpm[ni][mi].s5);
SetToZero(cpm[ni][mi].s6);
SetToZero(cpm[ni][mi].s7);
#elif VWM == 16
SetToZero(cpm[ni][mi].s0);
SetToZero(cpm[ni][mi].s1);
SetToZero(cpm[ni][mi].s2);
SetToZero(cpm[ni][mi].s3);
SetToZero(cpm[ni][mi].s4);
SetToZero(cpm[ni][mi].s5);
SetToZero(cpm[ni][mi].s6);
SetToZero(cpm[ni][mi].s7);
SetToZero(cpm[ni][mi].s8);
SetToZero(cpm[ni][mi].s9);
SetToZero(cpm[ni][mi].sA);
SetToZero(cpm[ni][mi].sB);
SetToZero(cpm[ni][mi].sC);
SetToZero(cpm[ni][mi].sD);
SetToZero(cpm[ni][mi].sE);
SetToZero(cpm[ni][mi].sF);
#endif
}
}
}
// =================================================================================================
// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for
// caching the A input matrix.
#if SA == 1
@ -272,71 +321,6 @@ inline void LocalToPrivateB(__local realN* blm, realN bpm[NWI/VWN], const int kg
// =================================================================================================
// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm
inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM,
const real alpha, const real beta) {
#pragma unroll
for (int ni=0; ni<NWI; ++ni) {
#pragma unroll
for (int mi=0; mi<MWI/VWM; ++mi) {
#if STRM == 0
int mg = mi + get_local_id(0)*(MWI/VWM);
#elif STRM == 1
int mg = get_local_id(0) + mi*MDIMC;
#endif
#if STRN == 0
int ng = ni + get_local_id(1)*NWI;
#elif STRN == 1
int ng = ni%VWN + get_local_id(1)*VWN + (ni/VWN)*VWN*NDIMC;
#endif
int idm = mg + get_group_id(0)*(MWG/VWM);
int idn = ng + get_group_id(1)*NWG;
int index = idn*(kSizeM/VWM) + idm;
realM cval = cgm[index];
#if VWM == 1
AXPBY(cgm[index], alpha, cpm[ni][mi], beta, cval);
#elif VWM == 2
AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
#elif VWM == 4
AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
AXPBY(cgm[index].z, alpha, cpm[ni][mi].z, beta, cval.z);
AXPBY(cgm[index].w, alpha, cpm[ni][mi].w, beta, cval.w);
#elif VWM == 8
AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
#elif VWM == 16
AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
AXPBY(cgm[index].s8, alpha, cpm[ni][mi].s8, beta, cval.s8);
AXPBY(cgm[index].s9, alpha, cpm[ni][mi].s9, beta, cval.s9);
AXPBY(cgm[index].sA, alpha, cpm[ni][mi].sA, beta, cval.sA);
AXPBY(cgm[index].sB, alpha, cpm[ni][mi].sB, beta, cval.sB);
AXPBY(cgm[index].sC, alpha, cpm[ni][mi].sC, beta, cval.sC);
AXPBY(cgm[index].sD, alpha, cpm[ni][mi].sD, beta, cval.sD);
AXPBY(cgm[index].sE, alpha, cpm[ni][mi].sE, beta, cval.sE);
AXPBY(cgm[index].sF, alpha, cpm[ni][mi].sF, beta, cval.sF);
#endif
}
}
}
// =================================================================================================
// The vectorised multiply-add function
inline realM MultiplyAddVector(realM cvec, const realM avec, const real bval) {
#if USE_VECTOR_MAD == 1
@ -432,77 +416,97 @@ inline void MultiplyAccumulate(realM cpm[NWI][MWI/VWM], realM apm[MWI/VWM], real
// =================================================================================================
// Main entry of the kernel. This function contains the basic skeleton, the functionality is
// provided by the inlined functions above
__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
__kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
const real alpha, const real beta,
const __global realM* restrict agm,
const __global realN* restrict bgm,
__global realM* cgm) {
// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm
inline void StoreResults(__global realM* cgm, realM cpm[NWI][MWI/VWM], const int kSizeM,
const real alpha, const real beta) {
#pragma unroll
for (int ni=0; ni<NWI; ++ni) {
#pragma unroll
for (int mi=0; mi<MWI/VWM; ++mi) {
#if STRM == 0
int mg = mi + get_local_id(0)*(MWI/VWM);
#elif STRM == 1
int mg = get_local_id(0) + mi*MDIMC;
#endif
#if STRN == 0
int ng = ni + get_local_id(1)*NWI;
#elif STRN == 1
int ng = ni%VWN + get_local_id(1)*VWN + (ni/VWN)*VWN*NDIMC;
#endif
int idm = mg + get_group_id(0)*(MWG/VWM);
int idn = ng + get_group_id(1)*NWG;
// Combined thread identifier
// The final multiplication with alpha and the addition with beta*C
int index = idn*(kSizeM/VWM) + idm;
realM cval = cgm[index];
#if VWM == 1
AXPBY(cgm[index], alpha, cpm[ni][mi], beta, cval);
#elif VWM == 2
AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
#elif VWM == 4
AXPBY(cgm[index].x, alpha, cpm[ni][mi].x, beta, cval.x);
AXPBY(cgm[index].y, alpha, cpm[ni][mi].y, beta, cval.y);
AXPBY(cgm[index].z, alpha, cpm[ni][mi].z, beta, cval.z);
AXPBY(cgm[index].w, alpha, cpm[ni][mi].w, beta, cval.w);
#elif VWM == 8
AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
#elif VWM == 16
AXPBY(cgm[index].s0, alpha, cpm[ni][mi].s0, beta, cval.s0);
AXPBY(cgm[index].s1, alpha, cpm[ni][mi].s1, beta, cval.s1);
AXPBY(cgm[index].s2, alpha, cpm[ni][mi].s2, beta, cval.s2);
AXPBY(cgm[index].s3, alpha, cpm[ni][mi].s3, beta, cval.s3);
AXPBY(cgm[index].s4, alpha, cpm[ni][mi].s4, beta, cval.s4);
AXPBY(cgm[index].s5, alpha, cpm[ni][mi].s5, beta, cval.s5);
AXPBY(cgm[index].s6, alpha, cpm[ni][mi].s6, beta, cval.s6);
AXPBY(cgm[index].s7, alpha, cpm[ni][mi].s7, beta, cval.s7);
AXPBY(cgm[index].s8, alpha, cpm[ni][mi].s8, beta, cval.s8);
AXPBY(cgm[index].s9, alpha, cpm[ni][mi].s9, beta, cval.s9);
AXPBY(cgm[index].sA, alpha, cpm[ni][mi].sA, beta, cval.sA);
AXPBY(cgm[index].sB, alpha, cpm[ni][mi].sB, beta, cval.sB);
AXPBY(cgm[index].sC, alpha, cpm[ni][mi].sC, beta, cval.sC);
AXPBY(cgm[index].sD, alpha, cpm[ni][mi].sD, beta, cval.sD);
AXPBY(cgm[index].sE, alpha, cpm[ni][mi].sE, beta, cval.sE);
AXPBY(cgm[index].sF, alpha, cpm[ni][mi].sF, beta, cval.sF);
#endif
}
}
}
// =================================================================================================
// Main body of the matrix-multiplication algorithm. It calls the (inlined) functions above.
inline void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
const __global realM* restrict agm, const __global realN* restrict bgm,
__global realM* cgm, realM cpm[NWI][MWI/VWM],
#if SA == 1 && SB == 1
__local realM* alm, __local realN* blm
#elif SA == 1
__local realM* alm
#elif SB == 1
__local realN* blm
#endif
) {
// Allocates workitem-private memory (registers)
realM apm[MWI/VWM];
realN bpm[NWI/VWN];
// Combined thread identifier (volatile to disable caching)
#if SA == 1 || SB == 1
volatile int tid = get_local_id(0) + MDIMC*get_local_id(1);
#endif
// Allocates workgroup-private memory (local memory)
#if SA == 1
__local realM alm[KWG * MWG/VWM];
#endif
#if SB == 1
__local realN blm[KWG * NWG/VWN];
#endif
// Allocates workitem-private memory (registers)
realM apm[MWI/VWM];
realN bpm[NWI/VWN];
realM cpm[NWI][MWI/VWM];
// Initializes the accumulation registers
#pragma unroll
for (int mi=0; mi<MWI/VWM; ++mi) {
#pragma unroll
for (int ni=0; ni<NWI; ++ni) {
#if VWM == 1
SetToZero(cpm[ni][mi]);
#elif VWM == 2
SetToZero(cpm[ni][mi].x);
SetToZero(cpm[ni][mi].y);
#elif VWM == 4
SetToZero(cpm[ni][mi].x);
SetToZero(cpm[ni][mi].y);
SetToZero(cpm[ni][mi].z);
SetToZero(cpm[ni][mi].w);
#elif VWM == 8
SetToZero(cpm[ni][mi].s0);
SetToZero(cpm[ni][mi].s1);
SetToZero(cpm[ni][mi].s2);
SetToZero(cpm[ni][mi].s3);
SetToZero(cpm[ni][mi].s4);
SetToZero(cpm[ni][mi].s5);
SetToZero(cpm[ni][mi].s6);
SetToZero(cpm[ni][mi].s7);
#elif VWM == 16
SetToZero(cpm[ni][mi].s0);
SetToZero(cpm[ni][mi].s1);
SetToZero(cpm[ni][mi].s2);
SetToZero(cpm[ni][mi].s3);
SetToZero(cpm[ni][mi].s4);
SetToZero(cpm[ni][mi].s5);
SetToZero(cpm[ni][mi].s6);
SetToZero(cpm[ni][mi].s7);
SetToZero(cpm[ni][mi].s8);
SetToZero(cpm[ni][mi].s9);
SetToZero(cpm[ni][mi].sA);
SetToZero(cpm[ni][mi].sB);
SetToZero(cpm[ni][mi].sC);
SetToZero(cpm[ni][mi].sD);
SetToZero(cpm[ni][mi].sE);
SetToZero(cpm[ni][mi].sF);
#endif
}
}
InitAccRegisters(cpm);
// Loops over all workgroup tiles
for (int kwg=0; kwg<kSizeK; kwg+=KWG) {
@ -515,8 +519,6 @@ __kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
#if SB == 1
GlobalToLocalB(bgm, blm, kSizeN, tid, kwg);
#endif
// Synchronizes all threads in a workgroup
#if SA == 1 || SB == 1
barrier(CLK_LOCAL_MEM_FENCE);
#endif
@ -552,19 +554,126 @@ __kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
MultiplyAccumulate(cpm, apm, bpm);
}
}
// Synchronizes all threads in a workgroup
#if SA == 1 || SB == 1
barrier(CLK_LOCAL_MEM_FENCE);
#endif
}
}
// Stores an MWG * NWG tile of results and perform the multiplication with alpha and beta
// =================================================================================================
// Main entry point of the kernel. This is the regular full version.
__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
__kernel void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
const real alpha, const real beta,
const __global realM* restrict agm,
const __global realN* restrict bgm,
__global realM* cgm) {
// Allocates workgroup-private memory (local memory)
#if SA == 1
__local realM alm[KWG * MWG/VWM];
#endif
#if SB == 1
__local realN blm[KWG * NWG/VWN];
#endif
// Computes the matrix-multiplication and stores the result in register memory
realM cpm[NWI][MWI/VWM];
#if SA == 1 && SB == 1
XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
#elif SA == 1
XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
#elif SB == 1
XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
#else
XgemmBody(kSizeM, kSizeN, kSizeK, agm, bgm, cgm, cpm);
#endif
// Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
StoreResults(cgm, cpm, kSizeM, alpha, beta);
}
// =================================================================================================
// Main entry point of the kernel. This is the upper-triangular version.
__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
__kernel void XgemmUpper(const int kSizeN, const int kSizeK,
const real alpha, const real beta,
const __global realM* restrict agm,
const __global realN* restrict bgm,
__global realM* cgm) {
// Skip these threads if they do not contain threads contributing to the upper-triangle
if (get_group_id(1)*NWG < get_group_id(0)*MWG) {
return;
}
// Allocates workgroup-private memory (local memory)
#if SA == 1
__local realM alm[KWG * MWG/VWM];
#endif
#if SB == 1
__local realN blm[KWG * NWG/VWN];
#endif
// Computes the matrix-multiplication and stores the result in register memory
realM cpm[NWI][MWI/VWM];
#if SA == 1 && SB == 1
XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
#elif SA == 1
XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
#elif SB == 1
XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
#else
XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
#endif
// Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
StoreResults(cgm, cpm, kSizeN, alpha, beta);
}
// =================================================================================================
// Main entry point of the kernel. This is the lower-triangular version.
__attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
__kernel void XgemmLower(const int kSizeN, const int kSizeK,
const real alpha, const real beta,
const __global realM* restrict agm,
const __global realN* restrict bgm,
__global realM* cgm) {
// Skip these threads if they do not contain threads contributing to the lower-triangle
if (get_group_id(1)*NWG > get_group_id(0)*MWG) {
return;
}
// Allocates workgroup-private memory (local memory)
#if SA == 1
__local realM alm[KWG * MWG/VWM];
#endif
#if SB == 1
__local realN blm[KWG * NWG/VWN];
#endif
// Computes the matrix-multiplication and stores the result in register memory
realM cpm[NWI][MWI/VWM];
#if SA == 1 && SB == 1
XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm, blm);
#elif SA == 1
XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, alm);
#elif SB == 1
XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm, blm);
#else
XgemmBody(kSizeN, kSizeN, kSizeK, agm, bgm, cgm, cpm);
#endif
// Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
StoreResults(cgm, cpm, kSizeN, alpha, beta);
}
// =================================================================================================
// End of the C++11 raw string literal
)";