Created a dedicated convgemm GEMM kernel as a copy of the batched direct gemm kernel

pull/319/head
Cedric Nugteren 2018-05-13 22:10:06 +02:00
parent ad8f1027ab
commit 0cb9580042
3 changed files with 250 additions and 49 deletions

View File

@ -0,0 +1,228 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file contains the an implementation of 3D convolution on a 4D image using GEMM kernels. It
// uses parameters from the direct GEMM kernel.
//
// =================================================================================================
// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
// literal). Comment-out this line for syntax-highlighting when developing.
R"(
// =================================================================================================
#if defined(ROUTINE_CONVGEMM)
// ConvGEMM kernel
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
void Xconvgemm(const int num_patches, const int num_kernels, const int patch_size,
const __global realMD* restrict colgm, const int col_offset, const int col_stride,
const __global realND* restrict kernelgm, const int kernel_offset,
__global real* resultgm, const int result_offset, const int result_stride) {
// Batch offsets
const int batch = get_group_id(2);
const int col_offset_batch = col_offset + col_stride * batch;
const int result_offset_batch = result_offset + result_stride * batch;
__local real alm[WGD * (WGD + PADA)];
__local real blm[WGD * (WGD + PADB)];
// Extra pointers to scalar versions of global memory
const __global real* restrict colgms = (const __global real* restrict) colgm;
const __global real* restrict kernelgms = (const __global real* restrict) kernelgm;
// Allocates workitem-private memory (registers)
#pragma promote_to_registers
real apd[MWID];
#pragma promote_to_registers
real bpd[NWID];
#pragma promote_to_registers
real cpd[NWID * MWID];
// Initializes the accumulation registers
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
SetToZero(cpd[_ni * MWID + _mi]);
}
}
// The faster version of GEMM is not allowed on the (incomplete) borders. Therefore, this section
// processes only the main parts: output blocks of WGD by WGD.
const int idm = get_local_id(0) * MWID + GetGroupID0() * WGD;
const int idn = get_local_id(1) * NWID + GetGroupID1() * WGD;
if ((idm < (num_patches/WGD)*WGD) && (idn < (num_kernels/WGD)*WGD)) {
// Loops over all complete workgroup tiles (K-dimension)
int kwg = 0;
for (; kwg < (patch_size/WGD) * WGD; kwg += WGD) {
// Loads data: off-chip --> local (matrix A and B)
if (num_patches % VWMD == 0 && col_offset_batch % VWMD == 0) {
GlobalToLocalDirectA(colgm, alm, num_patches, col_offset_batch, kwg, false, false);
}
else {
GlobalToLocalScalarA(colgms, alm, num_patches, col_offset_batch, kwg, false, false);
}
if (patch_size % VWND == 0 && kernel_offset % VWND == 0) {
GlobalToLocalDirectB(kernelgm, blm, patch_size, kernel_offset, kwg, true, false);
}
else {
GlobalToLocalScalarB(kernelgms, blm, patch_size, kernel_offset, kwg, true, false);
}
barrier(CLK_LOCAL_MEM_FENCE);
// Loops over all workitem tiles, unrolled by a factor KWID
for (int pwi = 0; pwi < WGD; pwi += KWID) {
#pragma unroll
for (int _pit = 0; _pit < KWID; _pit += 1) {
int kg = pwi + _pit;
// Loads data: local --> private (matrix A and B)
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
apd[_mi] = LocalToPrivateDirectA(alm, _mi, kg, false);
}
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
bpd[_ni] = LocalToPrivateDirectB(blm, _ni, kg, true);
}
// Performs the accumulation (Cpmd += Apmd * Bpmd)
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
MultiplyAdd(cpd[_ni * MWID + _mi], apd[_mi], bpd[_ni]);
}
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
// Loop over the remaining part (incomplete tile in K-dimension)
for (; kwg < patch_size; ++kwg) {
// Loads data: off-chip --> private (matrix A and B)
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
apd[_mi] = GlobalToPrivateDirectA(colgms, _mi, num_patches, col_offset_batch, idm, kwg, false, false);
}
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
bpd[_ni] = GlobalToPrivateDirectB(kernelgms, _ni, patch_size, kernel_offset, idn, kwg, true, false);
}
// Performs the accumulation (Cpmd += Apmd * Bpmd)
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
MultiplyAdd(cpd[_ni * MWID + _mi], apd[_mi], bpd[_ni]);
}
}
}
// Stores a tile of results
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
StoreResultsDirect(resultgm, cpd[_ni * MWID + _mi], _mi, _ni, idm, idn,
ONE, ZERO, num_patches, result_offset_batch, false);
}
}
}
// Simple but slower version for the parts on the edge (incomplete tiles in M and N-dimensions)
else {
// Loops over all complete workgroup tiles (K-dimension)
int kwg = 0;
for (; kwg < (patch_size/WGD) * WGD; kwg+=WGD) {
// Loads data: off-chip --> local (matrix A and B)
GlobalToLocalCheckedA(colgms, alm, num_patches, col_offset_batch, kwg, false, false, num_patches, patch_size);
GlobalToLocalCheckedB(kernelgms, blm, patch_size, kernel_offset, kwg, true, false, num_kernels, patch_size);
barrier(CLK_LOCAL_MEM_FENCE);
// Loops over all workitem tiles, unrolled by a factor KWID
for (int pwi = 0; pwi < WGD; pwi += KWID) {
#pragma unroll
for (int _pit = 0; _pit < KWID; _pit += 1) {
int kg = pwi + _pit;
// Loads data: local --> private (matrix A and B)
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
apd[_mi] = LocalToPrivateDirectA(alm, _mi, kg, false);
}
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
bpd[_ni] = LocalToPrivateDirectB(blm, _ni, kg, true);
}
// Performs the accumulation (C += A * B)
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
MultiplyAdd(cpd[_ni * MWID + _mi], apd[_mi], bpd[_ni]);
}
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
// Loop over the remaining part (incomplete tile in K-dimension)
for (; kwg < patch_size; ++kwg) {
// Loads data: off-chip --> private (matrix A and B)
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
apd[_mi] = GlobalToPrivateCheckedA(colgms, _mi, num_patches, col_offset_batch, idm, kwg, false, false, num_patches);
}
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
bpd[_ni] = GlobalToPrivateCheckedB(kernelgms, _ni, patch_size, kernel_offset, idn, kwg, true, false, num_kernels);
}
// Performs the accumulation (C += A * B)
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
MultiplyAdd(cpd[_ni * MWID + _mi], apd[_mi], bpd[_ni]);
}
}
}
// Stores a tile of results
#pragma unroll
for (int _ni = 0; _ni < NWID; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWID; _mi += 1) {
StoreResultsChecked(resultgm, cpd[_ni * MWID + _mi], _mi, _ni, idm, idn, num_patches, num_kernels,
ONE, ZERO, num_patches, result_offset_batch, false);
}
}
}
}
#endif
// =================================================================================================
// End of the C++11 raw string literal
)"
// =================================================================================================

View File

@ -105,7 +105,7 @@ void XgemmDirectBatchedTT(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
// =================================================================================================
#if defined(ROUTINE_GEMMSTRIDEDBATCHED) || defined(ROUTINE_CONVGEMM)
#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
// Direct version of the strided-batched GEMM kernel with [A, B] = [non-transposed, non-transposed]
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))

View File

@ -11,13 +11,12 @@
//
// =================================================================================================
#include "routines/levelx/xconvgemm.hpp"
#include "routines/levelx/xim2col.hpp"
#include "routines/level3/xgemm.hpp"
#include <string>
#include <vector>
#include "routines/levelx/xconvgemm.hpp"
#include "routines/levelx/xim2col.hpp"
namespace clblast {
// =================================================================================================
@ -32,7 +31,7 @@ Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &nam
#include "../../kernels/level3/xgemm_direct_part2.opencl"
#include "../../kernels/level3/xgemm_direct_part3.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_direct_batched.opencl"
#include "../../kernels/level3/xconvgemm.opencl"
}) {
}
@ -93,61 +92,35 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
}
// Strided batched GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result)
const auto m = num_patches;
const auto n = num_kernels;
const auto k = patch_size;
const auto col_ld = m;
const auto kernel_ld = k;
const auto result_ld = m;
const auto col_stride = patch_size * num_patches;
const auto kernel_stride = size_t{0}; // applies the same kernel to all batches
const auto result_stride = num_kernels * output_h * output_w;
// Computes the transpose/conjugate options and sets the a/b/c sizes based on that
bool col_do_transpose, kernel_do_transpose, result_do_transpose, col_conjugate, kernel_conjugate;
size_t col_one, col_two, kernel_one, kernel_two, result_one, result_two;
Xgemm<T>::ProcessArguments(Layout::kColMajor, Transpose::kNo, Transpose::kNo, m, n, k,
col_one, col_two, kernel_one, kernel_two, result_one, result_two,
col_do_transpose, kernel_do_transpose,
result_do_transpose, col_conjugate, kernel_conjugate, 0);
// Tests the matrices for validity
TestMatrixB(patch_size, num_kernels, kernel_buffer, kernel_offset, patch_size);
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
TestMatrixA(col_one, col_two, col_buffer, col_stride * batch, col_ld);
TestMatrixB(kernel_one, kernel_two, kernel_buffer, kernel_offset + kernel_stride * batch, kernel_ld);
TestMatrixC(result_one, result_two, result_buffer, result_offset + result_stride * batch, result_ld);
TestMatrixA(num_patches, patch_size, col_buffer, col_stride * batch, num_patches);
TestMatrixC(num_patches, num_kernels, result_buffer, result_offset + result_stride * batch, num_patches);
}
// Retrieves the proper XgemmDirect kernel from the compiled binary
const auto name = (col_do_transpose) ? (kernel_do_transpose ? "XgemmDirectStridedBatchedTT" : "XgemmDirectStridedBatchedTN") :
(kernel_do_transpose ? "XgemmDirectStridedBatchedNT" : "XgemmDirectStridedBatchedNN");
auto kernel = Kernel(program_, name);
auto kernel = Kernel(program_, "Xconvgemm");
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(m));
kernel.SetArgument(1, static_cast<int>(n));
kernel.SetArgument(2, static_cast<int>(k));
kernel.SetArgument(3, GetRealArg(ConstantOne<T>()));
kernel.SetArgument(4, GetRealArg(ConstantZero<T>()));
kernel.SetArgument(5, col_buffer());
kernel.SetArgument(6, static_cast<int>(0));
kernel.SetArgument(7, static_cast<int>(col_ld));
kernel.SetArgument(8, static_cast<int>(col_stride));
kernel.SetArgument(9, kernel_buffer());
kernel.SetArgument(10, static_cast<int>(kernel_offset));
kernel.SetArgument(11, static_cast<int>(kernel_ld));
kernel.SetArgument(12, static_cast<int>(kernel_stride));
kernel.SetArgument(13, result_buffer());
kernel.SetArgument(14, static_cast<int>(result_offset));
kernel.SetArgument(15, static_cast<int>(result_ld));
kernel.SetArgument(16, static_cast<int>(result_stride));
kernel.SetArgument(17, static_cast<int>(result_do_transpose));
kernel.SetArgument(18, static_cast<int>(false));
kernel.SetArgument(19, static_cast<int>(false));
kernel.SetArgument(0, static_cast<int>(num_patches));
kernel.SetArgument(1, static_cast<int>(num_kernels));
kernel.SetArgument(2, static_cast<int>(patch_size));
kernel.SetArgument(3, col_buffer());
kernel.SetArgument(4, static_cast<int>(0));
kernel.SetArgument(5, static_cast<int>(col_stride));
kernel.SetArgument(6, kernel_buffer());
kernel.SetArgument(7, static_cast<int>(kernel_offset));
kernel.SetArgument(8, result_buffer());
kernel.SetArgument(9, static_cast<int>(result_offset));
kernel.SetArgument(10, static_cast<int>(result_stride));
// Computes the global and local thread sizes
const auto m_ceiled = Ceil(m, db_["WGD"]);
const auto n_ceiled = Ceil(n, db_["WGD"]);
const auto m_ceiled = Ceil(num_patches, db_["WGD"]);
const auto n_ceiled = Ceil(num_kernels, db_["WGD"]);
const auto global = std::vector<size_t>{
(m_ceiled * db_["MDIMCD"]) / db_["WGD"],
(n_ceiled * db_["NDIMCD"]) / db_["WGD"],