Merge pull request #142 from CNugteren/gemm_batched

Added a first batched version of the GEMM routine
This commit is contained in:
Cedric Nugteren 2017-03-19 18:27:40 +01:00 committed by GitHub
commit a21d903796
35 changed files with 1575 additions and 67 deletions

View file

@ -15,6 +15,7 @@ Development version (next release)
* STRSM/DTRSM/CTRSM/ZTRSM (experimental, un-optimized) * STRSM/DTRSM/CTRSM/ZTRSM (experimental, un-optimized)
- Added batched (non-BLAS) routines: - Added batched (non-BLAS) routines:
* SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED (batched version of AXPY) * SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED (batched version of AXPY)
* SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED (batched version of GEMM)
Version 0.10.0 Version 0.10.0
- Updated to version 8.0 of the CLCudaAPI C++11 OpenCL header - Updated to version 8.0 of the CLCudaAPI C++11 OpenCL header

View file

@ -159,7 +159,7 @@ set(LEVEL1_ROUTINES xswap xscal xcopy xaxpy xdot xdotu xdotc xnrm2 xasum xamax)
set(LEVEL2_ROUTINES xgemv xgbmv xhemv xhbmv xhpmv xsymv xsbmv xspmv xtrmv xtbmv xtpmv xtrsv set(LEVEL2_ROUTINES xgemv xgbmv xhemv xhbmv xhpmv xsymv xsbmv xspmv xtrmv xtbmv xtpmv xtrsv
xger xgeru xgerc xher xhpr xher2 xhpr2 xsyr xspr xsyr2 xspr2) xger xgeru xgerc xher xhpr xher2 xhpr2 xsyr xspr xsyr2 xspr2)
set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm xtrsm) set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm xtrsm)
set(LEVELX_ROUTINES xomatcopy xaxpybatched) set(LEVELX_ROUTINES xomatcopy xaxpybatched xgemmbatched)
set(ROUTINES ${LEVEL1_ROUTINES} ${LEVEL2_ROUTINES} ${LEVEL3_ROUTINES} ${LEVELX_ROUTINES}) set(ROUTINES ${LEVEL1_ROUTINES} ${LEVEL2_ROUTINES} ${LEVEL3_ROUTINES} ${LEVELX_ROUTINES})
set(PRECISIONS 32 64 3232 6464 16) set(PRECISIONS 32 64 3232 6464 16)

View file

@ -276,6 +276,13 @@ CLBlast supports almost all the Netlib BLAS routines plus a couple of extra non-
| xTRMM | ✔ | ✔ | ✔ | ✔ | ✔ | | xTRMM | ✔ | ✔ | ✔ | ✔ | ✔ |
| xTRSM | ✔ | ✔ | ✔ | ✔ | | (experimental, un-optimized) | xTRSM | ✔ | ✔ | ✔ | ✔ | | (experimental, un-optimized)
Futhermore, there are also batched versions of BLAS routines available, processing multiple smaller computations in one go for better performance:
| Batched | S | D | C | Z | H |
| -------------|---|---|---|---|---|
| xAXPYBATCHED | ✔ | ✔ | ✔ | ✔ | ✔ |
| xGEMMBATCHED | ✔ | ✔ | ✔ | ✔ | ✔ |
In addition, some extra non-BLAS routines are also supported by CLBlast, classified as level-X. They are experimental and should be used with care: In addition, some extra non-BLAS routines are also supported by CLBlast, classified as level-X. They are experimental and should be used with care:
| Level-X | S | D | C | Z | H | | Level-X | S | D | C | Z | H |

View file

@ -2969,6 +2969,105 @@ Arguments to AXPYBATCHED:
xGEMMBATCHED: Batched version of GEMM
-------------
As GEMM, but multiple operations are batched together for better performance.
C++ API:
```
template <typename T>
StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const T *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const T *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
```
C API:
```
CLBlastStatusCode CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const float *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const float *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const double *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const double *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_float2 *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const cl_float2 *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_double2 *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const cl_double2 *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_half *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const cl_half *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
```
Arguments to GEMMBATCHED:
* `const Layout layout`: Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.
* `const Transpose a_transpose`: Transposing the input matrix A, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.
* `const Transpose b_transpose`: Transposing the input matrix B, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.
* `const size_t m`: Integer size argument. This value must be positive.
* `const size_t n`: Integer size argument. This value must be positive.
* `const size_t k`: Integer size argument. This value must be positive.
* `const T *alphas`: Input scalar constants.
* `const cl_mem a_buffer`: OpenCL buffer to store the input A matrix.
* `const size_t *a_offsets`: The offsets in elements from the start of the input A matrix.
* `const size_t a_ld`: Leading dimension of the input A matrix. This value must be greater than 0.
* `const cl_mem b_buffer`: OpenCL buffer to store the input B matrix.
* `const size_t *b_offsets`: The offsets in elements from the start of the input B matrix.
* `const size_t b_ld`: Leading dimension of the input B matrix. This value must be greater than 0.
* `const T *betas`: Input scalar constants.
* `cl_mem c_buffer`: OpenCL buffer to store the output C matrix.
* `const size_t *c_offsets`: The offsets in elements from the start of the output C matrix.
* `const size_t c_ld`: Leading dimension of the output C matrix. This value must be greater than 0.
* `const size_t batch_count`: Number of batches. This value must be positive.
* `cl_command_queue* queue`: Pointer to an OpenCL command queue associated with a context and device to execute the routine on.
* `cl_event* event`: Pointer to an OpenCL event to be able to wait for completion of the routine's OpenCL kernel(s). This is an optional argument.
Requirements for GEMMBATCHED:
* When `transpose_a == Transpose::kNo`, then `a_ld` must be at least `m`, otherwise `a_ld` must be at least `k`.
* When `transpose_b == Transpose::kNo`, then `b_ld` must be at least `k`, otherwise `b_ld` must be at least `n`.
* The value of `c_ld` must be at least `m`.
ClearCache: Resets the cache of compiled binaries (auxiliary function) ClearCache: Resets the cache of compiled binaries (auxiliary function)
------------- -------------

View file

@ -619,6 +619,18 @@ StatusCode AxpyBatched(const size_t n,
const size_t batch_count, const size_t batch_count,
cl_command_queue* queue, cl_event* event = nullptr); cl_command_queue* queue, cl_event* event = nullptr);
// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
template <typename T>
StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const T *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const T *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event = nullptr);
// ================================================================================================= // =================================================================================================
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on // CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on

View file

@ -1360,6 +1360,53 @@ CLBlastStatusCode PUBLIC_API CLBlastHaxpyBatched(const size_t n,
const size_t batch_count, const size_t batch_count,
cl_command_queue* queue, cl_event* event); cl_command_queue* queue, cl_event* event);
// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
CLBlastStatusCode PUBLIC_API CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const float *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const float *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const double *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const double *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_float2 *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const cl_float2 *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_double2 *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const cl_double2 *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_half *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const cl_half *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
// ================================================================================================= // =================================================================================================
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on // CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on

View file

@ -41,7 +41,7 @@ FILES = [
"/include/clblast_netlib_c.h", "/include/clblast_netlib_c.h",
"/src/clblast_netlib_c.cpp", "/src/clblast_netlib_c.cpp",
] ]
HEADER_LINES = [122, 76, 126, 23, 29, 41, 65, 32] HEADER_LINES = [123, 76, 126, 23, 29, 41, 65, 32]
FOOTER_LINES = [25, 138, 27, 38, 6, 6, 9, 2] FOOTER_LINES = [25, 138, 27, 38, 6, 6, 9, 2]
HEADER_LINES_DOC = 0 HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63 FOOTER_LINES_DOC = 63
@ -163,6 +163,7 @@ ROUTINES = [
Routine(True, True, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]), Routine(True, True, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
# Batched routines: # Batched routines:
Routine(True, True, True, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []), Routine(True, True, True, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
Routine(True, True, True, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
]] ]]

View file

@ -72,6 +72,7 @@
// Level-x includes (non-BLAS) // Level-x includes (non-BLAS)
#include "routines/levelx/xomatcopy.hpp" #include "routines/levelx/xomatcopy.hpp"
#include "routines/levelx/xaxpybatched.hpp" #include "routines/levelx/xaxpybatched.hpp"
#include "routines/levelx/xgemmbatched.hpp"
namespace clblast { namespace clblast {
@ -2231,6 +2232,89 @@ template StatusCode PUBLIC_API AxpyBatched<half>(const size_t,
cl_mem, const size_t*, const size_t, cl_mem, const size_t*, const size_t,
const size_t, const size_t,
cl_command_queue*, cl_event*); cl_command_queue*, cl_event*);
// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
template <typename T>
StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const T *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const T *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
try {
auto queue_cpp = Queue(*queue);
auto routine = XgemmBatched<T>(queue_cpp, event);
auto alphas_cpp = std::vector<T>();
auto betas_cpp = std::vector<T>();
auto a_offsets_cpp = std::vector<size_t>();
auto b_offsets_cpp = std::vector<size_t>();
auto c_offsets_cpp = std::vector<size_t>();
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
alphas_cpp.push_back(alphas[batch]);
betas_cpp.push_back(betas[batch]);
a_offsets_cpp.push_back(a_offsets[batch]);
b_offsets_cpp.push_back(b_offsets[batch]);
c_offsets_cpp.push_back(c_offsets[batch]);
}
routine.DoGemmBatched(layout, a_transpose, b_transpose,
m, n, k,
alphas_cpp,
Buffer<T>(a_buffer), a_offsets_cpp, a_ld,
Buffer<T>(b_buffer), b_offsets_cpp, b_ld,
betas_cpp,
Buffer<T>(c_buffer), c_offsets_cpp, c_ld,
batch_count);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
template StatusCode PUBLIC_API GemmBatched<float>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float*,
const cl_mem, const size_t*, const size_t,
const cl_mem, const size_t*, const size_t,
const float*,
cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API GemmBatched<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double*,
const cl_mem, const size_t*, const size_t,
const cl_mem, const size_t*, const size_t,
const double*,
cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API GemmBatched<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float2*,
const cl_mem, const size_t*, const size_t,
const cl_mem, const size_t*, const size_t,
const float2*,
cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API GemmBatched<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double2*,
const cl_mem, const size_t*, const size_t,
const cl_mem, const size_t*, const size_t,
const double2*,
cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const half*,
const cl_mem, const size_t*, const size_t,
const cl_mem, const size_t*, const size_t,
const half*,
cl_mem, const size_t*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
// ================================================================================================= // =================================================================================================
// Clears the cache of stored binaries // Clears the cache of stored binaries

View file

@ -3554,6 +3554,163 @@ CLBlastStatusCode CLBlastHaxpyBatched(const size_t n,
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
} }
// GEMM
CLBlastStatusCode CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const float *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const float *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
auto alphas_cpp = std::vector<float>();
auto betas_cpp = std::vector<float>();
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
alphas_cpp.push_back(alphas[batch]);
betas_cpp.push_back(betas[batch]);
}
try {
return static_cast<CLBlastStatusCode>(
clblast::GemmBatched(static_cast<clblast::Layout>(layout),
static_cast<clblast::Transpose>(a_transpose),
static_cast<clblast::Transpose>(b_transpose),
m, n, k,
alphas_cpp.data(),
a_buffer, a_offsets, a_ld,
b_buffer, b_offsets, b_ld,
betas_cpp.data(),
c_buffer, c_offsets, c_ld,
batch_count,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
CLBlastStatusCode CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const double *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const double *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
auto alphas_cpp = std::vector<double>();
auto betas_cpp = std::vector<double>();
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
alphas_cpp.push_back(alphas[batch]);
betas_cpp.push_back(betas[batch]);
}
try {
return static_cast<CLBlastStatusCode>(
clblast::GemmBatched(static_cast<clblast::Layout>(layout),
static_cast<clblast::Transpose>(a_transpose),
static_cast<clblast::Transpose>(b_transpose),
m, n, k,
alphas_cpp.data(),
a_buffer, a_offsets, a_ld,
b_buffer, b_offsets, b_ld,
betas_cpp.data(),
c_buffer, c_offsets, c_ld,
batch_count,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
CLBlastStatusCode CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_float2 *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const cl_float2 *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
auto alphas_cpp = std::vector<float2>();
auto betas_cpp = std::vector<float2>();
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
alphas_cpp.push_back(float2{alphas[batch].s[0], alphas[batch].s[1]});
betas_cpp.push_back(float2{betas[batch].s[0], betas[batch].s[1]});
}
try {
return static_cast<CLBlastStatusCode>(
clblast::GemmBatched(static_cast<clblast::Layout>(layout),
static_cast<clblast::Transpose>(a_transpose),
static_cast<clblast::Transpose>(b_transpose),
m, n, k,
alphas_cpp.data(),
a_buffer, a_offsets, a_ld,
b_buffer, b_offsets, b_ld,
betas_cpp.data(),
c_buffer, c_offsets, c_ld,
batch_count,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
CLBlastStatusCode CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_double2 *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const cl_double2 *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
auto alphas_cpp = std::vector<double2>();
auto betas_cpp = std::vector<double2>();
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
alphas_cpp.push_back(double2{alphas[batch].s[0], alphas[batch].s[1]});
betas_cpp.push_back(double2{betas[batch].s[0], betas[batch].s[1]});
}
try {
return static_cast<CLBlastStatusCode>(
clblast::GemmBatched(static_cast<clblast::Layout>(layout),
static_cast<clblast::Transpose>(a_transpose),
static_cast<clblast::Transpose>(b_transpose),
m, n, k,
alphas_cpp.data(),
a_buffer, a_offsets, a_ld,
b_buffer, b_offsets, b_ld,
betas_cpp.data(),
c_buffer, c_offsets, c_ld,
batch_count,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_half *alphas,
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
const cl_half *betas,
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
auto alphas_cpp = std::vector<half>();
auto betas_cpp = std::vector<half>();
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
alphas_cpp.push_back(alphas[batch]);
betas_cpp.push_back(betas[batch]);
}
try {
return static_cast<CLBlastStatusCode>(
clblast::GemmBatched(static_cast<clblast::Layout>(layout),
static_cast<clblast::Transpose>(a_transpose),
static_cast<clblast::Transpose>(b_transpose),
m, n, k,
alphas_cpp.data(),
a_buffer, a_offsets, a_ld,
b_buffer, b_offsets, b_ld,
betas_cpp.data(),
c_buffer, c_offsets, c_ld,
batch_count,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
// ================================================================================================= // =================================================================================================
// Clears the cache of stored binaries // Clears the cache of stored binaries

View file

@ -24,16 +24,14 @@ R"(
// Copies a matrix from source to destination. The output is padded with zero values in case the // Copies a matrix from source to destination. The output is padded with zero values in case the
// destination matrix dimensions are larger than the source matrix dimensions. Additionally, the ld // destination matrix dimensions are larger than the source matrix dimensions. Additionally, the ld
// value and offset can be different. // value and offset can be different.
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1))) inline void _CopyPadMatrix(const int src_one, const int src_two,
void CopyPadMatrix(const int src_one, const int src_two,
const int src_ld, const int src_offset, const int src_ld, const int src_offset,
__global const real* restrict src, __global const real* restrict src,
const int dest_one, const int dest_two, const int dest_one, const int dest_two,
const int dest_ld, const int dest_offset, const int dest_ld, const int dest_offset,
__global real* dest, __global real* dest,
const real_arg arg_alpha, const real alpha,
const int do_conjugate) { const int do_conjugate) {
const real alpha = GetRealArg(arg_alpha);
// Loops over the work per thread in both dimensions // Loops over the work per thread in both dimensions
#pragma unroll #pragma unroll
@ -60,22 +58,36 @@ void CopyPadMatrix(const int src_one, const int src_two,
} }
} }
// ================================================================================================= // Interface to the above function
// Same as above, but now un-pads a matrix. This kernel reads data from a padded source matrix, but
// writes only the actual data back to the destination matrix. Again, the ld value and offset can
// be different.
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1))) __kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
void CopyMatrix(const int src_one, const int src_two, void CopyPadMatrix(const int src_one, const int src_two,
const int src_ld, const int src_offset, const int src_ld, const int src_offset,
__global const real* restrict src, __global const real* restrict src,
const int dest_one, const int dest_two, const int dest_one, const int dest_two,
const int dest_ld, const int dest_offset, const int dest_ld, const int dest_offset,
__global real* dest, __global real* dest,
const real_arg arg_alpha, const real_arg arg_alpha,
const int do_conjugate) {
const real alpha = GetRealArg(arg_alpha);
_CopyPadMatrix(src_one, src_two, src_ld, src_offset, src,
dest_one, dest_two, dest_ld, dest_offset, dest,
alpha, do_conjugate);
}
// =================================================================================================
// Same as above, but now un-pads a matrix. This kernel reads data from a padded source matrix, but
// writes only the actual data back to the destination matrix. Again, the ld value and offset can
// be different.
inline void _CopyMatrix(const int src_one, const int src_two,
const int src_ld, const int src_offset,
__global const real* restrict src,
const int dest_one, const int dest_two,
const int dest_ld, const int dest_offset,
__global real* dest,
const real alpha,
const int upper, const int lower, const int upper, const int lower,
const int diagonal_imag_zero) { const int diagonal_imag_zero) {
const real alpha = GetRealArg(arg_alpha);
// Loops over the work per thread in both dimensions // Loops over the work per thread in both dimensions
#pragma unroll #pragma unroll
@ -105,6 +117,62 @@ void CopyMatrix(const int src_one, const int src_two,
} }
} }
// Interface to the above function
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
void CopyMatrix(const int src_one, const int src_two,
const int src_ld, const int src_offset,
__global const real* restrict src,
const int dest_one, const int dest_two,
const int dest_ld, const int dest_offset,
__global real* dest,
const real_arg arg_alpha,
const int upper, const int lower,
const int diagonal_imag_zero) {
const real alpha = GetRealArg(arg_alpha);
_CopyMatrix(src_one, src_two, src_ld, src_offset, src,
dest_one, dest_two, dest_ld, dest_offset, dest,
alpha, upper, lower, diagonal_imag_zero);
}
// =================================================================================================
#if defined(ROUTINE_GEMMBATCHED)
// Batched version of the above
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
void CopyPadMatrixBatched(const int src_one, const int src_two,
const int src_ld, const __constant int* src_offsets,
__global const real* restrict src,
const int dest_one, const int dest_two,
const int dest_ld, const __constant int* dest_offsets,
__global real* dest,
const int do_conjugate) {
const int batch = get_group_id(2);
const int src_offset = src_offsets[batch];
const int dest_offset = dest_offsets[batch];
real alpha; SetToOne(alpha);
_CopyPadMatrix(src_one, src_two, src_ld, src_offset, src,
dest_one, dest_two, dest_ld, dest_offset, dest,
alpha, do_conjugate);
}
// Batched version of the above
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
void CopyMatrixBatched(const int src_one, const int src_two,
const int src_ld, const __constant int* src_offsets,
__global const real* restrict src,
const int dest_one, const int dest_two,
const int dest_ld, const __constant int* dest_offsets,
__global real* dest) {
const int batch = get_group_id(2);
const int src_offset = src_offsets[batch];
const int dest_offset = dest_offsets[batch];
real alpha; SetToOne(alpha);
_CopyMatrix(src_one, src_two, src_ld, src_offset, src,
dest_one, dest_two, dest_ld, dest_offset, dest,
alpha, 0, 0, 0);
}
#endif
// ================================================================================================= // =================================================================================================
// End of the C++11 raw string literal // End of the C++11 raw string literal

View file

@ -24,19 +24,15 @@ R"(
// Transposes a matrix from source to destination. The output is padded with zero values in case the // Transposes a matrix from source to destination. The output is padded with zero values in case the
// destination matrix dimensions are larger than the transposed source matrix dimensions. // destination matrix dimensions are larger than the transposed source matrix dimensions.
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1))) inline void _TransposePadMatrix(__local real* tile,
void TransposePadMatrix(const int src_one, const int src_two, const int src_one, const int src_two,
const int src_ld, const int src_offset, const int src_ld, const int src_offset,
__global const real* restrict src, __global const real* restrict src,
const int dest_one, const int dest_two, const int dest_one, const int dest_two,
const int dest_ld, const int dest_offset, const int dest_ld, const int dest_offset,
__global real* dest, __global real* dest,
const real_arg arg_alpha, const real alpha,
const int do_conjugate) { const int do_conjugate) {
const real alpha = GetRealArg(arg_alpha);
// Local memory to store a tile of the matrix (for coalescing)
__local real tile[PADTRA_WPT*PADTRA_TILE][PADTRA_WPT*PADTRA_TILE + PADTRA_PAD];
// Loop over the work per thread // Loop over the work per thread
#pragma unroll #pragma unroll
@ -56,7 +52,9 @@ void TransposePadMatrix(const int src_one, const int src_two,
if (id_src_two < src_two && id_src_one < src_one) { if (id_src_two < src_two && id_src_one < src_one) {
value = src[id_src_two*src_ld + id_src_one + src_offset]; value = src[id_src_two*src_ld + id_src_one + src_offset];
} }
tile[get_local_id(1)*PADTRA_WPT + w_two][get_local_id(0)*PADTRA_WPT + w_one] = value; const int tile_id0 = get_local_id(0)*PADTRA_WPT + w_one;
const int tile_id1 = get_local_id(1)*PADTRA_WPT + w_two;
tile[tile_id1 * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD) + tile_id0] = value;
} }
} }
@ -75,7 +73,9 @@ void TransposePadMatrix(const int src_one, const int src_two,
// Stores the transposed value in the destination matrix // Stores the transposed value in the destination matrix
if ((id_dest_one < dest_one) && (id_dest_two < dest_two)) { if ((id_dest_one < dest_one) && (id_dest_two < dest_two)) {
real value = tile[get_local_id(0)*PADTRA_WPT + w_two][get_local_id(1)*PADTRA_WPT + w_one]; const int tile_id0 = get_local_id(1)*PADTRA_WPT + w_one;
const int tile_id1 = get_local_id(0)*PADTRA_WPT + w_two;
real value = tile[tile_id1 * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD) + tile_id0];
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); } if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
Multiply(dest[id_dest_two*dest_ld + id_dest_one + dest_offset], alpha, value); Multiply(dest[id_dest_two*dest_ld + id_dest_one + dest_offset], alpha, value);
} }
@ -83,25 +83,38 @@ void TransposePadMatrix(const int src_one, const int src_two,
} }
} }
// ================================================================================================= // Interface to the above function
// Transposes a matrix, while considering possible padding in the source matrix. Data is read from a
// padded source matrix, but only the actual data is written back to the transposed destination
// matrix. This kernel optionally checks for upper/lower triangular matrices.
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1))) __kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
void TransposeMatrix(const int src_one, const int src_two, void TransposePadMatrix(const int src_one, const int src_two,
const int src_ld, const int src_offset, const int src_ld, const int src_offset,
__global const real* restrict src, __global const real* restrict src,
const int dest_one, const int dest_two, const int dest_one, const int dest_two,
const int dest_ld, const int dest_offset, const int dest_ld, const int dest_offset,
__global real* dest, __global real* dest,
const real_arg arg_alpha, const real_arg arg_alpha,
const int do_conjugate) {
const real alpha = GetRealArg(arg_alpha);
__local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
_TransposePadMatrix(tile, src_one, src_two, src_ld, src_offset, src,
dest_one, dest_two, dest_ld, dest_offset, dest,
alpha, do_conjugate);
}
// =================================================================================================
// Transposes a matrix, while considering possible padding in the source matrix. Data is read from a
// padded source matrix, but only the actual data is written back to the transposed destination
// matrix. This kernel optionally checks for upper/lower triangular matrices.
inline void _TransposeMatrix(__local real* tile,
const int src_one, const int src_two,
const int src_ld, const int src_offset,
__global const real* restrict src,
const int dest_one, const int dest_two,
const int dest_ld, const int dest_offset,
__global real* dest,
const real alpha,
const int upper, const int lower, const int upper, const int lower,
const int diagonal_imag_zero) { const int diagonal_imag_zero) {
const real alpha = GetRealArg(arg_alpha);
// Local memory to store a tile of the matrix (for coalescing)
__local real tile[PADTRA_WPT*PADTRA_TILE][PADTRA_WPT*PADTRA_TILE + PADTRA_PAD];
// Loop over the work per thread // Loop over the work per thread
#pragma unroll #pragma unroll
@ -117,7 +130,9 @@ void TransposeMatrix(const int src_one, const int src_two,
// Loads data into the local memory if the thread IDs are within bounds of the source matrix. // Loads data into the local memory if the thread IDs are within bounds of the source matrix.
if ((id_src_one < src_one) && (id_src_two < src_two)) { if ((id_src_one < src_one) && (id_src_two < src_two)) {
real value = src[id_src_two*src_ld + id_src_one + src_offset]; real value = src[id_src_two*src_ld + id_src_one + src_offset];
tile[get_local_id(1)*PADTRA_WPT + w_two][get_local_id(0)*PADTRA_WPT + w_one] = value; const int tile_id0 = get_local_id(0)*PADTRA_WPT + w_one;
const int tile_id1 = get_local_id(1)*PADTRA_WPT + w_two;
tile[tile_id1 * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD) + tile_id0] = value;
} }
} }
} }
@ -145,7 +160,9 @@ void TransposeMatrix(const int src_one, const int src_two,
// Stores the transposed value in the destination matrix // Stores the transposed value in the destination matrix
if ((id_dest_one < dest_one) && (id_dest_two < dest_two)) { if ((id_dest_one < dest_one) && (id_dest_two < dest_two)) {
real value = tile[get_local_id(0)*PADTRA_WPT + w_two][get_local_id(1)*PADTRA_WPT + w_one]; const int tile_id0 = get_local_id(1)*PADTRA_WPT + w_one;
const int tile_id1 = get_local_id(0)*PADTRA_WPT + w_two;
real value = tile[tile_id1 * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD) + tile_id0];
if (diagonal_imag_zero == 1 && id_dest_one == id_dest_two) { ImagToZero(value); } if (diagonal_imag_zero == 1 && id_dest_one == id_dest_two) { ImagToZero(value); }
Multiply(dest[id_dest_two*dest_ld + id_dest_one + dest_offset], alpha, value); Multiply(dest[id_dest_two*dest_ld + id_dest_one + dest_offset], alpha, value);
} }
@ -154,6 +171,65 @@ void TransposeMatrix(const int src_one, const int src_two,
} }
} }
// Interface to the above function
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
void TransposeMatrix(const int src_one, const int src_two,
const int src_ld, const int src_offset,
__global const real* restrict src,
const int dest_one, const int dest_two,
const int dest_ld, const int dest_offset,
__global real* dest,
const real_arg arg_alpha,
const int upper, const int lower,
const int diagonal_imag_zero) {
const real alpha = GetRealArg(arg_alpha);
__local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
_TransposeMatrix(tile, src_one, src_two, src_ld, src_offset, src,
dest_one, dest_two, dest_ld, dest_offset, dest,
alpha, upper, lower, diagonal_imag_zero);
}
// =================================================================================================
#if defined(ROUTINE_GEMMBATCHED)
// Batched version of the above
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
void TransposePadMatrixBatched(const int src_one, const int src_two,
const int src_ld, const __constant int* src_offsets,
__global const real* restrict src,
const int dest_one, const int dest_two,
const int dest_ld, const __constant int* dest_offsets,
__global real* dest,
const int do_conjugate) {
const int batch = get_group_id(2);
const int src_offset = src_offsets[batch];
const int dest_offset = dest_offsets[batch];
real alpha; SetToOne(alpha);
__local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
_TransposePadMatrix(tile, src_one, src_two, src_ld, src_offset, src,
dest_one, dest_two, dest_ld, dest_offset, dest,
alpha, do_conjugate);
}
// Batched version of the above
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
void TransposeMatrixBatched(const int src_one, const int src_two,
const int src_ld, const __constant int* src_offsets,
__global const real* restrict src,
const int dest_one, const int dest_two,
const int dest_ld, const __constant int* dest_offsets,
__global real* dest) {
const int batch = get_group_id(2);
const int src_offset = src_offsets[batch];
const int dest_offset = dest_offsets[batch];
real alpha; SetToOne(alpha);
__local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
_TransposeMatrix(tile, src_one, src_two, src_ld, src_offset, src,
dest_one, dest_two, dest_ld, dest_offset, dest,
alpha, 0, 0, 0);
}
#endif
// ================================================================================================= // =================================================================================================
// End of the C++11 raw string literal // End of the C++11 raw string literal

View file

@ -0,0 +1,70 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file contains the batched version of the non-direct GEMM kernel. See part 1 for information
// about the non-batched version of the kernel.
//
// =================================================================================================
// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
// literal). Comment-out this line for syntax-highlighting when developing.
R"(
// =================================================================================================
// Main entry point of the kernel. This is the regular full version.
__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK,
const __constant real_arg* arg_alphas,
const __constant real_arg* arg_betas,
const __global realM* restrict agm, const int a_one, const int a_two,
const __global realN* restrict bgm, const int b_one, const int b_two,
__global realM* cgm, const int c_one, const int c_two) {
const int batch = get_group_id(2);
const real alpha = GetRealArg(arg_alphas[batch]);
const real beta = GetRealArg(arg_betas[batch]);
// Sets the offsets
const int a_offset = batch * a_one * a_two;
const int b_offset = batch * b_one * b_two;
const int c_offset = batch * c_one * c_two;
const __global realM* restrict agm_ = &agm[a_offset / VWM];
const __global realN* restrict bgm_ = &bgm[b_offset / VWN];
__global realM* restrict cgm_ = &cgm[c_offset / VWM];
// Allocates workgroup-private memory (local memory)
#if SA == 1
__local realM alm[KWG * MWG/VWM];
#endif
#if SB == 1
__local realN blm[KWG * NWG/VWN];
#endif
// Computes the matrix-multiplication and stores the result in register memory
realM cpm[NWI][MWI/VWM];
#if SA == 1 && SB == 1
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, alm, blm);
#elif SA == 1
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, alm);
#elif SB == 1
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, blm);
#else
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm);
#endif
// Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
StoreResults(cgm_, cpm, kSizeM, alpha, beta);
}
// =================================================================================================
// End of the C++11 raw string literal
)"
// =================================================================================================

View file

@ -0,0 +1,110 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file contains the batched version of the direct GEMM kernels. See part 1 for information
// about the non-batched version of the kernel.
//
// =================================================================================================
// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
// literal). Comment-out this line for syntax-highlighting when developing.
R"(
// =================================================================================================
// Direct version of the batched GEMM kernel with [A, B] = [non-transposed, non-transposed]
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
__kernel void XgemmDirectBatchedNN(const int kSizeM, const int kSizeN, const int kSizeK,
const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
__global real* cgm, const __constant int* c_offsets, const int c_ld,
const int c_transpose, const int a_conjugate, const int b_conjugate) {
const int batch = get_group_id(2);
const real_arg arg_alpha = arg_alphas[batch];
const real_arg arg_beta = arg_betas[batch];
const int a_offset = a_offsets[batch];
const int b_offset = b_offsets[batch];
const int c_offset = c_offsets[batch];
__local real alm[WGD * (WGD + PADA)];
__local real blm[WGD * (WGD + PADB)];
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
alm, blm, 0, 0, c_transpose, a_conjugate, b_conjugate);
}
// Direct version of the batched GEMM kernel with [A, B] = [non-transposed, transposed]
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
__kernel void XgemmDirectBatchedNT(const int kSizeM, const int kSizeN, const int kSizeK,
const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
__global real* cgm, const __constant int* c_offsets, const int c_ld,
const int c_transpose, const int a_conjugate, const int b_conjugate) {
const int batch = get_group_id(2);
const real_arg arg_alpha = arg_alphas[batch];
const real_arg arg_beta = arg_betas[batch];
const int a_offset = a_offsets[batch];
const int b_offset = b_offsets[batch];
const int c_offset = c_offsets[batch];
__local real alm[WGD * (WGD + PADA)];
__local real blm[WGD * (WGD + PADB)];
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
alm, blm, 0, 1, c_transpose, a_conjugate, b_conjugate);
}
// Direct version of the batched GEMM kernel with [A, B] = [transposed, non-transposed]
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
__kernel void XgemmDirectBatchedTN(const int kSizeM, const int kSizeN, const int kSizeK,
const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
__global real* cgm, const __constant int* c_offsets, const int c_ld,
const int c_transpose, const int a_conjugate, const int b_conjugate) {
const int batch = get_group_id(2);
const real_arg arg_alpha = arg_alphas[batch];
const real_arg arg_beta = arg_betas[batch];
const int a_offset = a_offsets[batch];
const int b_offset = b_offsets[batch];
const int c_offset = c_offsets[batch];
__local real alm[WGD * (WGD + PADA)];
__local real blm[WGD * (WGD + PADB)];
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
alm, blm, 1, 0, c_transpose, a_conjugate, b_conjugate);
}
// Direct version of the batched GEMM kernel with [A, B] = [transposed, transposed]
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
__kernel void XgemmDirectBatchedTT(const int kSizeM, const int kSizeN, const int kSizeK,
const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
__global real* cgm, const __constant int* c_offsets, const int c_ld,
const int c_transpose, const int a_conjugate, const int b_conjugate) {
const int batch = get_group_id(2);
const real_arg arg_alpha = arg_alphas[batch];
const real_arg arg_beta = arg_betas[batch];
const int a_offset = a_offsets[batch];
const int b_offset = b_offsets[batch];
const int c_offset = c_offsets[batch];
__local real alm[WGD * (WGD + PADA)];
__local real blm[WGD * (WGD + PADB)];
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
alm, blm, 1, 1, c_transpose, a_conjugate, b_conjugate);
}
// =================================================================================================
// End of the C++11 raw string literal
)"
// =================================================================================================

View file

@ -196,6 +196,70 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device,
} }
} }
// Batched version of the above
template <typename T>
void PadCopyTransposeMatrixBatched(Queue &queue, const Device &device,
const Databases &db,
EventPointer event, const std::vector<Event> &waitForEvents,
const size_t src_one, const size_t src_two,
const size_t src_ld, const Buffer<int> &src_offsets,
const Buffer<T> &src,
const size_t dest_one, const size_t dest_two,
const size_t dest_ld, const Buffer<int> &dest_offsets,
const Buffer<T> &dest,
const Program &program, const bool do_pad,
const bool do_transpose, const bool do_conjugate,
const size_t batch_count) {
// Determines the right kernel
auto kernel_name = std::string{};
if (do_transpose) {
kernel_name = (do_pad) ? "TransposePadMatrixBatched" : "TransposeMatrixBatched";
}
else {
kernel_name = (do_pad) ? "CopyPadMatrixBatched" : "CopyMatrixBatched";
}
// Retrieves the kernel from the compiled binary
auto kernel = Kernel(program, kernel_name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(src_one));
kernel.SetArgument(1, static_cast<int>(src_two));
kernel.SetArgument(2, static_cast<int>(src_ld));
kernel.SetArgument(3, src_offsets());
kernel.SetArgument(4, src());
kernel.SetArgument(5, static_cast<int>(dest_one));
kernel.SetArgument(6, static_cast<int>(dest_two));
kernel.SetArgument(7, static_cast<int>(dest_ld));
kernel.SetArgument(8, dest_offsets());
kernel.SetArgument(9, dest());
if (do_pad) {
kernel.SetArgument(10, static_cast<int>(do_conjugate));
}
// Launches the kernel and returns the error code. Uses global and local thread sizes based on
// parameters in the database.
if (do_transpose) {
const auto global = std::vector<size_t>{
Ceil(CeilDiv(dest_one, db["PADTRA_WPT"]), db["PADTRA_TILE"]),
Ceil(CeilDiv(dest_two, db["PADTRA_WPT"]), db["PADTRA_TILE"]),
batch_count
};
const auto local = std::vector<size_t>{db["PADTRA_TILE"], db["PADTRA_TILE"], 1};
RunKernel(kernel, queue, device, global, local, event, waitForEvents);
}
else {
const auto global = std::vector<size_t>{
Ceil(CeilDiv(dest_one, db["PAD_WPTX"]), db["PAD_DIMX"]),
Ceil(CeilDiv(dest_two, db["PAD_WPTY"]), db["PAD_DIMY"]),
batch_count
};
const auto local = std::vector<size_t>{db["PAD_DIMX"], db["PAD_DIMY"], 1};
RunKernel(kernel, queue, device, global, local, event, waitForEvents);
}
}
// ================================================================================================= // =================================================================================================
} // namespace clblast } // namespace clblast

View file

@ -104,13 +104,13 @@ void Xgemm<T>::DoGemm(const Layout layout,
// Selects which version of GEMM to run // Selects which version of GEMM to run
const auto do_gemm_direct = (m * n * k < db_["XGEMM_MIN_INDIRECT_SIZE"]); const auto do_gemm_direct = (m * n * k < db_["XGEMM_MIN_INDIRECT_SIZE"]);
if (do_gemm_direct) { // for small sizes (single kernel) if (do_gemm_direct) { // for small sizes (single kernel)
return GemmDirect(m, n, k, alpha, GemmDirect(m, n, k, alpha,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
c_buffer, c_offset, c_ld, c_buffer, c_offset, c_ld,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate); a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
} }
else { // for larger sizes (pre/post-processing plus a very fast kernel) else { // for larger sizes (pre/post-processing plus a very fast kernel)
return GemmIndirect(m, n, k, alpha, GemmIndirect(m, n, k, alpha,
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta, a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
c_buffer, c_offset, c_ld, c_buffer, c_offset, c_ld,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate, a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,

View file

@ -0,0 +1,350 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements the XgemmBatched class (see the header for information about the class).
//
// =================================================================================================
#include "routines/levelx/xgemmbatched.hpp"
#include <string>
#include <vector>
namespace clblast {
// =================================================================================================
// Constructor: forwards to base class constructor
template <typename T>
XgemmBatched<T>::XgemmBatched(Queue &queue, EventPointer event, const std::string &name):
Routine(queue, event, name,
{"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"},
PrecisionValue<T>(), {}, {
#include "../../kernels/level3/level3.opencl"
#include "../../kernels/level3/copy_fast.opencl"
#include "../../kernels/level3/copy_pad.opencl"
#include "../../kernels/level3/transpose_fast.opencl"
#include "../../kernels/level3/transpose_pad.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_direct_part1.opencl"
#include "../../kernels/level3/xgemm_direct_part2.opencl"
#include "../../kernels/level3/xgemm_direct_part3.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part1.opencl"
#include "../../kernels/level3/xgemm_part2.opencl"
#include "../../kernels/level3/xgemm_part3.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_batched.opencl"
#include "../../kernels/level3/xgemm_direct_batched.opencl"
}) {
}
// =================================================================================================
// The main routine
template <typename T>
void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const std::vector<T> &alphas,
const Buffer<T> & a_buffer, const std::vector<size_t> &a_offsets, const size_t a_ld,
const Buffer<T> & b_buffer, const std::vector<size_t> &b_offsets, const size_t b_ld,
const std::vector<T> &betas,
const Buffer<T> & c_buffer, const std::vector<size_t> &c_offsets, const size_t c_ld,
const size_t batch_count) {
// Tests for a valid batch count
if ((batch_count < 1) || (alphas.size() != batch_count) || (betas.size() != batch_count) ||
(a_offsets.size() != batch_count) || (b_offsets.size() != batch_count) || (c_offsets.size() != batch_count)) {
throw BLASError(StatusCode::kInvalidBatchCount);
}
// Makes sure all dimensions are larger than zero
if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
// Computes whether or not the matrices are transposed in memory. See GEMM routine for details.
const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) ||
(layout == Layout::kRowMajor && a_transpose == Transpose::kNo);
const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
(layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
const auto c_rotated = (layout == Layout::kRowMajor);
static const auto a_want_rotated = false;
static const auto b_want_rotated = true;
static const auto c_want_rotated = false;
const auto a_do_transpose = a_rotated != a_want_rotated;
const auto b_do_transpose = b_rotated != b_want_rotated;
const auto c_do_transpose = c_rotated != c_want_rotated;
// In case of complex data-types, the transpose can also become a conjugate transpose
const auto a_conjugate = (a_transpose == Transpose::kConjugate);
const auto b_conjugate = (b_transpose == Transpose::kConjugate);
// Computes the first and second dimensions of the 3 matrices taking into account whether the
// matrices are rotated or not
const auto a_one = (a_rotated) ? k : m;
const auto a_two = (a_rotated) ? m : k;
const auto b_one = (b_rotated) ? n : k;
const auto b_two = (b_rotated) ? k : n;
const auto c_one = (c_rotated) ? n : m;
const auto c_two = (c_rotated) ? m : n;
// Tests the matrices for validity
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
TestMatrixA(a_one, a_two, a_buffer, a_offsets[batch], a_ld);
TestMatrixB(b_one, b_two, b_buffer, b_offsets[batch], b_ld);
TestMatrixC(c_one, c_two, c_buffer, c_offsets[batch], c_ld);
}
// Upload the scalar arguments to the device
auto alphas_device = Buffer<T>(context_, BufferAccess::kReadOnly, batch_count);
auto betas_device = Buffer<T>(context_, BufferAccess::kReadOnly, batch_count);
alphas_device.Write(queue_, batch_count, alphas);
betas_device.Write(queue_, batch_count, betas);
// Converts the offset to integers
std::vector<int> a_offsets_int(a_offsets.begin(), a_offsets.end());
std::vector<int> b_offsets_int(b_offsets.begin(), b_offsets.end());
std::vector<int> c_offsets_int(c_offsets.begin(), c_offsets.end());
// Selects which version of the batched GEMM to run
const auto do_gemm_direct = true;
if (do_gemm_direct) { // single generic kernel
BatchedGemmDirect(m, n, k, alphas_device,
a_buffer, a_offsets_int, a_ld, b_buffer, b_offsets_int, b_ld,
betas_device, c_buffer, c_offsets_int, c_ld,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
batch_count);
}
else { // pre/post-processing plus a very fast kernel
BatchedGemmIndirect(m, n, k, alphas_device,
a_buffer, a_offsets_int, a_ld, b_buffer, b_offsets_int, b_ld,
betas_device, c_buffer, c_offsets_int, c_ld,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
a_one, a_two, a_want_rotated,
b_one, b_two, b_want_rotated,
c_one, c_two, c_want_rotated,
batch_count);
}
}
// =================================================================================================
// The indirect version of batched GEMM. This uses the faster but non-general kernel. It has specific
// requirements, but several pre and post-processing kernels take care of those. However, the
// overhead of these extra kernels might not be ideal for certain devices/arguments.
template <typename T>
void XgemmBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, const size_t k,
const Buffer<T> &alphas,
const Buffer<T> &a_buffer, const std::vector<int> &a_offsets, const size_t a_ld,
const Buffer<T> &b_buffer, const std::vector<int> &b_offsets, const size_t b_ld,
const Buffer<T> &betas,
const Buffer<T> &c_buffer, const std::vector<int> &c_offsets, const size_t c_ld,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
const size_t a_one, const size_t a_two, const bool a_want_rotated,
const size_t b_one, const size_t b_two, const bool b_want_rotated,
const size_t c_one, const size_t c_two, const bool c_want_rotated,
const size_t batch_count) {
// Calculates the ceiled versions of m, n, and k
const auto m_ceiled = Ceil(Ceil(m, db_["MWG"]), db_["VWM"]);
const auto n_ceiled = Ceil(Ceil(n, db_["NWG"]), db_["VWN"]);
const auto k_ceiled = Ceil(Ceil(k, db_["KWG"]), db_["VWM"]);
// Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
// whether the matrices need to be rotated or not for the kernel.
const auto a_one_i = (a_want_rotated) ? k_ceiled : m_ceiled;
const auto a_two_i = (a_want_rotated) ? m_ceiled : k_ceiled;
const auto b_one_i = (b_want_rotated) ? n_ceiled : k_ceiled;
const auto b_two_i = (b_want_rotated) ? k_ceiled : n_ceiled;
const auto c_one_i = (c_want_rotated) ? n_ceiled : m_ceiled;
const auto c_two_i = (c_want_rotated) ? m_ceiled : n_ceiled;
// Sets the "internal" offsets, i.e. the perfect offsets
auto a_offsets_i = std::vector<int>(batch_count);
auto b_offsets_i = std::vector<int>(batch_count);
auto c_offsets_i = std::vector<int>(batch_count);
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
a_offsets_i[batch] = batch * a_one_i * a_two_i;
b_offsets_i[batch] = batch * b_one_i * b_two_i;
c_offsets_i[batch] = batch * c_one_i * c_two_i;
}
// Determines whether or not temporary matrices are needed
auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offsets == a_offsets_i &&
a_do_transpose == false && a_conjugate == false;
auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && b_offsets == b_offsets_i &&
b_do_transpose == false && b_conjugate == false;
auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offsets == c_offsets_i &&
c_do_transpose == false;
// Creates the temporary matrices
const auto a_temp = (a_no_temp) ? a_buffer : Buffer<T>(context_, batch_count * a_one_i * a_two_i);
const auto b_temp = (b_no_temp) ? b_buffer : Buffer<T>(context_, batch_count * b_one_i * b_two_i);
const auto c_temp = (c_no_temp) ? c_buffer : Buffer<T>(context_, batch_count * c_one_i * c_two_i);
// Events of all kernels (including pre/post processing kernels)
auto eventWaitList = std::vector<Event>();
auto emptyEventList = std::vector<Event>();
// Runs the pre-processing kernel for matrix A. This transposes the matrix, but also pads zeros
// to fill it up until it reaches a certain multiple of size (kernel parameter dependent). In
// case nothing has to be done, these kernels can be skipped.
if (!a_no_temp) {
auto a_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
auto a_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
a_offsets_device.Write(queue_, batch_count, a_offsets);
a_offsets_i_device.Write(queue_, batch_count, a_offsets_i);
auto eventProcessA = Event();
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
a_one, a_two, a_ld, a_offsets_device, a_buffer,
a_one_i, a_two_i, a_one_i, a_offsets_i_device, a_temp,
program_, true, a_do_transpose, a_conjugate, batch_count);
eventWaitList.push_back(eventProcessA);
}
// As above, but now for matrix B
if (!b_no_temp) {
auto b_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
auto b_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
b_offsets_device.Write(queue_, batch_count, b_offsets);
b_offsets_i_device.Write(queue_, batch_count, b_offsets_i);
auto eventProcessB = Event();
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
b_one, b_two, b_ld, b_offsets_device, b_buffer,
b_one_i, b_two_i, b_one_i, b_offsets_i_device, b_temp,
program_, true, b_do_transpose, b_conjugate, batch_count);
eventWaitList.push_back(eventProcessB);
}
// As above, but now for matrix C
auto c_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
auto c_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
if (!c_no_temp) {
c_offsets_device.Write(queue_, batch_count, c_offsets);
c_offsets_i_device.Write(queue_, batch_count, c_offsets_i);
auto eventProcessC = Event();
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
c_one, c_two, c_ld, c_offsets_device, c_buffer,
c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp,
program_, true, c_do_transpose, false, batch_count);
eventWaitList.push_back(eventProcessC);
}
// Retrieves the Xgemm kernel from the compiled binary
auto kernel = Kernel(program_, "XgemmBatched");
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(m_ceiled));
kernel.SetArgument(1, static_cast<int>(n_ceiled));
kernel.SetArgument(2, static_cast<int>(k_ceiled));
kernel.SetArgument(3, alphas());
kernel.SetArgument(4, betas());
kernel.SetArgument(5, a_temp());
kernel.SetArgument(6, static_cast<int>(a_one_i));
kernel.SetArgument(7, static_cast<int>(a_two_i));
kernel.SetArgument(8, b_temp());
kernel.SetArgument(9, static_cast<int>(b_one_i));
kernel.SetArgument(10, static_cast<int>(b_two_i));
kernel.SetArgument(11, c_temp());
kernel.SetArgument(12, static_cast<int>(c_one_i));
kernel.SetArgument(13, static_cast<int>(c_two_i));
// Computes the global and local thread sizes
const auto global = std::vector<size_t>{
(c_one_i * db_["MDIMC"]) / db_["MWG"],
(c_two_i * db_["NDIMC"]) / db_["NWG"],
batch_count
};
const auto local = std::vector<size_t>{db_["MDIMC"], db_["NDIMC"], 1};
// Launches the kernel
auto eventKernel = Event();
auto eventPointer = eventKernel.pointer();
RunKernel(kernel, queue_, device_, global, local, eventPointer, eventWaitList);
// Runs the post-processing kernel if needed
if (!c_no_temp) {
eventWaitList.push_back(eventKernel);
PadCopyTransposeMatrixBatched(queue_, device_, db_, event_, eventWaitList,
c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp,
c_one, c_two, c_ld, c_offsets_device, c_buffer,
program_, false, c_do_transpose, false, batch_count);
}
}
// =================================================================================================
// The direct version of batched GEMM, requiring just one kernel, no pre or post-processing kernels.
template <typename T>
void XgemmBatched<T>::BatchedGemmDirect(const size_t m, const size_t n, const size_t k,
const Buffer<T> &alphas,
const Buffer<T> &a_buffer, const std::vector<int> &a_offsets, const size_t a_ld,
const Buffer<T> &b_buffer, const std::vector<int> &b_offsets, const size_t b_ld,
const Buffer<T> &betas,
const Buffer<T> &c_buffer, const std::vector<int> &c_offsets, const size_t c_ld,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
const size_t batch_count) {
// Uploads the offsets to the device
auto a_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
auto b_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
auto c_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
a_offsets_device.Write(queue_, batch_count, a_offsets);
b_offsets_device.Write(queue_, batch_count, b_offsets);
c_offsets_device.Write(queue_, batch_count, c_offsets);
// Retrieves the proper XgemmDirect kernel from the compiled binary
const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectBatchedTT" : "XgemmDirectBatchedTN") :
(b_do_transpose ? "XgemmDirectBatchedNT" : "XgemmDirectBatchedNN");
auto kernel = Kernel(program_, name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(m));
kernel.SetArgument(1, static_cast<int>(n));
kernel.SetArgument(2, static_cast<int>(k));
kernel.SetArgument(3, alphas());
kernel.SetArgument(4, betas());
kernel.SetArgument(5, a_buffer());
kernel.SetArgument(6, a_offsets_device());
kernel.SetArgument(7, static_cast<int>(a_ld));
kernel.SetArgument(8, b_buffer());
kernel.SetArgument(9, b_offsets_device());
kernel.SetArgument(10, static_cast<int>(b_ld));
kernel.SetArgument(11, c_buffer());
kernel.SetArgument(12, c_offsets_device());
kernel.SetArgument(13, static_cast<int>(c_ld));
kernel.SetArgument(14, static_cast<int>(c_do_transpose));
kernel.SetArgument(15, static_cast<int>(a_conjugate));
kernel.SetArgument(16, static_cast<int>(b_conjugate));
// Computes the global and local thread sizes
const auto m_ceiled = Ceil(m, db_["WGD"]);
const auto n_ceiled = Ceil(n, db_["WGD"]);
const auto global = std::vector<size_t>{
(m_ceiled * db_["MDIMCD"]) / db_["WGD"],
(n_ceiled * db_["NDIMCD"]) / db_["WGD"],
batch_count
};
const auto local = std::vector<size_t>{db_["MDIMCD"], db_["NDIMCD"], 1};
// Launches the kernel
RunKernel(kernel, queue_, device_, global, local, event_);
}
// =================================================================================================
// Compiles the templated class
template class XgemmBatched<half>;
template class XgemmBatched<float>;
template class XgemmBatched<double>;
template class XgemmBatched<float2>;
template class XgemmBatched<double2>;
// =================================================================================================
} // namespace clblast

View file

@ -0,0 +1,72 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements the XgemmBatched routine. This is a non-blas batched version of GEMM.
//
// =================================================================================================
#ifndef CLBLAST_ROUTINES_XGEMMBATCHED_H_
#define CLBLAST_ROUTINES_XGEMMBATCHED_H_
#include <vector>
#include "routine.hpp"
namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
template <typename T>
class XgemmBatched: public Routine {
public:
// Constructor
XgemmBatched(Queue &queue, EventPointer event, const std::string &name = "GEMMBATCHED");
// Templated-precision implementation of the routine
void DoGemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const std::vector<T> &alphas,
const Buffer<T> & a_buffer, const std::vector<size_t> &a_offsets, const size_t a_ld,
const Buffer<T> & b_buffer, const std::vector<size_t> &b_offsets, const size_t b_ld,
const std::vector<T> &betas,
const Buffer<T> & c_buffer, const std::vector<size_t> &c_offsets, const size_t c_ld,
const size_t batch_count);
// Indirect version of batched GEMM (with pre and post-processing kernels)
void BatchedGemmIndirect(const size_t m, const size_t n, const size_t k,
const Buffer<T> &alphas,
const Buffer<T> &a_buffer, const std::vector<int> &a_offsets, const size_t a_ld,
const Buffer<T> &b_buffer, const std::vector<int> &b_offsets, const size_t b_ld,
const Buffer<T> &betas,
const Buffer<T> &c_buffer, const std::vector<int> &c_offsets, const size_t c_ld,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
const size_t a_one, const size_t a_two, const bool a_want_rotated,
const size_t b_one, const size_t b_two, const bool b_want_rotated,
const size_t c_one, const size_t c_two, const bool c_want_rotated,
const size_t batch_count);
// Direct version of batched GEMM (no pre and post-processing kernels)
void BatchedGemmDirect(const size_t m, const size_t n, const size_t k,
const Buffer<T> &alphas,
const Buffer<T> &a_buffer, const std::vector<int> &a_offsets, const size_t a_ld,
const Buffer<T> &b_buffer, const std::vector<int> &b_offsets, const size_t b_ld,
const Buffer<T> &betas,
const Buffer<T> &c_buffer, const std::vector<int> &c_offsets, const size_t c_ld,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
const size_t batch_count);
};
// =================================================================================================
} // namespace clblast
// CLBLAST_ROUTINES_XGEMMBATCHED_H_
#endif

View file

@ -46,6 +46,7 @@ class TuneCopy {
static size_t DefaultM() { return 1024; } static size_t DefaultM() { return 1024; }
static size_t DefaultN() { return 1024; } static size_t DefaultN() { return 1024; }
static size_t DefaultK() { return 1; } // N/A for this kernel static size_t DefaultK() { return 1; } // N/A for this kernel
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
static double DefaultFraction() { return 1.0; } // N/A for this kernel static double DefaultFraction() { return 1.0; } // N/A for this kernel
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging

View file

@ -46,6 +46,7 @@ class TunePad {
static size_t DefaultM() { return 1024; } static size_t DefaultM() { return 1024; }
static size_t DefaultN() { return 1024; } static size_t DefaultN() { return 1024; }
static size_t DefaultK() { return 1; } // N/A for this kernel static size_t DefaultK() { return 1; } // N/A for this kernel
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
static double DefaultFraction() { return 1.0; } // N/A for this kernel static double DefaultFraction() { return 1.0; } // N/A for this kernel
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging

View file

@ -46,6 +46,7 @@ class TuneTranspose {
static size_t DefaultM() { return 1024; } static size_t DefaultM() { return 1024; }
static size_t DefaultN() { return 1024; } static size_t DefaultN() { return 1024; }
static size_t DefaultK() { return 1; } // N/A for this kernel static size_t DefaultK() { return 1; } // N/A for this kernel
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
static double DefaultFraction() { return 1.0; } // N/A for this kernel static double DefaultFraction() { return 1.0; } // N/A for this kernel
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging

View file

@ -46,6 +46,7 @@ class TunePadTranspose {
static size_t DefaultM() { return 1024; } static size_t DefaultM() { return 1024; }
static size_t DefaultN() { return 1024; } static size_t DefaultN() { return 1024; }
static size_t DefaultK() { return 1; } // N/A for this kernel static size_t DefaultK() { return 1; } // N/A for this kernel
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
static double DefaultFraction() { return 1.0; } // N/A for this kernel static double DefaultFraction() { return 1.0; } // N/A for this kernel
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging

View file

@ -50,6 +50,7 @@ class TuneXaxpy {
static size_t DefaultM() { return 1; } // N/A for this kernel static size_t DefaultM() { return 1; } // N/A for this kernel
static size_t DefaultN() { return 4096*1024; } static size_t DefaultN() { return 4096*1024; }
static size_t DefaultK() { return 1; } // N/A for this kernel static size_t DefaultK() { return 1; } // N/A for this kernel
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
static double DefaultFraction() { return 1.0; } // N/A for this kernel static double DefaultFraction() { return 1.0; } // N/A for this kernel
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging

View file

@ -46,6 +46,7 @@ class TuneXdot {
static size_t DefaultM() { return 1; } // N/A for this kernel static size_t DefaultM() { return 1; } // N/A for this kernel
static size_t DefaultN() { return 2*1024*1024; } static size_t DefaultN() { return 2*1024*1024; }
static size_t DefaultK() { return 1; } // N/A for this kernel static size_t DefaultK() { return 1; } // N/A for this kernel
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
static double DefaultFraction() { return 1.0; } // N/A for this kernel static double DefaultFraction() { return 1.0; } // N/A for this kernel
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging

View file

@ -51,6 +51,7 @@ class TuneXgemm {
static size_t DefaultM() { return 1024; } static size_t DefaultM() { return 1024; }
static size_t DefaultN() { return 1024; } static size_t DefaultN() { return 1024; }
static size_t DefaultK() { return 1024; } static size_t DefaultK() { return 1024; }
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
static double DefaultFraction() { return (V==1) ? 1.0 : 512.0; } // test all or sample randomly static double DefaultFraction() { return (V==1) ? 1.0 : 512.0; } // test all or sample randomly
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging

View file

@ -51,6 +51,7 @@ class TuneXgemmDirect {
static size_t DefaultM() { return 256; } static size_t DefaultM() { return 256; }
static size_t DefaultN() { return 256; } static size_t DefaultN() { return 256; }
static size_t DefaultK() { return 256; } static size_t DefaultK() { return 256; }
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
static double DefaultFraction() { return (V==1) ? 1.0 : 32.0; } // test all or sample randomly static double DefaultFraction() { return (V==1) ? 1.0 : 32.0; } // test all or sample randomly
static size_t DefaultNumRuns() { return 4; } // run every kernel this many times for averaging static size_t DefaultNumRuns() { return 4; } // run every kernel this many times for averaging

View file

@ -49,6 +49,7 @@ class TuneXgemv {
static size_t DefaultM() { return 2048; } static size_t DefaultM() { return 2048; }
static size_t DefaultN() { return 2048; } static size_t DefaultN() { return 2048; }
static size_t DefaultK() { return 1; } // N/A for this kernel static size_t DefaultK() { return 1; } // N/A for this kernel
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
static double DefaultFraction() { return 1.0; } // N/A for this kernel static double DefaultFraction() { return 1.0; } // N/A for this kernel
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging

View file

@ -46,6 +46,7 @@ class TuneXger {
static size_t DefaultM() { return 1024; } static size_t DefaultM() { return 1024; }
static size_t DefaultN() { return 1024; } static size_t DefaultN() { return 1024; }
static size_t DefaultK() { return 1; } // N/A for this kernel static size_t DefaultK() { return 1; } // N/A for this kernel
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
static double DefaultFraction() { return 1.0; } // N/A for this kernel static double DefaultFraction() { return 1.0; } // N/A for this kernel
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging

View file

@ -47,6 +47,7 @@ void Tuner(int argc, char* argv[]) {
if (o == kArgAlpha) { args.alpha = GetArgument(command_line_args, help, kArgAlpha, GetScalar<T>()); } if (o == kArgAlpha) { args.alpha = GetArgument(command_line_args, help, kArgAlpha, GetScalar<T>()); }
if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<T>()); } if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<T>()); }
if (o == kArgFraction) { args.fraction = GetArgument(command_line_args, help, kArgFraction, C::DefaultFraction()); } if (o == kArgFraction) { args.fraction = GetArgument(command_line_args, help, kArgFraction, C::DefaultFraction()); }
if (o == kArgBatchCount) { args.batch_count = GetArgument(command_line_args, help, kArgBatchCount, C::DefaultBatchCount()); }
} }
const auto num_runs = GetArgument(command_line_args, help, kArgNumRuns, C::DefaultNumRuns()); const auto num_runs = GetArgument(command_line_args, help, kArgNumRuns, C::DefaultNumRuns());
@ -158,6 +159,7 @@ void Tuner(int argc, char* argv[]) {
if (o == kArgK) { metadata.push_back({"arg_k", std::to_string(args.k)}); } if (o == kArgK) { metadata.push_back({"arg_k", std::to_string(args.k)}); }
if (o == kArgAlpha) { metadata.push_back({"arg_alpha", ToString(args.alpha)}); } if (o == kArgAlpha) { metadata.push_back({"arg_alpha", ToString(args.alpha)}); }
if (o == kArgBeta) { metadata.push_back({"arg_beta", ToString(args.beta)}); } if (o == kArgBeta) { metadata.push_back({"arg_beta", ToString(args.beta)}); }
if (o == kArgBatchCount) { metadata.push_back({"arg_batch_count", ToString(args.batch_count)}); }
} }
tuner.PrintJSON("clblast_"+C::KernelFamily()+"_"+precision_string+".json", metadata); tuner.PrintJSON("clblast_"+C::KernelFamily()+"_"+precision_string+".json", metadata);
} }

View file

@ -0,0 +1,30 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// =================================================================================================
#include "test/correctness/testblas.hpp"
#include "test/routines/levelx/xgemmbatched.hpp"
// Shortcuts to the clblast namespace
using float2 = clblast::float2;
using double2 = clblast::double2;
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
auto errors = size_t{0};
errors += clblast::RunTests<clblast::TestXgemmBatched<float>, float, float>(argc, argv, false, "SGEMMBATCHED");
errors += clblast::RunTests<clblast::TestXgemmBatched<double>, double, double>(argc, argv, true, "DGEMMBATCHED");
errors += clblast::RunTests<clblast::TestXgemmBatched<float2>, float2, float2>(argc, argv, true, "CGEMMBATCHED");
errors += clblast::RunTests<clblast::TestXgemmBatched<double2>, double2, double2>(argc, argv, true, "ZGEMMBATCHED");
errors += clblast::RunTests<clblast::TestXgemmBatched<half>, half, half>(argc, argv, true, "HGEMMBATCHED");
if (errors > 0) { return 1; } else { return 0; }
}
// =================================================================================================

View file

@ -21,6 +21,8 @@
namespace clblast { namespace clblast {
// ================================================================================================= // =================================================================================================
template <typename T, typename U> const auto TestBlas<T,U>::kSeed = 42; // fixed seed for reproducibility
// Test settings for the regular test. Append to these lists in case more tests are required. // Test settings for the regular test. Append to these lists in case more tests are required.
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kVectorDims = { 7, 93, 4096 }; template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kVectorDims = { 7, 93, 4096 };
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kIncrements = { 1, 2, 7 }; template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kIncrements = { 1, 2, 7 };

View file

@ -30,7 +30,7 @@ namespace clblast {
template <typename T, typename U> template <typename T, typename U>
class TestBlas: public Tester<T,U> { class TestBlas: public Tester<T,U> {
public: public:
static constexpr auto kSeed = 42; // fixed seed for reproducibility static const int kSeed;
// Uses several variables from the Tester class // Uses several variables from the Tester class
using Tester<T,U>::context_; using Tester<T,U>::context_;

View file

@ -24,6 +24,8 @@
namespace clblast { namespace clblast {
// ================================================================================================= // =================================================================================================
template <typename T, typename U> const int Client<T,U>::kSeed = 42; // fixed seed for reproducibility
// Constructor // Constructor
template <typename T, typename U> template <typename T, typename U>
Client<T,U>::Client(const Routine run_routine, Client<T,U>::Client(const Routine run_routine,

View file

@ -40,7 +40,7 @@ namespace clblast {
template <typename T, typename U> template <typename T, typename U>
class Client { class Client {
public: public:
static constexpr auto kSeed = 42; // fixed seed for reproducibility static const int kSeed;
// Shorthand for the routine-specific functions passed to the tester // Shorthand for the routine-specific functions passed to the tester
using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>; using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;

View file

@ -0,0 +1,37 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// =================================================================================================
#include "test/performance/client.hpp"
#include "test/routines/levelx/xgemmbatched.hpp"
// Shortcuts to the clblast namespace
using float2 = clblast::float2;
using double2 = clblast::double2;
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args, clblast::Precision::kSingle)) {
case clblast::Precision::kHalf:
clblast::RunClient<clblast::TestXgemmBatched<half>, half, half>(argc, argv); break;
case clblast::Precision::kSingle:
clblast::RunClient<clblast::TestXgemmBatched<float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
clblast::RunClient<clblast::TestXgemmBatched<double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
clblast::RunClient<clblast::TestXgemmBatched<float2>, float2, float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
clblast::RunClient<clblast::TestXgemmBatched<double2>, double2, double2>(argc, argv); break;
}
return 0;
}
// =================================================================================================

View file

@ -0,0 +1,207 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements a class with static methods to describe the XgemmBatched routine. Examples of
// such 'descriptions' are how to calculate the size a of buffer or how to run the routine. These
// static methods are used by the correctness tester and the performance tester.
//
// =================================================================================================
#ifndef CLBLAST_TEST_ROUTINES_XGEMMBATCHED_H_
#define CLBLAST_TEST_ROUTINES_XGEMMBATCHED_H_
#include <vector>
#include <string>
#ifdef CLBLAST_REF_CLBLAS
#include "test/wrapper_clblas.hpp"
#endif
#ifdef CLBLAST_REF_CBLAS
#include "test/wrapper_cblas.hpp"
#endif
namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
template <typename T>
class TestXgemmBatched {
public:
// Although it is a non-BLAS routine, it can still be tested against level-3 routines in a loop
static size_t BLASLevel() { return 3; }
// The list of arguments relevant for this routine
static std::vector<std::string> GetOptions() {
return {kArgM, kArgN, kArgK,
kArgLayout, kArgATransp, kArgBTransp,
kArgALeadDim, kArgBLeadDim, kArgCLeadDim,
kArgAOffset, kArgBOffset, kArgCOffset,
kArgBatchCount, kArgAlpha, kArgBeta};
}
// Helper for the sizes per batch
static size_t PerBatchSizeA(const Arguments<T> &args) {
auto a_rotated = (args.layout == Layout::kColMajor && args.a_transpose != Transpose::kNo) ||
(args.layout == Layout::kRowMajor && args.a_transpose == Transpose::kNo);
auto a_two = (a_rotated) ? args.m : args.k;
return a_two * args.a_ld;
}
static size_t PerBatchSizeB(const Arguments<T> &args) {
auto b_rotated = (args.layout == Layout::kColMajor && args.b_transpose != Transpose::kNo) ||
(args.layout == Layout::kRowMajor && args.b_transpose == Transpose::kNo);
auto b_two = (b_rotated) ? args.k : args.n;
return b_two * args.b_ld;
}
static size_t PerBatchSizeC(const Arguments<T> &args) {
auto c_rotated = (args.layout == Layout::kRowMajor);
auto c_two = (c_rotated) ? args.m : args.n;
return c_two * args.c_ld;
}
// Describes how to obtain the sizes of the buffers
static size_t GetSizeA(const Arguments<T> &args) {
return PerBatchSizeA(args) * args.batch_count + args.a_offset;
}
static size_t GetSizeB(const Arguments<T> &args) {
return PerBatchSizeB(args) * args.batch_count + args.b_offset;
}
static size_t GetSizeC(const Arguments<T> &args) {
return PerBatchSizeC(args) * args.batch_count + args.c_offset;
}
// Describes how to set the sizes of all the buffers
static void SetSizes(Arguments<T> &args) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
args.c_size = GetSizeC(args);
// Also sets the batch-related variables
args.a_offsets = std::vector<size_t>(args.batch_count);
args.b_offsets = std::vector<size_t>(args.batch_count);
args.c_offsets = std::vector<size_t>(args.batch_count);
args.alphas = std::vector<T>(args.batch_count);
args.betas = std::vector<T>(args.batch_count);
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
args.a_offsets[batch] = batch * PerBatchSizeA(args) + args.a_offset;
args.b_offsets[batch] = batch * PerBatchSizeB(args) + args.b_offset;
args.c_offsets[batch] = batch * PerBatchSizeC(args) + args.c_offset;
args.alphas[batch] = args.alpha + Constant<T>(batch);
args.betas[batch] = args.beta + Constant<T>(batch);
}
}
// Describes what the default values of the leading dimensions of the matrices are
static size_t DefaultLDA(const Arguments<T> &args) { return args.k; }
static size_t DefaultLDB(const Arguments<T> &args) { return args.n; }
static size_t DefaultLDC(const Arguments<T> &args) { return args.n; }
// Describes which transpose options are relevant for this routine
using Transposes = std::vector<Transpose>;
static Transposes GetATransposes(const Transposes &all) { return all; }
static Transposes GetBTransposes(const Transposes &all) { return all; }
// Describes how to prepare the input data
static void PrepareData(const Arguments<T>&, Queue&, const int, std::vector<T>&,
std::vector<T>&, std::vector<T>&, std::vector<T>&, std::vector<T>&,
std::vector<T>&, std::vector<T>&) {} // N/A for this routine
// Describes how to run the CLBlast routine
static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
auto queue_plain = queue();
auto event = cl_event{};
auto status = GemmBatched(args.layout, args.a_transpose, args.b_transpose,
args.m, args.n, args.k, args.alphas.data(),
buffers.a_mat(), args.a_offsets.data(), args.a_ld,
buffers.b_mat(), args.b_offsets.data(), args.b_ld, args.betas.data(),
buffers.c_mat(), args.c_offsets.data(), args.c_ld,
args.batch_count,
&queue_plain, &event);
if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
return status;
}
// Describes how to run the clBLAS routine (for correctness/performance comparison)
#ifdef CLBLAST_REF_CLBLAS
static StatusCode RunReference1(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
auto queue_plain = queue();
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
auto event = cl_event{};
auto status = clblasXgemm(convertToCLBLAS(args.layout),
convertToCLBLAS(args.a_transpose),
convertToCLBLAS(args.b_transpose),
args.m, args.n, args.k, args.alphas[batch],
buffers.a_mat, args.a_offsets[batch], args.a_ld,
buffers.b_mat, args.b_offsets[batch], args.b_ld, args.betas[batch],
buffers.c_mat, args.c_offsets[batch], args.c_ld,
1, &queue_plain, 0, nullptr, &event);
clWaitForEvents(1, &event);
if (static_cast<StatusCode>(status) != StatusCode::kSuccess) {
return static_cast<StatusCode>(status);
}
}
return StatusCode::kSuccess;
}
#endif
// Describes how to run the CPU BLAS routine (for correctness/performance comparison)
#ifdef CLBLAST_REF_CBLAS
static StatusCode RunReference2(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
std::vector<T> a_mat_cpu(args.a_size, static_cast<T>(0));
std::vector<T> b_mat_cpu(args.b_size, static_cast<T>(0));
std::vector<T> c_mat_cpu(args.c_size, static_cast<T>(0));
buffers.a_mat.Read(queue, args.a_size, a_mat_cpu);
buffers.b_mat.Read(queue, args.b_size, b_mat_cpu);
buffers.c_mat.Read(queue, args.c_size, c_mat_cpu);
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
cblasXgemm(convertToCBLAS(args.layout),
convertToCBLAS(args.a_transpose),
convertToCBLAS(args.b_transpose),
args.m, args.n, args.k, args.alphas[batch],
a_mat_cpu, args.a_offsets[batch], args.a_ld,
b_mat_cpu, args.b_offsets[batch], args.b_ld, args.betas[batch],
c_mat_cpu, args.c_offsets[batch], args.c_ld);
}
buffers.c_mat.Write(queue, args.c_size, c_mat_cpu);
return StatusCode::kSuccess;
}
#endif
// Describes how to download the results of the computation (more importantly: which buffer)
static std::vector<T> DownloadResult(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
std::vector<T> result(args.c_size, static_cast<T>(0));
buffers.c_mat.Read(queue, args.c_size, result);
return result;
}
// Describes how to compute the indices of the result buffer
static size_t ResultID1(const Arguments<T> &args) { return args.m; }
static size_t ResultID2(const Arguments<T> &args) { return args.n * args.batch_count; }
static size_t GetResultIndex(const Arguments<T> &args, const size_t id1, const size_t id2_3) {
const size_t id2 = id2_3 % args.n;
const size_t id3 = id2_3 / args.n;
return (args.layout == Layout::kRowMajor) ?
id1*args.c_ld + id2 + args.c_offsets[id3]:
id2*args.c_ld + id1 + args.c_offsets[id3];
}
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
return args.batch_count * (2 * args.m * args.n * args.k);
}
static size_t GetBytes(const Arguments<T> &args) {
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
}
};
// =================================================================================================
} // namespace clblast
// CLBLAST_TEST_ROUTINES_XGEMMBATCHED_H_
#endif