diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f067efb..80eb9583 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -221,7 +221,7 @@ set(LEVEL1_ROUTINES xswap xscal xcopy xaxpy xdot xdotu xdotc xnrm2 xasum xamax) set(LEVEL2_ROUTINES xgemv xgbmv xhemv xhbmv xhpmv xsymv xsbmv xspmv xtrmv xtbmv xtpmv xtrsv xger xgeru xgerc xher xhpr xher2 xhpr2 xsyr xspr xsyr2 xspr2) set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm xtrsm) -set(LEVELX_ROUTINES xhad xomatcopy xim2col xconvgemm xaxpybatched xgemmbatched xgemmstridedbatched) +set(LEVELX_ROUTINES xhad xomatcopy xim2col xcol2im xconvgemm xaxpybatched xgemmbatched xgemmstridedbatched) set(ROUTINES ${LEVEL1_ROUTINES} ${LEVEL2_ROUTINES} ${LEVEL3_ROUTINES} ${LEVELX_ROUTINES}) set(PRECISIONS 32 64 3232 6464 16) diff --git a/doc/api.md b/doc/api.md index 15bc0dcd..64b4a1c8 100644 --- a/doc/api.md +++ b/doc/api.md @@ -3072,6 +3072,66 @@ Arguments to IM2COL: +xCOL2IM: Col2im function (non-BLAS function) +------------- + +Performs the col2im algorithm, in which _col_ is the input matrix and _im_ is the output matrix. + +C++ API: +``` +template +StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) +``` + +C API: +``` +CLBlastStatusCode CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) +CLBlastStatusCode CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) +CLBlastStatusCode CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) +CLBlastStatusCode CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) +CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) +``` + +Arguments to COL2IM: + +* `const size_t channels`: Integer size argument. This value must be positive. +* `const size_t height`: Integer size argument. This value must be positive. +* `const size_t width`: Integer size argument. This value must be positive. +* `const size_t kernel_h`: Integer size argument. This value must be positive. +* `const size_t kernel_w`: Integer size argument. This value must be positive. +* `const size_t pad_h`: Integer size argument. This value must be positive. +* `const size_t pad_w`: Integer size argument. This value must be positive. +* `const size_t stride_h`: Integer size argument. This value must be positive. +* `const size_t stride_w`: Integer size argument. This value must be positive. +* `const size_t dilation_h`: Integer size argument. This value must be positive. +* `const size_t dilation_w`: Integer size argument. This value must be positive. +* `const cl_mem col_buffer`: OpenCL buffer to store the input col tensor. +* `const size_t col_offset`: The offset in elements from the start of the input col tensor. +* `cl_mem im_buffer`: OpenCL buffer to store the output im tensor. +* `const size_t im_offset`: The offset in elements from the start of the output im tensor. +* `cl_command_queue* queue`: Pointer to an OpenCL command queue associated with a context and device to execute the routine on. +* `cl_event* event`: Pointer to an OpenCL event to be able to wait for completion of the routine's OpenCL kernel(s). This is an optional argument. + + + xCONVGEMM: Batched convolution as GEMM (non-BLAS function) ------------- diff --git a/include/clblast.h b/include/clblast.h index 9a8988e7..27adf7fa 100644 --- a/include/clblast.h +++ b/include/clblast.h @@ -636,6 +636,13 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event = nullptr); +// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM +template +StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event = nullptr); + // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, diff --git a/include/clblast_c.h b/include/clblast_c.h index 2357182c..1c681bfe 100644 --- a/include/clblast_c.h +++ b/include/clblast_c.h @@ -1410,6 +1410,28 @@ CLBlastStatusCode PUBLIC_API CLBlastHim2col(const size_t channels, const size_t cl_mem col_buffer, const size_t col_offset, cl_command_queue* queue, cl_event* event); +// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM +CLBlastStatusCode PUBLIC_API CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event); +CLBlastStatusCode PUBLIC_API CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event); +CLBlastStatusCode PUBLIC_API CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event); +CLBlastStatusCode PUBLIC_API CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event); +CLBlastStatusCode PUBLIC_API CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event); + // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM CLBlastStatusCode PUBLIC_API CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, diff --git a/include/clblast_cuda.h b/include/clblast_cuda.h index 1bbd898e..58f9b74b 100644 --- a/include/clblast_cuda.h +++ b/include/clblast_cuda.h @@ -608,6 +608,13 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width CUdeviceptr col_buffer, const size_t col_offset, const CUcontext context, const CUdevice device); +// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM +template +StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const CUdeviceptr col_buffer, const size_t col_offset, + CUdeviceptr im_buffer, const size_t im_offset, + const CUcontext context, const CUdevice device); + // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, diff --git a/include/clblast_netlib_c.h b/include/clblast_netlib_c.h index b64b82eb..65545bfb 100644 --- a/include/clblast_netlib_c.h +++ b/include/clblast_netlib_c.h @@ -960,6 +960,20 @@ void PUBLIC_API cblas_zim2col(const int channels, const int height, const int wi const void* im, void* col); +// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM +void PUBLIC_API cblas_scol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const float* col, + float* im); +void PUBLIC_API cblas_dcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const double* col, + double* im); +void PUBLIC_API cblas_ccol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const void* col, + void* im); +void PUBLIC_API cblas_zcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const void* col, + void* im); + // ================================================================================================= #ifdef __cplusplus diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index c2637037..27107739 100755 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -181,6 +181,7 @@ ROUTINES = [ Routine(True, True, 0, False, "x", "had", T, [S,D,C,Z,H], ["n"], [], ["x","y"], ["z"], [xn,yn,zn], ["alpha","beta"], "", "Element-wise vector product (Hadamard)", "Performs the Hadamard element-wise product _z = alpha * x * y + beta * z_, in which _x_, _y_, and _z_ are vectors and _alpha_ and _beta_ are scalar constants.", []), Routine(True, True, 0, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]), Routine(True, True, 0, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, [], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix.", []), + Routine(True, True, 0, False, "x", "col2im", T, [S,D,C,Z,H], im2col_constants, [], ["col"], ["im"], [col,im], [""], "", "Col2im function (non-BLAS function)", "Performs the col2im algorithm, in which _col_ is the input matrix and _im_ is the output matrix.", []), Routine(True, True, 0, False, "x", "convgemm", T, [S,D,H], convgemm_constants, [], ["im","kernel"], ["result"], [imb,kernel,result],[""], "", "Batched convolution as GEMM (non-BLAS function)", "Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D input tensor (NCHW - batch-channelin-height-width), _kernel_ the 4D kernel weights tensor (KCHW - channelout-channelin-height-width), and _result_ the 4D output tensor (NCHW - batch-channelout-height-width).", []), # Batched routines: Routine(True, True, 1, False, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []), diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index 7321349d..3b5a6b76 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -205,7 +205,7 @@ class Routine: def no_scalars(self): """Determines whether or not this routine has scalar arguments (alpha/beta)""" - return self.scalars == [] or self.name in ["im2col", "convgemm"] + return self.scalars == [] or self.name in ["im2col", "col2im", "convgemm"] def has_layout(self): """Determines whether the layout is an argument""" @@ -226,12 +226,14 @@ class Routine: """Determines which buffers go first (between alpha and beta) and which ones go after""" if self.level == "2b" or self.name == "had": return ["x", "y"] - return ["ap", "a", "b", "x", "im", "kernel"] + extra_buffer = "col" if self.name == "col2im" else "im" + return ["ap", "a", "b", "x", extra_buffer, "kernel"] def buffers_second(self): if self.level == "2b" or self.name == "had": return ["z", "ap", "a", "b", "c"] - return ["y", "c", "col", "result"] + extra_buffer = "im" if self.name == "col2im" else "col" + return ["y", "c", extra_buffer, "result"] def buffer(self, name): """Retrieves a variable name for a specific input/output vector/matrix (e.g. 'x')""" diff --git a/src/clblast.cpp b/src/clblast.cpp index 0cd2f843..e45f504a 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2252,6 +2252,42 @@ template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const si cl_mem, const size_t, cl_command_queue*, cl_event*); +// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM +template +StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) { + try { + auto queue_cpp = Queue(*queue); + auto routine = Xcol2im(queue_cpp, event); + routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + Buffer(col_buffer), col_offset, + Buffer(im_buffer), im_offset); + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const cl_mem, const size_t, + cl_mem, const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const cl_mem, const size_t, + cl_mem, const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const cl_mem, const size_t, + cl_mem, const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const cl_mem, const size_t, + cl_mem, const size_t, + cl_command_queue*, cl_event*); +template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const cl_mem, const size_t, + cl_mem, const size_t, + cl_command_queue*, cl_event*); + // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, diff --git a/src/clblast_c.cpp b/src/clblast_c.cpp index 72adb888..645a69b1 100644 --- a/src/clblast_c.cpp +++ b/src/clblast_c.cpp @@ -3679,6 +3679,73 @@ CLBlastStatusCode CLBlastHim2col(const size_t channels, const size_t height, con } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } } +// COL2IM +CLBlastStatusCode CLBlastScol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) { + try { + return static_cast( + clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + col_buffer, col_offset, + im_buffer, im_offset, + queue, event) + ); + } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastDcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) { + try { + return static_cast( + clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + col_buffer, col_offset, + im_buffer, im_offset, + queue, event) + ); + } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastCcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) { + try { + return static_cast( + clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + col_buffer, col_offset, + im_buffer, im_offset, + queue, event) + ); + } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastZcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) { + try { + return static_cast( + clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + col_buffer, col_offset, + im_buffer, im_offset, + queue, event) + ); + } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } +} +CLBlastStatusCode CLBlastHcol2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const cl_mem col_buffer, const size_t col_offset, + cl_mem im_buffer, const size_t im_offset, + cl_command_queue* queue, cl_event* event) { + try { + return static_cast( + clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + col_buffer, col_offset, + im_buffer, im_offset, + queue, event) + ); + } catch (...) { return static_cast(clblast::DispatchExceptionForC()); } +} + // CONVGEMM CLBlastStatusCode CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, const cl_mem im_buffer, const size_t im_offset, diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp index f14806cb..03d995ba 100644 --- a/src/clblast_cuda.cpp +++ b/src/clblast_cuda.cpp @@ -2350,6 +2350,44 @@ template StatusCode PUBLIC_API Im2col(const size_t, const size_t, const si CUdeviceptr, const size_t, const CUcontext, const CUdevice); +// Col2im function (non-BLAS function): SCOL2IM/DCOL2IM/CCOL2IM/ZCOL2IM/HCOL2IM +template +StatusCode Col2im(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, + const CUdeviceptr col_buffer, const size_t col_offset, + CUdeviceptr im_buffer, const size_t im_offset, + const CUcontext context, const CUdevice device) { + try { + const auto context_cpp = Context(context); + const auto device_cpp = Device(device); + auto queue_cpp = Queue(context_cpp, device_cpp); + auto routine = Xcol2im(queue_cpp, nullptr); + routine.DoCol2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + Buffer(col_buffer), col_offset, + Buffer(im_buffer), im_offset); + return StatusCode::kSuccess; + } catch (...) { return DispatchException(); } +} +template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const CUdeviceptr, const size_t, + CUdeviceptr, const size_t, + const CUcontext, const CUdevice); +template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const CUdeviceptr, const size_t, + CUdeviceptr, const size_t, + const CUcontext, const CUdevice); +template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const CUdeviceptr, const size_t, + CUdeviceptr, const size_t, + const CUcontext, const CUdevice); +template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const CUdeviceptr, const size_t, + CUdeviceptr, const size_t, + const CUcontext, const CUdevice); +template StatusCode PUBLIC_API Col2im(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, + const CUdeviceptr, const size_t, + CUdeviceptr, const size_t, + const CUcontext, const CUdevice); + // Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM template StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count, diff --git a/src/clblast_netlib_c.cpp b/src/clblast_netlib_c.cpp index dbc2ba57..22570535 100644 --- a/src/clblast_netlib_c.cpp +++ b/src/clblast_netlib_c.cpp @@ -4967,4 +4967,94 @@ void cblas_zim2col(const int channels, const int height, const int width, const col_buffer.Read(queue, col_size, reinterpret_cast(col)); } +// COL2IM +void cblas_scol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const float* col, + float* im) { + OPTIONAL_STATIC auto device = get_device(); + OPTIONAL_STATIC auto context = clblast::Context(device); + auto queue = clblast::Queue(context, device); + const auto col_size = height * width * channels; + const auto im_size = height * width * channels; + auto col_buffer = clblast::Buffer(context, col_size); + auto im_buffer = clblast::Buffer(context, im_size); + col_buffer.Write(queue, col_size, reinterpret_cast(col)); + im_buffer.Write(queue, im_size, reinterpret_cast(im)); + auto queue_cl = queue(); + auto s = clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + col_buffer(), 0, + im_buffer(), 0, + &queue_cl); + if (s != clblast::StatusCode::kSuccess) { + throw std::runtime_error("CLBlast returned with error code " + clblast::ToString(s)); + } + im_buffer.Read(queue, im_size, reinterpret_cast(im)); +} +void cblas_dcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const double* col, + double* im) { + OPTIONAL_STATIC auto device = get_device(); + OPTIONAL_STATIC auto context = clblast::Context(device); + auto queue = clblast::Queue(context, device); + const auto col_size = height * width * channels; + const auto im_size = height * width * channels; + auto col_buffer = clblast::Buffer(context, col_size); + auto im_buffer = clblast::Buffer(context, im_size); + col_buffer.Write(queue, col_size, reinterpret_cast(col)); + im_buffer.Write(queue, im_size, reinterpret_cast(im)); + auto queue_cl = queue(); + auto s = clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + col_buffer(), 0, + im_buffer(), 0, + &queue_cl); + if (s != clblast::StatusCode::kSuccess) { + throw std::runtime_error("CLBlast returned with error code " + clblast::ToString(s)); + } + im_buffer.Read(queue, im_size, reinterpret_cast(im)); +} +void cblas_ccol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const void* col, + void* im) { + OPTIONAL_STATIC auto device = get_device(); + OPTIONAL_STATIC auto context = clblast::Context(device); + auto queue = clblast::Queue(context, device); + const auto col_size = height * width * channels; + const auto im_size = height * width * channels; + auto col_buffer = clblast::Buffer(context, col_size); + auto im_buffer = clblast::Buffer(context, im_size); + col_buffer.Write(queue, col_size, reinterpret_cast(col)); + im_buffer.Write(queue, im_size, reinterpret_cast(im)); + auto queue_cl = queue(); + auto s = clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + col_buffer(), 0, + im_buffer(), 0, + &queue_cl); + if (s != clblast::StatusCode::kSuccess) { + throw std::runtime_error("CLBlast returned with error code " + clblast::ToString(s)); + } + im_buffer.Read(queue, im_size, reinterpret_cast(im)); +} +void cblas_zcol2im(const int channels, const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, + const void* col, + void* im) { + OPTIONAL_STATIC auto device = get_device(); + OPTIONAL_STATIC auto context = clblast::Context(device); + auto queue = clblast::Queue(context, device); + const auto col_size = height * width * channels; + const auto im_size = height * width * channels; + auto col_buffer = clblast::Buffer(context, col_size); + auto im_buffer = clblast::Buffer(context, im_size); + col_buffer.Write(queue, col_size, reinterpret_cast(col)); + im_buffer.Write(queue, im_size, reinterpret_cast(im)); + auto queue_cl = queue(); + auto s = clblast::Col2im(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + col_buffer(), 0, + im_buffer(), 0, + &queue_cl); + if (s != clblast::StatusCode::kSuccess) { + throw std::runtime_error("CLBlast returned with error code " + clblast::ToString(s)); + } + im_buffer.Read(queue, im_size, reinterpret_cast(im)); +} + // ================================================================================================= diff --git a/src/kernels/levelx/col2im.opencl b/src/kernels/levelx/col2im.opencl new file mode 100644 index 00000000..76917795 --- /dev/null +++ b/src/kernels/levelx/col2im.opencl @@ -0,0 +1,74 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// This file contains the col2im kernel, taken from: +// https://gist.github.com/vbkaisetsu/a98299df827f9a5245635f646c1d94be +// Credits go to https://github.com/vbkaisetsu +// +// ================================================================================================= + +// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string +// literal). Comment-out this line for syntax-highlighting when developing. +R"( + +// Work-group size parameters re-used from the 'copy' kernel +#ifndef COPY_DIMX + #define COPY_DIMX 8 // Local workgroup size in the first dimension (w) +#endif +#ifndef COPY_DIMY + #define COPY_DIMY 8 // Local workgroup size in the second dimension (h) +#endif + +// ================================================================================================= + +__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1))) +void col2im(const int input_h, const int input_w, const int channels, + const int output_h, const int output_w, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const __global real* restrict col_buffer, const int col_offset, + __global real *im_buffer, const int im_offset) { + const int x_x = get_global_id(0) + pad_w; + const int x_y = ((int) get_global_id(1)) % input_h + pad_h; + const int channel = ((int) get_global_id(1)) / input_h; + const int kernel_extent_w = (kernel_w - 1) * dilation_w + 1; + const int kernel_extent_h = (kernel_h - 1) * dilation_h + 1; + const int col_channel_shift = channel * kernel_w * kernel_h * output_h * output_w + col_offset; + const int x_channel_shift = channel * input_h * input_w + im_offset; + const int t_y_begin = (x_y < kernel_extent_h) ? 0 : (x_y - kernel_extent_h) / stride_h + 1; + const int t_y_end = min(x_y / stride_h + 1, output_h); + const int t_x_begin = (x_x < kernel_extent_w) ? 0 : (x_x - kernel_extent_w) / stride_w + 1; + const int t_x_end = min(x_x / stride_w + 1, output_w); + + if (x_x < input_w + pad_w && channel < channels) { + real val; + SetToZero(val); + for (int t_y = t_y_begin; t_y < t_y_end; ++t_y) { + for (int t_x = t_x_begin; t_x < t_x_end; ++t_x) { + int w_y = x_y - t_y * stride_h; + int w_x = x_x - t_x * stride_w; + if (w_y % dilation_h == 0 && w_x % dilation_w == 0) { + w_y /= dilation_h; + w_x /= dilation_w; + val += col_buffer[col_channel_shift + + (w_x + w_y * kernel_w) * output_h * output_w + + t_y * output_w + + t_x]; + } + } + } + im_buffer[x_channel_shift + (x_y - pad_h) * input_w + x_x - pad_w] = val; + } +} + +// ================================================================================================= + +// End of the C++11 raw string literal +)" + +// ================================================================================================= diff --git a/src/routines/levelx/xcol2im.cpp b/src/routines/levelx/xcol2im.cpp index e69de29b..8339c02c 100644 --- a/src/routines/levelx/xcol2im.cpp +++ b/src/routines/levelx/xcol2im.cpp @@ -0,0 +1,92 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren +// +// This file implements the Xcol2im class (see the header for information about the class). +// +// ================================================================================================= + +#include "routines/levelx/xcol2im.hpp" + +#include +#include + +namespace clblast { +// ================================================================================================= + +// Constructor: forwards to base class constructor +template +Xcol2im::Xcol2im(Queue &queue, EventPointer event, const std::string &name): + Routine(queue, event, name, {"Copy"}, PrecisionValue(), {}, { +#include "../../kernels/levelx/col2im.opencl" + }) { +} + +// ================================================================================================= + +// The main routine +template +void Xcol2im::DoCol2im(const size_t channels, const size_t height, const size_t width, + const size_t kernel_h, const size_t kernel_w, const size_t pad_h, + const size_t pad_w, const size_t stride_h, const size_t stride_w, + const size_t dilation_h, const size_t dilation_w, + const Buffer &col_buffer, const size_t col_offset, + const Buffer &im_buffer, const size_t im_offset) { + + // Makes sure all dimensions are larger than zero + if ((channels == 0) || (height == 0) || (width == 0)) { throw BLASError(StatusCode::kInvalidDimension); } + + // Sets the output height and width + const auto size_h = height + 2 * pad_h; + const auto padding_h = dilation_h * (kernel_h - 1) + 1; + const auto col_h = (size_h >= padding_h) ? (size_h - padding_h) / stride_h + 1 : 1; + const auto size_w = width + 2 * pad_w; + const auto padding_w = dilation_w * (kernel_w - 1) + 1; + const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1; + + // Retrieves the kernel from the compiled binary + auto kernel = Kernel(program_, "col2im"); + + // Sets the kernel arguments + kernel.SetArgument(0, static_cast(height)); + kernel.SetArgument(1, static_cast(width)); + kernel.SetArgument(2, static_cast(channels)); + kernel.SetArgument(3, static_cast(col_h)); + kernel.SetArgument(4, static_cast(col_w)); + kernel.SetArgument(5, static_cast(kernel_h)); + kernel.SetArgument(6, static_cast(kernel_w)); + kernel.SetArgument(7, static_cast(pad_h)); + kernel.SetArgument(8, static_cast(pad_w)); + kernel.SetArgument(9, static_cast(stride_h)); + kernel.SetArgument(10, static_cast(stride_w)); + kernel.SetArgument(11, static_cast(dilation_h)); + kernel.SetArgument(12, static_cast(dilation_w)); + kernel.SetArgument(13, col_buffer()); + kernel.SetArgument(14, static_cast(col_offset)); + kernel.SetArgument(15, im_buffer()); + kernel.SetArgument(16, static_cast(im_offset)); + + // Launches the kernel + const auto w_ceiled = Ceil(col_w, db_["COPY_DIMX"]); + const auto h_ceiled = Ceil(col_h, db_["COPY_DIMY"]); + const auto global = std::vector{w_ceiled, h_ceiled * channels}; + const auto local = std::vector{db_["COPY_DIMX"], db_["COPY_DIMY"]}; + RunKernel(kernel, queue_, device_, global, local, event_); +} + +// ================================================================================================= + +// Compiles the templated class +template class Xcol2im; +template class Xcol2im; +template class Xcol2im; +template class Xcol2im; +template class Xcol2im; + +// ================================================================================================= +} // namespace clblast diff --git a/src/routines/levelx/xcol2im.hpp b/src/routines/levelx/xcol2im.hpp new file mode 100644 index 00000000..86d68c45 --- /dev/null +++ b/src/routines/levelx/xcol2im.hpp @@ -0,0 +1,45 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren +// +// This file implements the Xcol2im routine. The precision is implemented using a template argument. +// Uses the tuning parameters from the regular copy kernel. +// +// ================================================================================================= + +#ifndef CLBLAST_ROUTINES_XCOL2IM_H_ +#define CLBLAST_ROUTINES_XCOL2IM_H_ + +#include "routine.hpp" + +namespace clblast { +// ================================================================================================= + +// See comment at top of file for a description of the class +template +class Xcol2im: public Routine { + public: + + // Constructor + Xcol2im(Queue &queue, EventPointer event, const std::string &name = "COL2IM"); + + // Templated-precision implementation of the routine + void DoCol2im(const size_t channels, const size_t height, const size_t width, + const size_t kernel_h, const size_t kernel_w, + const size_t pad_h, const size_t pad_w, + const size_t stride_h, const size_t stride_w, + const size_t dilation_h, const size_t dilation_w, + const Buffer &col_buffer, const size_t col_offset, + const Buffer &im_buffer, const size_t im_offset); +}; + +// ================================================================================================= +} // namespace clblast + +// CLBLAST_ROUTINES_XCOL2IM_H_ +#endif diff --git a/src/routines/routines.hpp b/src/routines/routines.hpp index e080ed47..95475470 100644 --- a/src/routines/routines.hpp +++ b/src/routines/routines.hpp @@ -70,6 +70,7 @@ #include "routines/levelx/xhad.hpp" #include "routines/levelx/xomatcopy.hpp" #include "routines/levelx/xim2col.hpp" +#include "routines/levelx/xcol2im.hpp" #include "routines/levelx/xconvgemm.hpp" #include "routines/levelx/xaxpybatched.hpp" #include "routines/levelx/xgemmbatched.hpp" diff --git a/test/correctness/routines/levelx/xcol2im.cpp b/test/correctness/routines/levelx/xcol2im.cpp new file mode 100644 index 00000000..306d6fcc --- /dev/null +++ b/test/correctness/routines/levelx/xcol2im.cpp @@ -0,0 +1,26 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren +// +// ================================================================================================= + +#include "test/correctness/testblas.hpp" +#include "test/routines/levelx/xcol2im.hpp" + +// Main function (not within the clblast namespace) +int main(int argc, char *argv[]) { + auto errors = size_t{0}; + errors += clblast::RunTests, float, float>(argc, argv, false, "SCOL2IM"); + errors += clblast::RunTests, double, double>(argc, argv, true, "DCOL2IM"); + errors += clblast::RunTests, clblast::float2, clblast::float2>(argc, argv, true, "CCOL2IM"); + errors += clblast::RunTests, clblast::double2, clblast::double2>(argc, argv, true, "ZCOL2IM"); + errors += clblast::RunTests, clblast::half, clblast::half>(argc, argv, true, "HCOL2IM"); + if (errors > 0) { return 1; } else { return 0; } +} + +// ================================================================================================= diff --git a/test/performance/routines/levelx/xcol2im.cpp b/test/performance/routines/levelx/xcol2im.cpp new file mode 100644 index 00000000..76a5be30 --- /dev/null +++ b/test/performance/routines/levelx/xcol2im.cpp @@ -0,0 +1,33 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren +// +// ================================================================================================= + +#include "test/performance/client.hpp" +#include "test/routines/levelx/xcol2im.hpp" + +// Main function (not within the clblast namespace) +int main(int argc, char *argv[]) { + const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv); + switch(clblast::GetPrecision(command_line_args, clblast::Precision::kSingle)) { + case clblast::Precision::kHalf: + clblast::RunClient, clblast::half, clblast::half>(argc, argv); break; + case clblast::Precision::kSingle: + clblast::RunClient, float, float>(argc, argv); break; + case clblast::Precision::kDouble: + clblast::RunClient, double, double>(argc, argv); break; + case clblast::Precision::kComplexSingle: + clblast::RunClient, clblast::float2, clblast::float2>(argc, argv); break; + case clblast::Precision::kComplexDouble: + clblast::RunClient, clblast::double2, clblast::double2>(argc, argv); break; + } + return 0; +} + +// ================================================================================================= diff --git a/test/routines/levelx/xcol2im.hpp b/test/routines/levelx/xcol2im.hpp new file mode 100644 index 00000000..7393c432 --- /dev/null +++ b/test/routines/levelx/xcol2im.hpp @@ -0,0 +1,195 @@ + +// ================================================================================================= +// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This +// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max- +// width of 100 characters per line. +// +// Author(s): +// Cedric Nugteren +// +// This file implements a class with static methods to describe the Xcol2im routine. Examples of +// such 'descriptions' are how to calculate the size a of buffer or how to run the routine. These +// static methods are used by the correctness tester and the performance tester. +// +// ================================================================================================= + +#ifndef CLBLAST_TEST_ROUTINES_XCOL2IM_H_ +#define CLBLAST_TEST_ROUTINES_XCOL2IM_H_ + +#include "test/routines/common.hpp" + +namespace clblast { +// ================================================================================================= + +// See comment at top of file for a description of the class +template +class TestXcol2im { +public: + + // The BLAS level: 4 for the extra routines + static size_t BLASLevel() { return 4; } + + // The list of arguments relevant for this routine + static std::vector GetOptions() { + return {kArgChannels, kArgHeight, kArgWidth, kArgKernelH, kArgKernelW, kArgPadH, kArgPadW, + kArgStrideH, kArgStrideW, kArgDilationH, kArgDilationW, + kArgAOffset, kArgBOffset}; + } + static std::vector BuffersIn() { return {kBufMatA, kBufMatB}; } // b = col + static std::vector BuffersOut() { return {kBufMatA}; } // a = im + + // Describes how to obtain the sizes of the buffers + static size_t ColHeight(const Arguments &args) { + const auto size = args.height + 2 * args.pad_h; + const auto padding = args.dilation_h * (args.kernel_h - 1) + 1; + if (size >= padding) { return (size - padding) / args.stride_h + 1; } + return 1; + } + static size_t ColWidth(const Arguments &args) { + const auto size = args.width + 2 * args.pad_w; + const auto padding = args.dilation_w * (args.kernel_w - 1) + 1; + if (size >= padding) { return (size - padding) / args.stride_w + 1; } + return 1; + } + static size_t NumPatches(const Arguments &args) { + return ColHeight(args) * ColWidth(args) * args.channels; + } + static size_t GetSizeA(const Arguments &args) { + return args.height * args.width * args.channels + args.a_offset; + } + static size_t GetSizeB(const Arguments &args) { + return args.kernel_w * args.kernel_h * NumPatches(args) + args.b_offset; + } + + // Describes how to set the sizes of all the buffers + static void SetSizes(Arguments &args, Queue&) { + args.a_size = GetSizeA(args); // im + args.b_size = GetSizeB(args); // col + } + + // Describes what the default values of the leading dimensions of the matrices are + static size_t DefaultLDA(const Arguments &) { return 1; } // N/A for this routine + static size_t DefaultLDB(const Arguments &) { return 1; } // N/A for this routine + static size_t DefaultLDC(const Arguments &) { return 1; } // N/A for this routine + + // Describes which transpose options are relevant for this routine + using Transposes = std::vector; + static Transposes GetATransposes(const Transposes &) { return {}; } // N/A for this routine + static Transposes GetBTransposes(const Transposes &) { return {}; } // N/A for this routine + + // Describes how to prepare the input data + static void PrepareData(const Arguments&, Queue&, const int, std::vector&, + std::vector&, std::vector&, std::vector&, std::vector&, + std::vector&, std::vector&) {} // N/A for this routine + + // Describes how to run the CLBlast routine + static StatusCode RunRoutine(const Arguments &args, Buffers &buffers, Queue &queue) { + #ifdef OPENCL_API + auto queue_plain = queue(); + auto event = cl_event{}; + auto status = Col2im(args.channels, args.height, args.width, + args.kernel_h, args.kernel_w, + args.pad_h, args.pad_w, + args.stride_h, args.stride_w, + args.dilation_h, args.dilation_w, + buffers.b_mat(), args.b_offset, // col + buffers.a_mat(), args.a_offset, // im + &queue_plain, &event); + if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); } + #elif CUDA_API + auto status = Col2im(args.channels, args.height, args.width, + args.kernel_h, args.kernel_w, + args.pad_h, args.pad_w, + args.stride_h, args.stride_w, + args.dilation_h, args.dilation_w, + buffers.b_mat(), args.b_offset, // col + buffers.a_mat(), args.a_offset, // im + queue.GetContext()(), queue.GetDevice()()); + cuStreamSynchronize(queue()); + #endif + return status; + } + + // Describes how to run a naive version of the routine (for correctness/performance comparison). + // Note that a proper clBLAS or CPU BLAS comparison is not available for non-BLAS routines. + static StatusCode RunReference1(const Arguments &args, Buffers &buffers, Queue &queue) { + auto buffers_host = BuffersHost(); + DeviceToHost(args, buffers, buffers_host, queue, BuffersIn()); + const auto status = RunReference(args, buffers_host); + HostToDevice(args, buffers, buffers_host, queue, BuffersOut()); + return status; + } + + static StatusCode RunReference2(const Arguments &args, BuffersHost &buffers_host, Queue&) { + return RunReference(args, buffers_host); + } + static StatusCode RunReference3(const Arguments &, BuffersCUDA &, Queue &) { + return StatusCode::kUnknownError; + } + + // Describes how to download the results of the computation (more importantly: which buffer) + static std::vector DownloadResult(const Arguments &args, Buffers &buffers, Queue &queue) { + std::vector result(args.a_size, static_cast(0)); + buffers.a_mat.Read(queue, args.a_size, result); + return result; + } + + // Describes how to compute the indices of the result buffer + static size_t ResultID1(const Arguments &args) { return args.height * args.width; } + static size_t ResultID2(const Arguments &args) { return args.channels; } + static size_t GetResultIndex(const Arguments &args, const size_t id1, const size_t id2) { + return id1 + args.height * args.width * id2 + args.a_offset; + } + + // Describes how to compute performance metrics + static size_t GetFlops(const Arguments &) { + return 1; + } + static size_t GetBytes(const Arguments &args) { + const auto im = args.channels * args.width * args.height; // possibly less with striding + const auto col = args.kernel_h * args.kernel_w * NumPatches(args); + return (im + col) * sizeof(T); + } +}; + +// ================================================================================================= + +template +StatusCode RunReference(const Arguments &args, BuffersHost &buffers_host) { + // Reference taken from im2col but swapped the input/output + const auto col_h = TestXcol2im::ColHeight(args); + const auto col_w = TestXcol2im::ColWidth(args); + for (auto c_id = size_t{0}; c_id < args.channels; ++c_id) { // image channels + for (auto kh_id = size_t{0}; kh_id < args.kernel_h; ++kh_id) { // kernel height + for (auto kw_id = size_t{0}; kw_id < args.kernel_w; ++kw_id) { // kernel width + for (auto h_id = size_t{0}; h_id < col_h; ++h_id) { // image height + for (auto w_id = size_t{0}; w_id < col_w; ++w_id) { // image width + + // Reads the input value + const auto kernel_index = kw_id + args.kernel_w * kh_id; + const auto patch_index = w_id + col_w * h_id; + const auto col_index = patch_index + kernel_index * col_w * col_h + + c_id * col_w * col_h * args.kernel_h * args.kernel_w; + const auto val = buffers_host.b_mat[col_index + args.b_offset]; + + // Sets the output value + const auto h_index = kh_id * args.dilation_h + args.stride_h * h_id - args.pad_h; + const auto w_index = kw_id * args.dilation_w + args.stride_w * w_id - args.pad_w; + if (h_index >= 0 && h_index < args.height && + w_index >= 0 && w_index < args.width) { + const auto im_index = w_index + args.width * (h_index + args.height * c_id); + buffers_host.a_mat[im_index + args.a_offset] = val; + } + } + } + } + } + } + return StatusCode::kSuccess; +} + +// ================================================================================================= +} // namespace clblast + +// CLBLAST_TEST_ROUTINES_XCOL2IM_H_ +#endif