Plugged in the code of strided-batched-gemm into convgemm in preparation of a new kernel
parent
4e6d30088d
commit
ad8f1027ab
|
@ -105,7 +105,7 @@ void XgemmDirectBatchedTT(const int kSizeM, const int kSizeN, const int kSizeK,
|
|||
|
||||
#endif
|
||||
// =================================================================================================
|
||||
#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
|
||||
#if defined(ROUTINE_GEMMSTRIDEDBATCHED) || defined(ROUTINE_CONVGEMM)
|
||||
|
||||
// Direct version of the strided-batched GEMM kernel with [A, B] = [non-transposed, non-transposed]
|
||||
__kernel __attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
#include "routines/levelx/xconvgemm.hpp"
|
||||
#include "routines/levelx/xim2col.hpp"
|
||||
#include "routines/levelx/xgemmstridedbatched.hpp"
|
||||
#include "routines/level3/xgemm.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -24,9 +24,16 @@ namespace clblast {
|
|||
// Constructor: forwards to base class constructor
|
||||
template <typename T>
|
||||
Xconvgemm<T>::Xconvgemm(Queue &queue, EventPointer event, const std::string &name):
|
||||
Routine(queue, event, name, {"Copy"}, PrecisionValue<T>(), {}, {
|
||||
#include "../../kernels/levelx/im2col.opencl"
|
||||
}) {
|
||||
Routine(queue, event, name, {"XgemmDirect"},
|
||||
PrecisionValue<T>(), {}, {
|
||||
#include "../../kernels/level3/level3.opencl"
|
||||
, // separated in multiple parts to prevent C1091 in MSVC 2013
|
||||
#include "../../kernels/level3/xgemm_direct_part1.opencl"
|
||||
#include "../../kernels/level3/xgemm_direct_part2.opencl"
|
||||
#include "../../kernels/level3/xgemm_direct_part3.opencl"
|
||||
, // separated in multiple parts to prevent C1091 in MSVC 2013
|
||||
#include "../../kernels/level3/xgemm_direct_batched.opencl"
|
||||
}) {
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
@ -41,8 +48,13 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
|
|||
const Buffer<T> &kernel_buffer, const size_t kernel_offset,
|
||||
const Buffer<T> &result_buffer, const size_t result_offset) {
|
||||
|
||||
// Tests for a valid batch count
|
||||
if (batch_count == 0) {
|
||||
throw BLASError(StatusCode::kInvalidBatchCount);
|
||||
}
|
||||
|
||||
// Makes sure all dimensions are larger than zero
|
||||
if ((channels == 0) || (height == 0) || (width == 0) || (num_kernels == 0) || (batch_count == 0)) {
|
||||
if ((channels == 0) || (height == 0) || (width == 0) || (num_kernels == 0)) {
|
||||
throw BLASError(StatusCode::kInvalidDimension);
|
||||
}
|
||||
|
||||
|
@ -80,7 +92,7 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
|
|||
im2col_event.WaitForCompletion();
|
||||
}
|
||||
|
||||
// GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result)
|
||||
// Strided batched GEMM: C (result) = alpha (1) * A (col) * B (kernel) + beta (0) * C (result)
|
||||
const auto m = num_patches;
|
||||
const auto n = num_kernels;
|
||||
const auto k = patch_size;
|
||||
|
@ -88,17 +100,63 @@ void Xconvgemm<T>::DoConvgemm(const size_t channels, const size_t height, const
|
|||
const auto kernel_ld = k;
|
||||
const auto result_ld = m;
|
||||
const auto col_stride = patch_size * num_patches;
|
||||
const auto kernel_stride = size_t{0}; // applies the same kernel to all
|
||||
const auto kernel_stride = size_t{0}; // applies the same kernel to all batches
|
||||
const auto result_stride = num_kernels * output_h * output_w;
|
||||
auto gemm_event = Event();
|
||||
auto gemm = XgemmStridedBatched<T>(queue_, gemm_event.pointer());
|
||||
gemm.DoGemmStridedBatched(Layout::kColMajor, Transpose::kNo, Transpose::kNo,
|
||||
m, n, k, ConstantOne<T>(),
|
||||
col_buffer, 0, col_ld, col_stride,
|
||||
kernel_buffer, kernel_offset, kernel_ld, kernel_stride, ConstantZero<T>(),
|
||||
result_buffer, result_offset, result_ld, result_stride,
|
||||
batch_count);
|
||||
gemm_event.WaitForCompletion();
|
||||
|
||||
// Computes the transpose/conjugate options and sets the a/b/c sizes based on that
|
||||
bool col_do_transpose, kernel_do_transpose, result_do_transpose, col_conjugate, kernel_conjugate;
|
||||
size_t col_one, col_two, kernel_one, kernel_two, result_one, result_two;
|
||||
Xgemm<T>::ProcessArguments(Layout::kColMajor, Transpose::kNo, Transpose::kNo, m, n, k,
|
||||
col_one, col_two, kernel_one, kernel_two, result_one, result_two,
|
||||
col_do_transpose, kernel_do_transpose,
|
||||
result_do_transpose, col_conjugate, kernel_conjugate, 0);
|
||||
|
||||
// Tests the matrices for validity
|
||||
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||
TestMatrixA(col_one, col_two, col_buffer, col_stride * batch, col_ld);
|
||||
TestMatrixB(kernel_one, kernel_two, kernel_buffer, kernel_offset + kernel_stride * batch, kernel_ld);
|
||||
TestMatrixC(result_one, result_two, result_buffer, result_offset + result_stride * batch, result_ld);
|
||||
}
|
||||
|
||||
// Retrieves the proper XgemmDirect kernel from the compiled binary
|
||||
const auto name = (col_do_transpose) ? (kernel_do_transpose ? "XgemmDirectStridedBatchedTT" : "XgemmDirectStridedBatchedTN") :
|
||||
(kernel_do_transpose ? "XgemmDirectStridedBatchedNT" : "XgemmDirectStridedBatchedNN");
|
||||
auto kernel = Kernel(program_, name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(m));
|
||||
kernel.SetArgument(1, static_cast<int>(n));
|
||||
kernel.SetArgument(2, static_cast<int>(k));
|
||||
kernel.SetArgument(3, GetRealArg(ConstantOne<T>()));
|
||||
kernel.SetArgument(4, GetRealArg(ConstantZero<T>()));
|
||||
kernel.SetArgument(5, col_buffer());
|
||||
kernel.SetArgument(6, static_cast<int>(0));
|
||||
kernel.SetArgument(7, static_cast<int>(col_ld));
|
||||
kernel.SetArgument(8, static_cast<int>(col_stride));
|
||||
kernel.SetArgument(9, kernel_buffer());
|
||||
kernel.SetArgument(10, static_cast<int>(kernel_offset));
|
||||
kernel.SetArgument(11, static_cast<int>(kernel_ld));
|
||||
kernel.SetArgument(12, static_cast<int>(kernel_stride));
|
||||
kernel.SetArgument(13, result_buffer());
|
||||
kernel.SetArgument(14, static_cast<int>(result_offset));
|
||||
kernel.SetArgument(15, static_cast<int>(result_ld));
|
||||
kernel.SetArgument(16, static_cast<int>(result_stride));
|
||||
kernel.SetArgument(17, static_cast<int>(result_do_transpose));
|
||||
kernel.SetArgument(18, static_cast<int>(false));
|
||||
kernel.SetArgument(19, static_cast<int>(false));
|
||||
|
||||
// Computes the global and local thread sizes
|
||||
const auto m_ceiled = Ceil(m, db_["WGD"]);
|
||||
const auto n_ceiled = Ceil(n, db_["WGD"]);
|
||||
const auto global = std::vector<size_t>{
|
||||
(m_ceiled * db_["MDIMCD"]) / db_["WGD"],
|
||||
(n_ceiled * db_["NDIMCD"]) / db_["WGD"],
|
||||
batch_count
|
||||
};
|
||||
const auto local = std::vector<size_t>{db_["MDIMCD"], db_["NDIMCD"], 1};
|
||||
|
||||
// Launches the kernel
|
||||
RunKernel(kernel, queue_, device_, global, local, event_);
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
|
Loading…
Reference in New Issue