Plugged in the code of strided-batched-gemm into convgemm in preparation of a new kernel

pull/319/head
Cedric Nugteren 2018-05-13 21:01:46 +02:00
parent 4e6d30088d
commit ad8f1027ab
2 changed files with 75 additions and 17 deletions

View File

@ -105,7 +105,7 @@ void XgemmDirectBatchedTT(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
// =================================================================================================
#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
#if defined(ROUTINE_GEMMSTRIDEDBATCHED) || defined(ROUTINE_CONVGEMM)
// Direct version of the strided-batched GEMM kernel with [A, B] = [non-transposed, non-transposed]
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))

View File

@ -13,7 +13,7 @@
#include "routines/levelx/xconvgemm.hpp"
#include "routines/levelx/xim2col.hpp"
#include "routines/levelx/xgemmstridedbatched.hpp"
#include "routines/level3/xgemm.hpp"
#include <string>
#include <vector>
@ -24,9 +24,16 @@ namespace clblast {
// Constructor: forwards to base class constructor
template <typename T>
Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &name):
Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, {
#include "../../kernels/levelx/im2col.opencl"
}) {
Routine(queue, event, name, {"XgemmDirect"},
PrecisionValue<T>(), {}, {
#include "../../kernels/level3/level3.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_direct_part1.opencl"
#include "../../kernels/level3/xgemm_direct_part2.opencl"
#include "../../kernels/level3/xgemm_direct_part3.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_direct_batched.opencl"
}) {
}
// =================================================================================================
@ -41,8 +48,13 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
const Buffer<T> &kernel_buffer, const size_t kernel_offset,
const Buffer<T> &result_buffer, const size_t result_offset) {
// Tests for a valid batch count
if (batch_count == 0) {
throw BLASError(StatusCode::kInvalidBatchCount);
}
// Makes sure all dimensions are larger than zero
if ((channels == 0) || (height == 0) || (width == 0) || (num_kernels == 0) || (batch_count == 0)) {
if ((channels == 0) || (height == 0) || (width == 0) || (num_kernels == 0)) {
throw BLASError(StatusCode::kInvalidDimension);
}
@ -80,7 +92,7 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
im2col_event.WaitForCompletion();
}
// GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result)
// Strided batched GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result)
const auto m = num_patches;
const auto n = num_kernels;
const auto k = patch_size;
@ -88,17 +100,63 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
const auto kernel_ld = k;
const auto result_ld = m;
const auto col_stride = patch_size * num_patches;
const auto kernel_stride = size_t{0}; // applies the same kernel to all
const auto kernel_stride = size_t{0}; // applies the same kernel to all batches
const auto result_stride = num_kernels * output_h * output_w;
auto gemm_event = Event();
auto gemm = XgemmStridedBatched<T>(queue_, gemm_event.pointer());
gemm.DoGemmStridedBatched(Layout::kColMajor, Transpose::kNo, Transpose::kNo,
m, n, k, ConstantOne<T>(),
col_buffer, 0, col_ld, col_stride,
kernel_buffer, kernel_offset, kernel_ld, kernel_stride, ConstantZero<T>(),
result_buffer, result_offset, result_ld, result_stride,
batch_count);
gemm_event.WaitForCompletion();
// Computes the transpose/conjugate options and sets the a/b/c sizes based on that
bool col_do_transpose, kernel_do_transpose, result_do_transpose, col_conjugate, kernel_conjugate;
size_t col_one, col_two, kernel_one, kernel_two, result_one, result_two;
Xgemm<T>::ProcessArguments(Layout::kColMajor, Transpose::kNo, Transpose::kNo, m, n, k,
col_one, col_two, kernel_one, kernel_two, result_one, result_two,
col_do_transpose, kernel_do_transpose,
result_do_transpose, col_conjugate, kernel_conjugate, 0);
// Tests the matrices for validity
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
TestMatrixA(col_one, col_two, col_buffer, col_stride * batch, col_ld);
TestMatrixB(kernel_one, kernel_two, kernel_buffer, kernel_offset + kernel_stride * batch, kernel_ld);
TestMatrixC(result_one, result_two, result_buffer, result_offset + result_stride * batch, result_ld);
}
// Retrieves the proper XgemmDirect kernel from the compiled binary
const auto name = (col_do_transpose) ? (kernel_do_transpose ? "XgemmDirectStridedBatchedTT" : "XgemmDirectStridedBatchedTN") :
(kernel_do_transpose ? "XgemmDirectStridedBatchedNT" : "XgemmDirectStridedBatchedNN");
auto kernel = Kernel(program_, name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(m));
kernel.SetArgument(1, static_cast<int>(n));
kernel.SetArgument(2, static_cast<int>(k));
kernel.SetArgument(3, GetRealArg(ConstantOne<T>()));
kernel.SetArgument(4, GetRealArg(ConstantZero<T>()));
kernel.SetArgument(5, col_buffer());
kernel.SetArgument(6, static_cast<int>(0));
kernel.SetArgument(7, static_cast<int>(col_ld));
kernel.SetArgument(8, static_cast<int>(col_stride));
kernel.SetArgument(9, kernel_buffer());
kernel.SetArgument(10, static_cast<int>(kernel_offset));
kernel.SetArgument(11, static_cast<int>(kernel_ld));
kernel.SetArgument(12, static_cast<int>(kernel_stride));
kernel.SetArgument(13, result_buffer());
kernel.SetArgument(14, static_cast<int>(result_offset));
kernel.SetArgument(15, static_cast<int>(result_ld));
kernel.SetArgument(16, static_cast<int>(result_stride));
kernel.SetArgument(17, static_cast<int>(result_do_transpose));
kernel.SetArgument(18, static_cast<int>(false));
kernel.SetArgument(19, static_cast<int>(false));
// Computes the global and local thread sizes
const auto m_ceiled = Ceil(m, db_["WGD"]);
const auto n_ceiled = Ceil(n, db_["WGD"]);
const auto global = std::vector<size_t>{
(m_ceiled * db_["MDIMCD"]) / db_["WGD"],
(n_ceiled * db_["NDIMCD"]) / db_["WGD"],
batch_count
};
const auto local = std::vector<size_t>{db_["MDIMCD"], db_["NDIMCD"], 1};
// Launches the kernel
RunKernel(kernel, queue_, device_, global, local, event_);
}
// =================================================================================================