Specialised the GEMM direct kernel in four ways for transposing/non-transposing: NN, NT, TN, TT

This commit is contained in:
Cedric Nugteren 2016-10-02 17:59:05 +02:00
parent 61f489e370
commit d8827e908c
3 changed files with 82 additions and 26 deletions

View file

@ -51,16 +51,16 @@ inline void StoreResultsDirect(__global real* cgm, real cpm[NWID][MWID],
// =================================================================================================
// Main entry point of the kernel. This is the direct version without restrictions.
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
__kernel void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK,
const real_arg arg_alpha,
const real_arg arg_beta,
const __global realMD* restrict agm, const int a_offset, const int a_ld,
const __global realND* restrict bgm, const int b_offset, const int b_ld,
__global real* cgm, const int c_offset, const int c_ld,
const int a_transpose, const int b_transpose, const int c_transpose,
const int a_conjugate, const int b_conjugate) {
// Main body of the kernel. This is the direct version without restrictions.
inline void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK,
const real_arg arg_alpha,
const real_arg arg_beta,
const __global realMD* restrict agm, const int a_offset, const int a_ld,
const __global realND* restrict bgm, const int b_offset, const int b_ld,
__global real* cgm, const int c_offset, const int c_ld,
__local real* alm, __local real* blm,
const int a_transpose, const int b_transpose, const int c_transpose,
const int a_conjugate, const int b_conjugate) {
const real alpha = GetRealArg(arg_alpha);
const real beta = GetRealArg(arg_beta);
@ -68,10 +68,6 @@ __kernel void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK,
const __global real* restrict agms = (const __global real* restrict) agm;
const __global real* restrict bgms = (const __global real* restrict) bgm;
// Allocates workgroup-private memory (local memory)
__local real alm[WGD * (WGD + PADA)];
__local real blm[WGD * (WGD + PADB)];
// Allocates workitem-private memory (registers)
real apm[MWID];
real bpm[NWID];
@ -201,6 +197,68 @@ __kernel void XgemmDirect(const int kSizeM, const int kSizeN, const int kSizeK,
// =================================================================================================
// Direct version of the GEMM kernel with [A, B] = [non-transposed, non-transposed]
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
__kernel void XgemmDirectNN(const int kSizeM, const int kSizeN, const int kSizeK,
const real_arg arg_alpha, const real_arg arg_beta,
const __global realMD* restrict agm, const int a_offset, const int a_ld,
const __global realND* restrict bgm, const int b_offset, const int b_ld,
__global real* cgm, const int c_offset, const int c_ld,
const int c_transpose, const int a_conjugate, const int b_conjugate) {
__local real alm[WGD * (WGD + PADA)];
__local real blm[WGD * (WGD + PADB)];
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
alm, blm, 0, 0, c_transpose, a_conjugate, b_conjugate);
}
// Direct version of the GEMM kernel with [A, B] = [non-transposed, transposed]
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
__kernel void XgemmDirectNT(const int kSizeM, const int kSizeN, const int kSizeK,
const real_arg arg_alpha, const real_arg arg_beta,
const __global realMD* restrict agm, const int a_offset, const int a_ld,
const __global realND* restrict bgm, const int b_offset, const int b_ld,
__global real* cgm, const int c_offset, const int c_ld,
const int c_transpose, const int a_conjugate, const int b_conjugate) {
__local real alm[WGD * (WGD + PADA)];
__local real blm[WGD * (WGD + PADB)];
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
alm, blm, 0, 1, c_transpose, a_conjugate, b_conjugate);
}
// Direct version of the GEMM kernel with [A, B] = [transposed, non-transposed]
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
__kernel void XgemmDirectTN(const int kSizeM, const int kSizeN, const int kSizeK,
const real_arg arg_alpha, const real_arg arg_beta,
const __global realMD* restrict agm, const int a_offset, const int a_ld,
const __global realND* restrict bgm, const int b_offset, const int b_ld,
__global real* cgm, const int c_offset, const int c_ld,
const int c_transpose, const int a_conjugate, const int b_conjugate) {
__local real alm[WGD * (WGD + PADA)];
__local real blm[WGD * (WGD + PADB)];
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
alm, blm, 1, 0, c_transpose, a_conjugate, b_conjugate);
}
// Direct version of the GEMM kernel with [A, B] = [transposed, transposed]
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
__kernel void XgemmDirectTT(const int kSizeM, const int kSizeN, const int kSizeK,
const real_arg arg_alpha, const real_arg arg_beta,
const __global realMD* restrict agm, const int a_offset, const int a_ld,
const __global realND* restrict bgm, const int b_offset, const int b_ld,
__global real* cgm, const int c_offset, const int c_ld,
const int c_transpose, const int a_conjugate, const int b_conjugate) {
__local real alm[WGD * (WGD + PADA)];
__local real blm[WGD * (WGD + PADB)];
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
alm, blm, 1, 1, c_transpose, a_conjugate, b_conjugate);
}
// =================================================================================================
// End of the C++11 raw string literal
)"

View file

@ -275,9 +275,11 @@ StatusCode Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k,
// Loads the program from the database
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
// Retrieves the XgemmDirect kernel from the compiled binary
// Retrieves the proper XgemmDirect kernel from the compiled binary
try {
auto kernel = Kernel(program, "XgemmDirect");
const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectTT" : "XgemmDirectTN") :
(b_do_transpose ? "XgemmDirectNT" : "XgemmDirectNN");
auto kernel = Kernel(program, name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(m));
@ -294,11 +296,9 @@ StatusCode Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k,
kernel.SetArgument(11, c_buffer());
kernel.SetArgument(12, static_cast<int>(c_offset));
kernel.SetArgument(13, static_cast<int>(c_ld));
kernel.SetArgument(14, static_cast<int>(a_do_transpose));
kernel.SetArgument(15, static_cast<int>(b_do_transpose));
kernel.SetArgument(16, static_cast<int>(c_do_transpose));
kernel.SetArgument(17, static_cast<int>(a_conjugate));
kernel.SetArgument(18, static_cast<int>(b_conjugate));
kernel.SetArgument(14, static_cast<int>(c_do_transpose));
kernel.SetArgument(15, static_cast<int>(a_conjugate));
kernel.SetArgument(16, static_cast<int>(b_conjugate));
// Computes the global and local thread sizes
const auto m_ceiled = Ceil(m, db_["WGD"]);

View file

@ -29,7 +29,7 @@ class TuneXgemmDirect {
// The representative kernel and the source code
static std::string KernelFamily() { return (V==1) ? "xgemm_direct_1" : "xgemm_direct_2"; }
static std::string KernelName() { return "XgemmDirect"; }
static std::string KernelName() { return "XgemmDirectTN"; }
static std::string GetSources() {
return
#include "../src/kernels/common.opencl"
@ -50,8 +50,8 @@ class TuneXgemmDirect {
static size_t DefaultM() { return 256; }
static size_t DefaultN() { return 256; }
static size_t DefaultK() { return 256; }
static double DefaultFraction() { return (V==1) ? 1.0 : 16.0; } // test all or sample randomly
static size_t DefaultNumRuns() { return 10; } // run every kernel this many times for averaging
static double DefaultFraction() { return (V==1) ? 1.0 : 32.0; } // test all or sample randomly
static size_t DefaultNumRuns() { return 4; } // run every kernel this many times for averaging
// Describes how to obtain the sizes of the buffers
static size_t GetSizeX(const Arguments<T> &) { return 1; } // N/A for this kernel
@ -154,8 +154,6 @@ class TuneXgemmDirect {
tuner.AddArgumentOutput(c_mat);
tuner.AddArgumentScalar(0); // c_offset
tuner.AddArgumentScalar(static_cast<int>(args.n)); // c_ld
tuner.AddArgumentScalar(1); // a_do_transpose
tuner.AddArgumentScalar(0); // b_do_transpose
tuner.AddArgumentScalar(1); // c_do_transpose
tuner.AddArgumentScalar(0); // a_conjugate
tuner.AddArgumentScalar(0); // b_conjugate