Fix xconvgemm kernel and enable ConvGemmMethod::kSingleKernel
parent
9819957768
commit
301dc280df
|
@ -11,7 +11,6 @@
|
|||
// uses parameters from the direct GEMM kernel. This is the part with the loads from memory (1/2).
|
||||
// This uses "CONVGEMM_WITH_IM2COL" as a switch to select between direct convgemm or first running
|
||||
// the im2col kernel to create a 'col' temporary matrix.
|
||||
// TODO: Currently only works with 'CONVGEMM_WITH_IM2COL' set
|
||||
//
|
||||
// =================================================================================================
|
||||
|
||||
|
@ -30,12 +29,17 @@ INLINE_FUNC real GlobalToPrivateCheckedImage(const __global real* restrict image
|
|||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w) {
|
||||
const int dilation_h, const int dilation_w,
|
||||
const bool kernel_flip) {
|
||||
|
||||
// Im2col indices
|
||||
const int kernel_2d_index = kwg % (kernel_h * kernel_w);
|
||||
const int kw_id = kernel_2d_index % kernel_w;
|
||||
const int kh_id = kernel_2d_index / kernel_w;
|
||||
const int kw_id = (kernel_flip)
|
||||
? kernel_w - kernel_2d_index % kernel_w - 1
|
||||
: kernel_2d_index % kernel_w;
|
||||
const int kh_id = (kernel_flip)
|
||||
? kernel_h - kernel_2d_index / kernel_w - 1
|
||||
: kernel_2d_index / kernel_w;
|
||||
const int c_id = kwg / (kernel_h * kernel_w);
|
||||
const int h_index = -pad_h + kh_id * dilation_h + stride_h * h_id;
|
||||
const int w_index = -pad_w + kw_id * dilation_w + stride_w * w_id;
|
||||
|
@ -55,14 +59,15 @@ INLINE_FUNC real GlobalToPrivateCheckedImage(const __global real* restrict image
|
|||
|
||||
// Loads global off-chip memory into local (shared) memory on-chip. This function is specific for
|
||||
// loading the image input tensor. This includes a bounds check.
|
||||
INLINE_FUNC real GlobalToLocalCheckedImage(const __global realMD* restrict imagegm, LOCAL_PTR real* alm,
|
||||
INLINE_FUNC real GlobalToLocalCheckedImage(const __global real* restrict imagegm, LOCAL_PTR real* alm,
|
||||
const int image_offset_batch,
|
||||
const int h_id, const int w_id, const int kwg,
|
||||
const int output_w, const int kwg,
|
||||
const int input_h, const int input_w, const int channels,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w) {
|
||||
const int dilation_h, const int dilation_w,
|
||||
const bool kernel_flip) {
|
||||
#if MDIMCD == MDIMAD
|
||||
const int la0 = get_local_id(0);
|
||||
const int la1 = get_local_id(1);
|
||||
|
@ -82,10 +87,17 @@ INLINE_FUNC real GlobalToLocalCheckedImage(const __global realMD* restrict image
|
|||
int idm = mg + GetGroupID0()*WGD;
|
||||
int idk = kg + kwg;
|
||||
|
||||
const int w_id = idm % output_w;
|
||||
const int h_id = idm / output_w;
|
||||
|
||||
// Im2col indices
|
||||
const int kernel_2d_index = idk % (kernel_h * kernel_w);
|
||||
const int kw_id = kernel_2d_index % kernel_w;
|
||||
const int kh_id = kernel_2d_index / kernel_w;
|
||||
const int kw_id = (kernel_flip)
|
||||
? kernel_w - kernel_2d_index % kernel_w - 1
|
||||
: kernel_2d_index % kernel_w;
|
||||
const int kh_id = (kernel_flip)
|
||||
? kernel_h - kernel_2d_index / kernel_w - 1
|
||||
: kernel_2d_index / kernel_w;
|
||||
const int c_id = idk / (kernel_h * kernel_w);
|
||||
const int h_index = -pad_h + kh_id * dilation_h + stride_h * h_id;
|
||||
const int w_index = -pad_w + kw_id * dilation_w + stride_w * w_id;
|
||||
|
@ -104,7 +116,8 @@ INLINE_FUNC real GlobalToLocalCheckedImage(const __global realMD* restrict image
|
|||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif // defined(ROUTINE_CONVGEMM) && !defined(CONVGEMM_WITH_IM2COL)
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// End of the C++11 raw string literal
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
// uses parameters from the direct GEMM kernel. This part contains the main kernel (2/2).
|
||||
// This uses "CONVGEMM_WITH_IM2COL" as a switch to select between direct convgemm or first running
|
||||
// the im2col kernel to create a 'col' temporary matrix.
|
||||
// TODO: Currently only works with 'CONVGEMM_WITH_IM2COL' set
|
||||
//
|
||||
// =================================================================================================
|
||||
|
||||
|
@ -23,20 +22,25 @@ R"(
|
|||
#if defined(ROUTINE_CONVGEMM)
|
||||
|
||||
// ConvGEMM kernel
|
||||
#if defined(CONVGEMM_WITH_IM2COL)
|
||||
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
|
||||
void Xconvgemm(const int num_patches, const int num_kernels, const int patch_size,
|
||||
const __global realND* restrict kernelgm, const int kernel_offset,
|
||||
__global real* resultgm, const int result_offset, const int result_stride,
|
||||
#if defined(CONVGEMM_WITH_IM2COL)
|
||||
const __global realMD* restrict colgm, const int col_offset, const int col_stride)
|
||||
#else
|
||||
const __global realMD* restrict imagegm, const int image_offset,
|
||||
const int input_h, const int input_w, const int channels,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int output_h, const int output_w)
|
||||
INLINE_FUNC void Xconvgemm(const int num_patches, const int num_kernels, const int patch_size,
|
||||
const __global realND* restrict kernelgm, const int kernel_offset,
|
||||
__global real* resultgm, const int result_offset, const int result_stride,
|
||||
const __global realMD* restrict imagegm, const int image_offset,
|
||||
const int input_h, const int input_w, const int channels,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int output_h, const int output_w,
|
||||
LOCAL_PTR real* alm, LOCAL_PTR real* blm,
|
||||
const bool kernel_flip)
|
||||
#endif
|
||||
{
|
||||
|
||||
|
@ -49,12 +53,16 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
|
|||
#endif
|
||||
const int result_offset_batch = result_offset + result_stride * batch;
|
||||
|
||||
#if defined(CONVGEMM_WITH_IM2COL)
|
||||
__local real alm[WGD * (WGD + PADA)];
|
||||
__local real blm[WGD * (WGD + PADB)];
|
||||
#endif
|
||||
|
||||
// Extra pointers to scalar versions of global memory
|
||||
#if defined(CONVGEMM_WITH_IM2COL)
|
||||
const __global real* restrict colgms = (const __global real* restrict) colgm;
|
||||
#else
|
||||
const __global real* restrict imagegms = (const __global real* restrict) imagegm;
|
||||
#endif
|
||||
const __global real* restrict kernelgms = (const __global real* restrict) kernelgm;
|
||||
|
||||
|
@ -100,10 +108,10 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
|
|||
GlobalToLocalScalarA(colgms, alm, num_patches, col_offset_batch, kwg, false, false);
|
||||
}
|
||||
#else
|
||||
GlobalToLocalCheckedImage(imagegm, alm, image_offset_batch, h_id, w_id, kwg,
|
||||
GlobalToLocalCheckedImage(imagegms, alm, image_offset_batch, output_w, kwg,
|
||||
input_h, input_w, channels, kernel_h, kernel_w,
|
||||
pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w);
|
||||
dilation_h, dilation_w, kernel_flip);
|
||||
#endif
|
||||
if (patch_size % VWND == 0 && kernel_offset % VWND == 0) {
|
||||
GlobalToLocalDirectB(kernelgm, blm, patch_size, kernel_offset, kwg, true, false);
|
||||
|
@ -151,10 +159,12 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
|
|||
#if defined(CONVGEMM_WITH_IM2COL)
|
||||
apd[_mi] = GlobalToPrivateDirectA(colgms, _mi, num_patches, col_offset_batch, idm, kwg, false, false);
|
||||
#else
|
||||
apd[_mi] = GlobalToPrivateCheckedImage(imagegm, image_offset_batch, h_id, w_id, kwg,
|
||||
const int w_id = (idm + _mi) % output_w;
|
||||
const int h_id = (idm + _mi) / output_w;
|
||||
apd[_mi] = GlobalToPrivateCheckedImage(imagegms, image_offset_batch, h_id, w_id, kwg,
|
||||
input_h, input_w, channels, kernel_h, kernel_w,
|
||||
pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w);
|
||||
dilation_h, dilation_w, kernel_flip);
|
||||
#endif
|
||||
}
|
||||
#pragma unroll
|
||||
|
@ -193,10 +203,10 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
|
|||
#if defined(CONVGEMM_WITH_IM2COL)
|
||||
GlobalToLocalCheckedA(colgms, alm, num_patches, col_offset_batch, kwg, false, false, num_patches, patch_size);
|
||||
#else
|
||||
GlobalToLocalCheckedImage(imagegm, alm, image_offset_batch, h_id, w_id, kwg,
|
||||
GlobalToLocalCheckedImage(imagegms, alm, image_offset_batch, output_w, kwg,
|
||||
input_h, input_w, channels, kernel_h, kernel_w,
|
||||
pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w);
|
||||
dilation_h, dilation_w, kernel_flip);
|
||||
#endif
|
||||
GlobalToLocalCheckedB(kernelgms, blm, patch_size, kernel_offset, kwg, true, false, num_kernels, patch_size);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
@ -239,10 +249,12 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
|
|||
#if defined(CONVGEMM_WITH_IM2COL)
|
||||
apd[_mi] = GlobalToPrivateCheckedA(colgms, _mi, num_patches, col_offset_batch, idm, kwg, false, false, num_patches);
|
||||
#else
|
||||
apd[_mi] = GlobalToPrivateCheckedImage(imagegm, image_offset_batch, h_id, w_id, kwg,
|
||||
const int w_id = (idm + _mi) % output_w;
|
||||
const int h_id = (idm + _mi) / output_w;
|
||||
apd[_mi] = GlobalToPrivateCheckedImage(imagegms, image_offset_batch, h_id, w_id, kwg,
|
||||
input_h, input_w, channels, kernel_h, kernel_w,
|
||||
pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w);
|
||||
dilation_h, dilation_w, kernel_flip);
|
||||
#endif
|
||||
}
|
||||
#pragma unroll
|
||||
|
@ -272,7 +284,53 @@ void Xconvgemm(const int num_patches, const int num_kernels, const int patch_siz
|
|||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#if !defined(CONVGEMM_WITH_IM2COL)
|
||||
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
|
||||
void XconvgemmFlip(const int num_patches, const int num_kernels, const int patch_size,
|
||||
const __global realND* restrict kernelgm, const int kernel_offset,
|
||||
__global real* resultgm, const int result_offset, const int result_stride,
|
||||
const __global realMD* restrict imagegm, const int image_offset,
|
||||
const int input_h, const int input_w, const int channels,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int output_h, const int output_w) {
|
||||
const bool kernel_flip = true;
|
||||
__local real alm[WGD * (WGD + PADA)];
|
||||
__local real blm[WGD * (WGD + PADB)];
|
||||
Xconvgemm(num_patches, num_kernels, patch_size,
|
||||
kernelgm, kernel_offset, resultgm, result_offset, result_stride,
|
||||
imagegm, image_offset, input_h, input_w, channels, kernel_h, kernel_w,
|
||||
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
|
||||
output_h, output_w, alm, blm, kernel_flip);
|
||||
}
|
||||
|
||||
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
|
||||
void XconvgemmNormal(const int num_patches, const int num_kernels, const int patch_size,
|
||||
const __global realND* restrict kernelgm, const int kernel_offset,
|
||||
__global real* resultgm, const int result_offset, const int result_stride,
|
||||
const __global realMD* restrict imagegm, const int image_offset,
|
||||
const int input_h, const int input_w, const int channels,
|
||||
const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int output_h, const int output_w) {
|
||||
const bool kernel_flip = false;
|
||||
__local real alm[WGD * (WGD + PADA)];
|
||||
__local real blm[WGD * (WGD + PADB)];
|
||||
Xconvgemm(num_patches, num_kernels, patch_size,
|
||||
kernelgm, kernel_offset, resultgm, result_offset, result_stride,
|
||||
imagegm, image_offset, input_h, input_w, channels, kernel_h, kernel_w,
|
||||
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
|
||||
output_h, output_w, alm, blm, kernel_flip);
|
||||
}
|
||||
|
||||
#endif // !defined(CONVGEMM_WITH_IM2COL)
|
||||
|
||||
#endif // defined(ROUTINE_CONVGEMM)
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// End of the C++11 raw string literal
|
||||
|
|
|
@ -53,9 +53,6 @@ void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode,
|
|||
const Buffer<T> &kernel_buffer, const size_t kernel_offset,
|
||||
const Buffer<T> &result_buffer, const size_t result_offset) {
|
||||
|
||||
// TODO: Implement single-kernel approach
|
||||
assert(method_ == ConvGemmMethod::kWithIm2Col);
|
||||
|
||||
// Tests for a valid batch count
|
||||
if (batch_count == 0) {
|
||||
throw BLASError(StatusCode::kInvalidBatchCount);
|
||||
|
@ -121,7 +118,12 @@ void Xconvgemm<T>::DoConvgemm(const KernelMode kernel_mode,
|
|||
}
|
||||
|
||||
// Retrieves the proper XgemmDirect kernel from the compiled binary
|
||||
auto kernel = Kernel(program_, "Xconvgemm");
|
||||
const std::string kernel_name = (method_ == ConvGemmMethod::kWithIm2Col)
|
||||
? "Xconvgemm"
|
||||
: (kernel_mode == KernelMode::kConvolution)
|
||||
? "XconvgemmFlip"
|
||||
: "XconvgemmNormal";
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(num_patches));
|
||||
|
|
|
@ -29,7 +29,7 @@ class Xconvgemm: public Routine {
|
|||
// Constructor
|
||||
enum class ConvGemmMethod {kWithIm2Col, kSingleKernel};
|
||||
Xconvgemm(Queue &queue, EventPointer event, const std::string &name = "CONVGEMM",
|
||||
const ConvGemmMethod method = ConvGemmMethod::kWithIm2Col);
|
||||
const ConvGemmMethod method = ConvGemmMethod::kSingleKernel);
|
||||
|
||||
// Templated-precision implementation of the routine
|
||||
void DoConvgemm(const KernelMode kernel_mode,
|
||||
|
|
Loading…
Reference in New Issue