mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-08-27 07:17:00 +02:00
Merge pull request #142 from CNugteren/gemm_batched
Added a first batched version of the GEMM routine
This commit is contained in:
commit
a21d903796
|
@ -15,6 +15,7 @@ Development version (next release)
|
||||||
* STRSM/DTRSM/CTRSM/ZTRSM (experimental, un-optimized)
|
* STRSM/DTRSM/CTRSM/ZTRSM (experimental, un-optimized)
|
||||||
- Added batched (non-BLAS) routines:
|
- Added batched (non-BLAS) routines:
|
||||||
* SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED (batched version of AXPY)
|
* SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED (batched version of AXPY)
|
||||||
|
* SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED (batched version of GEMM)
|
||||||
|
|
||||||
Version 0.10.0
|
Version 0.10.0
|
||||||
- Updated to version 8.0 of the CLCudaAPI C++11 OpenCL header
|
- Updated to version 8.0 of the CLCudaAPI C++11 OpenCL header
|
||||||
|
|
|
@ -159,7 +159,7 @@ set(LEVEL1_ROUTINES xswap xscal xcopy xaxpy xdot xdotu xdotc xnrm2 xasum xamax)
|
||||||
set(LEVEL2_ROUTINES xgemv xgbmv xhemv xhbmv xhpmv xsymv xsbmv xspmv xtrmv xtbmv xtpmv xtrsv
|
set(LEVEL2_ROUTINES xgemv xgbmv xhemv xhbmv xhpmv xsymv xsbmv xspmv xtrmv xtbmv xtpmv xtrsv
|
||||||
xger xgeru xgerc xher xhpr xher2 xhpr2 xsyr xspr xsyr2 xspr2)
|
xger xgeru xgerc xher xhpr xher2 xhpr2 xsyr xspr xsyr2 xspr2)
|
||||||
set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm xtrsm)
|
set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm xtrsm)
|
||||||
set(LEVELX_ROUTINES xomatcopy xaxpybatched)
|
set(LEVELX_ROUTINES xomatcopy xaxpybatched xgemmbatched)
|
||||||
set(ROUTINES ${LEVEL1_ROUTINES} ${LEVEL2_ROUTINES} ${LEVEL3_ROUTINES} ${LEVELX_ROUTINES})
|
set(ROUTINES ${LEVEL1_ROUTINES} ${LEVEL2_ROUTINES} ${LEVEL3_ROUTINES} ${LEVELX_ROUTINES})
|
||||||
set(PRECISIONS 32 64 3232 6464 16)
|
set(PRECISIONS 32 64 3232 6464 16)
|
||||||
|
|
||||||
|
|
|
@ -276,6 +276,13 @@ CLBlast supports almost all the Netlib BLAS routines plus a couple of extra non-
|
||||||
| xTRMM | ✔ | ✔ | ✔ | ✔ | ✔ |
|
| xTRMM | ✔ | ✔ | ✔ | ✔ | ✔ |
|
||||||
| xTRSM | ✔ | ✔ | ✔ | ✔ | | (experimental, un-optimized)
|
| xTRSM | ✔ | ✔ | ✔ | ✔ | | (experimental, un-optimized)
|
||||||
|
|
||||||
|
Futhermore, there are also batched versions of BLAS routines available, processing multiple smaller computations in one go for better performance:
|
||||||
|
|
||||||
|
| Batched | S | D | C | Z | H |
|
||||||
|
| -------------|---|---|---|---|---|
|
||||||
|
| xAXPYBATCHED | ✔ | ✔ | ✔ | ✔ | ✔ |
|
||||||
|
| xGEMMBATCHED | ✔ | ✔ | ✔ | ✔ | ✔ |
|
||||||
|
|
||||||
In addition, some extra non-BLAS routines are also supported by CLBlast, classified as level-X. They are experimental and should be used with care:
|
In addition, some extra non-BLAS routines are also supported by CLBlast, classified as level-X. They are experimental and should be used with care:
|
||||||
|
|
||||||
| Level-X | S | D | C | Z | H |
|
| Level-X | S | D | C | Z | H |
|
||||||
|
|
|
@ -2969,6 +2969,105 @@ Arguments to AXPYBATCHED:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
xGEMMBATCHED: Batched version of GEMM
|
||||||
|
-------------
|
||||||
|
|
||||||
|
As GEMM, but multiple operations are batched together for better performance.
|
||||||
|
|
||||||
|
C++ API:
|
||||||
|
```
|
||||||
|
template <typename T>
|
||||||
|
StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const T *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const T *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event)
|
||||||
|
```
|
||||||
|
|
||||||
|
C API:
|
||||||
|
```
|
||||||
|
CLBlastStatusCode CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const float *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const float *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event)
|
||||||
|
CLBlastStatusCode CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const double *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const double *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event)
|
||||||
|
CLBlastStatusCode CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const cl_float2 *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const cl_float2 *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event)
|
||||||
|
CLBlastStatusCode CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const cl_double2 *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const cl_double2 *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event)
|
||||||
|
CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const cl_half *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const cl_half *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event)
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments to GEMMBATCHED:
|
||||||
|
|
||||||
|
* `const Layout layout`: Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.
|
||||||
|
* `const Transpose a_transpose`: Transposing the input matrix A, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.
|
||||||
|
* `const Transpose b_transpose`: Transposing the input matrix B, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.
|
||||||
|
* `const size_t m`: Integer size argument. This value must be positive.
|
||||||
|
* `const size_t n`: Integer size argument. This value must be positive.
|
||||||
|
* `const size_t k`: Integer size argument. This value must be positive.
|
||||||
|
* `const T *alphas`: Input scalar constants.
|
||||||
|
* `const cl_mem a_buffer`: OpenCL buffer to store the input A matrix.
|
||||||
|
* `const size_t *a_offsets`: The offsets in elements from the start of the input A matrix.
|
||||||
|
* `const size_t a_ld`: Leading dimension of the input A matrix. This value must be greater than 0.
|
||||||
|
* `const cl_mem b_buffer`: OpenCL buffer to store the input B matrix.
|
||||||
|
* `const size_t *b_offsets`: The offsets in elements from the start of the input B matrix.
|
||||||
|
* `const size_t b_ld`: Leading dimension of the input B matrix. This value must be greater than 0.
|
||||||
|
* `const T *betas`: Input scalar constants.
|
||||||
|
* `cl_mem c_buffer`: OpenCL buffer to store the output C matrix.
|
||||||
|
* `const size_t *c_offsets`: The offsets in elements from the start of the output C matrix.
|
||||||
|
* `const size_t c_ld`: Leading dimension of the output C matrix. This value must be greater than 0.
|
||||||
|
* `const size_t batch_count`: Number of batches. This value must be positive.
|
||||||
|
* `cl_command_queue* queue`: Pointer to an OpenCL command queue associated with a context and device to execute the routine on.
|
||||||
|
* `cl_event* event`: Pointer to an OpenCL event to be able to wait for completion of the routine's OpenCL kernel(s). This is an optional argument.
|
||||||
|
|
||||||
|
Requirements for GEMMBATCHED:
|
||||||
|
|
||||||
|
* When `transpose_a == Transpose::kNo`, then `a_ld` must be at least `m`, otherwise `a_ld` must be at least `k`.
|
||||||
|
* When `transpose_b == Transpose::kNo`, then `b_ld` must be at least `k`, otherwise `b_ld` must be at least `n`.
|
||||||
|
* The value of `c_ld` must be at least `m`.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ClearCache: Resets the cache of compiled binaries (auxiliary function)
|
ClearCache: Resets the cache of compiled binaries (auxiliary function)
|
||||||
-------------
|
-------------
|
||||||
|
|
||||||
|
|
|
@ -619,6 +619,18 @@ StatusCode AxpyBatched(const size_t n,
|
||||||
const size_t batch_count,
|
const size_t batch_count,
|
||||||
cl_command_queue* queue, cl_event* event = nullptr);
|
cl_command_queue* queue, cl_event* event = nullptr);
|
||||||
|
|
||||||
|
// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
|
||||||
|
template <typename T>
|
||||||
|
StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const T *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const T *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event = nullptr);
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
|
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
|
||||||
|
|
|
@ -1360,6 +1360,53 @@ CLBlastStatusCode PUBLIC_API CLBlastHaxpyBatched(const size_t n,
|
||||||
const size_t batch_count,
|
const size_t batch_count,
|
||||||
cl_command_queue* queue, cl_event* event);
|
cl_command_queue* queue, cl_event* event);
|
||||||
|
|
||||||
|
// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
|
||||||
|
CLBlastStatusCode PUBLIC_API CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const float *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const float *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event);
|
||||||
|
CLBlastStatusCode PUBLIC_API CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const double *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const double *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event);
|
||||||
|
CLBlastStatusCode PUBLIC_API CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const cl_float2 *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const cl_float2 *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event);
|
||||||
|
CLBlastStatusCode PUBLIC_API CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const cl_double2 *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const cl_double2 *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event);
|
||||||
|
CLBlastStatusCode PUBLIC_API CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const cl_half *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const cl_half *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event);
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
|
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on
|
||||||
|
|
|
@ -41,7 +41,7 @@ FILES = [
|
||||||
"/include/clblast_netlib_c.h",
|
"/include/clblast_netlib_c.h",
|
||||||
"/src/clblast_netlib_c.cpp",
|
"/src/clblast_netlib_c.cpp",
|
||||||
]
|
]
|
||||||
HEADER_LINES = [122, 76, 126, 23, 29, 41, 65, 32]
|
HEADER_LINES = [123, 76, 126, 23, 29, 41, 65, 32]
|
||||||
FOOTER_LINES = [25, 138, 27, 38, 6, 6, 9, 2]
|
FOOTER_LINES = [25, 138, 27, 38, 6, 6, 9, 2]
|
||||||
HEADER_LINES_DOC = 0
|
HEADER_LINES_DOC = 0
|
||||||
FOOTER_LINES_DOC = 63
|
FOOTER_LINES_DOC = 63
|
||||||
|
@ -163,6 +163,7 @@ ROUTINES = [
|
||||||
Routine(True, True, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
|
Routine(True, True, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
|
||||||
# Batched routines:
|
# Batched routines:
|
||||||
Routine(True, True, True, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
|
Routine(True, True, True, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
|
||||||
|
Routine(True, True, True, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
|
||||||
]]
|
]]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -72,6 +72,7 @@
|
||||||
// Level-x includes (non-BLAS)
|
// Level-x includes (non-BLAS)
|
||||||
#include "routines/levelx/xomatcopy.hpp"
|
#include "routines/levelx/xomatcopy.hpp"
|
||||||
#include "routines/levelx/xaxpybatched.hpp"
|
#include "routines/levelx/xaxpybatched.hpp"
|
||||||
|
#include "routines/levelx/xgemmbatched.hpp"
|
||||||
|
|
||||||
namespace clblast {
|
namespace clblast {
|
||||||
|
|
||||||
|
@ -2231,6 +2232,89 @@ template StatusCode PUBLIC_API AxpyBatched<half>(const size_t,
|
||||||
cl_mem, const size_t*, const size_t,
|
cl_mem, const size_t*, const size_t,
|
||||||
const size_t,
|
const size_t,
|
||||||
cl_command_queue*, cl_event*);
|
cl_command_queue*, cl_event*);
|
||||||
|
|
||||||
|
// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
|
||||||
|
template <typename T>
|
||||||
|
StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const T *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const T *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event) {
|
||||||
|
try {
|
||||||
|
auto queue_cpp = Queue(*queue);
|
||||||
|
auto routine = XgemmBatched<T>(queue_cpp, event);
|
||||||
|
auto alphas_cpp = std::vector<T>();
|
||||||
|
auto betas_cpp = std::vector<T>();
|
||||||
|
auto a_offsets_cpp = std::vector<size_t>();
|
||||||
|
auto b_offsets_cpp = std::vector<size_t>();
|
||||||
|
auto c_offsets_cpp = std::vector<size_t>();
|
||||||
|
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||||
|
alphas_cpp.push_back(alphas[batch]);
|
||||||
|
betas_cpp.push_back(betas[batch]);
|
||||||
|
a_offsets_cpp.push_back(a_offsets[batch]);
|
||||||
|
b_offsets_cpp.push_back(b_offsets[batch]);
|
||||||
|
c_offsets_cpp.push_back(c_offsets[batch]);
|
||||||
|
}
|
||||||
|
routine.DoGemmBatched(layout, a_transpose, b_transpose,
|
||||||
|
m, n, k,
|
||||||
|
alphas_cpp,
|
||||||
|
Buffer<T>(a_buffer), a_offsets_cpp, a_ld,
|
||||||
|
Buffer<T>(b_buffer), b_offsets_cpp, b_ld,
|
||||||
|
betas_cpp,
|
||||||
|
Buffer<T>(c_buffer), c_offsets_cpp, c_ld,
|
||||||
|
batch_count);
|
||||||
|
return StatusCode::kSuccess;
|
||||||
|
} catch (...) { return DispatchException(); }
|
||||||
|
}
|
||||||
|
template StatusCode PUBLIC_API GemmBatched<float>(const Layout, const Transpose, const Transpose,
|
||||||
|
const size_t, const size_t, const size_t,
|
||||||
|
const float*,
|
||||||
|
const cl_mem, const size_t*, const size_t,
|
||||||
|
const cl_mem, const size_t*, const size_t,
|
||||||
|
const float*,
|
||||||
|
cl_mem, const size_t*, const size_t,
|
||||||
|
const size_t,
|
||||||
|
cl_command_queue*, cl_event*);
|
||||||
|
template StatusCode PUBLIC_API GemmBatched<double>(const Layout, const Transpose, const Transpose,
|
||||||
|
const size_t, const size_t, const size_t,
|
||||||
|
const double*,
|
||||||
|
const cl_mem, const size_t*, const size_t,
|
||||||
|
const cl_mem, const size_t*, const size_t,
|
||||||
|
const double*,
|
||||||
|
cl_mem, const size_t*, const size_t,
|
||||||
|
const size_t,
|
||||||
|
cl_command_queue*, cl_event*);
|
||||||
|
template StatusCode PUBLIC_API GemmBatched<float2>(const Layout, const Transpose, const Transpose,
|
||||||
|
const size_t, const size_t, const size_t,
|
||||||
|
const float2*,
|
||||||
|
const cl_mem, const size_t*, const size_t,
|
||||||
|
const cl_mem, const size_t*, const size_t,
|
||||||
|
const float2*,
|
||||||
|
cl_mem, const size_t*, const size_t,
|
||||||
|
const size_t,
|
||||||
|
cl_command_queue*, cl_event*);
|
||||||
|
template StatusCode PUBLIC_API GemmBatched<double2>(const Layout, const Transpose, const Transpose,
|
||||||
|
const size_t, const size_t, const size_t,
|
||||||
|
const double2*,
|
||||||
|
const cl_mem, const size_t*, const size_t,
|
||||||
|
const cl_mem, const size_t*, const size_t,
|
||||||
|
const double2*,
|
||||||
|
cl_mem, const size_t*, const size_t,
|
||||||
|
const size_t,
|
||||||
|
cl_command_queue*, cl_event*);
|
||||||
|
template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose, const Transpose,
|
||||||
|
const size_t, const size_t, const size_t,
|
||||||
|
const half*,
|
||||||
|
const cl_mem, const size_t*, const size_t,
|
||||||
|
const cl_mem, const size_t*, const size_t,
|
||||||
|
const half*,
|
||||||
|
cl_mem, const size_t*, const size_t,
|
||||||
|
const size_t,
|
||||||
|
cl_command_queue*, cl_event*);
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
// Clears the cache of stored binaries
|
// Clears the cache of stored binaries
|
||||||
|
|
|
@ -3554,6 +3554,163 @@ CLBlastStatusCode CLBlastHaxpyBatched(const size_t n,
|
||||||
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GEMM
|
||||||
|
CLBlastStatusCode CLBlastSgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const float *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const float *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event) {
|
||||||
|
auto alphas_cpp = std::vector<float>();
|
||||||
|
auto betas_cpp = std::vector<float>();
|
||||||
|
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||||
|
alphas_cpp.push_back(alphas[batch]);
|
||||||
|
betas_cpp.push_back(betas[batch]);
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return static_cast<CLBlastStatusCode>(
|
||||||
|
clblast::GemmBatched(static_cast<clblast::Layout>(layout),
|
||||||
|
static_cast<clblast::Transpose>(a_transpose),
|
||||||
|
static_cast<clblast::Transpose>(b_transpose),
|
||||||
|
m, n, k,
|
||||||
|
alphas_cpp.data(),
|
||||||
|
a_buffer, a_offsets, a_ld,
|
||||||
|
b_buffer, b_offsets, b_ld,
|
||||||
|
betas_cpp.data(),
|
||||||
|
c_buffer, c_offsets, c_ld,
|
||||||
|
batch_count,
|
||||||
|
queue, event)
|
||||||
|
);
|
||||||
|
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
||||||
|
}
|
||||||
|
CLBlastStatusCode CLBlastDgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const double *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const double *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event) {
|
||||||
|
auto alphas_cpp = std::vector<double>();
|
||||||
|
auto betas_cpp = std::vector<double>();
|
||||||
|
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||||
|
alphas_cpp.push_back(alphas[batch]);
|
||||||
|
betas_cpp.push_back(betas[batch]);
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return static_cast<CLBlastStatusCode>(
|
||||||
|
clblast::GemmBatched(static_cast<clblast::Layout>(layout),
|
||||||
|
static_cast<clblast::Transpose>(a_transpose),
|
||||||
|
static_cast<clblast::Transpose>(b_transpose),
|
||||||
|
m, n, k,
|
||||||
|
alphas_cpp.data(),
|
||||||
|
a_buffer, a_offsets, a_ld,
|
||||||
|
b_buffer, b_offsets, b_ld,
|
||||||
|
betas_cpp.data(),
|
||||||
|
c_buffer, c_offsets, c_ld,
|
||||||
|
batch_count,
|
||||||
|
queue, event)
|
||||||
|
);
|
||||||
|
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
||||||
|
}
|
||||||
|
CLBlastStatusCode CLBlastCgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const cl_float2 *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const cl_float2 *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event) {
|
||||||
|
auto alphas_cpp = std::vector<float2>();
|
||||||
|
auto betas_cpp = std::vector<float2>();
|
||||||
|
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||||
|
alphas_cpp.push_back(float2{alphas[batch].s[0], alphas[batch].s[1]});
|
||||||
|
betas_cpp.push_back(float2{betas[batch].s[0], betas[batch].s[1]});
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return static_cast<CLBlastStatusCode>(
|
||||||
|
clblast::GemmBatched(static_cast<clblast::Layout>(layout),
|
||||||
|
static_cast<clblast::Transpose>(a_transpose),
|
||||||
|
static_cast<clblast::Transpose>(b_transpose),
|
||||||
|
m, n, k,
|
||||||
|
alphas_cpp.data(),
|
||||||
|
a_buffer, a_offsets, a_ld,
|
||||||
|
b_buffer, b_offsets, b_ld,
|
||||||
|
betas_cpp.data(),
|
||||||
|
c_buffer, c_offsets, c_ld,
|
||||||
|
batch_count,
|
||||||
|
queue, event)
|
||||||
|
);
|
||||||
|
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
||||||
|
}
|
||||||
|
CLBlastStatusCode CLBlastZgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const cl_double2 *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const cl_double2 *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event) {
|
||||||
|
auto alphas_cpp = std::vector<double2>();
|
||||||
|
auto betas_cpp = std::vector<double2>();
|
||||||
|
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||||
|
alphas_cpp.push_back(double2{alphas[batch].s[0], alphas[batch].s[1]});
|
||||||
|
betas_cpp.push_back(double2{betas[batch].s[0], betas[batch].s[1]});
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return static_cast<CLBlastStatusCode>(
|
||||||
|
clblast::GemmBatched(static_cast<clblast::Layout>(layout),
|
||||||
|
static_cast<clblast::Transpose>(a_transpose),
|
||||||
|
static_cast<clblast::Transpose>(b_transpose),
|
||||||
|
m, n, k,
|
||||||
|
alphas_cpp.data(),
|
||||||
|
a_buffer, a_offsets, a_ld,
|
||||||
|
b_buffer, b_offsets, b_ld,
|
||||||
|
betas_cpp.data(),
|
||||||
|
c_buffer, c_offsets, c_ld,
|
||||||
|
batch_count,
|
||||||
|
queue, event)
|
||||||
|
);
|
||||||
|
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
||||||
|
}
|
||||||
|
CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const cl_half *alphas,
|
||||||
|
const cl_mem a_buffer, const size_t *a_offsets, const size_t a_ld,
|
||||||
|
const cl_mem b_buffer, const size_t *b_offsets, const size_t b_ld,
|
||||||
|
const cl_half *betas,
|
||||||
|
cl_mem c_buffer, const size_t *c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count,
|
||||||
|
cl_command_queue* queue, cl_event* event) {
|
||||||
|
auto alphas_cpp = std::vector<half>();
|
||||||
|
auto betas_cpp = std::vector<half>();
|
||||||
|
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||||
|
alphas_cpp.push_back(alphas[batch]);
|
||||||
|
betas_cpp.push_back(betas[batch]);
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return static_cast<CLBlastStatusCode>(
|
||||||
|
clblast::GemmBatched(static_cast<clblast::Layout>(layout),
|
||||||
|
static_cast<clblast::Transpose>(a_transpose),
|
||||||
|
static_cast<clblast::Transpose>(b_transpose),
|
||||||
|
m, n, k,
|
||||||
|
alphas_cpp.data(),
|
||||||
|
a_buffer, a_offsets, a_ld,
|
||||||
|
b_buffer, b_offsets, b_ld,
|
||||||
|
betas_cpp.data(),
|
||||||
|
c_buffer, c_offsets, c_ld,
|
||||||
|
batch_count,
|
||||||
|
queue, event)
|
||||||
|
);
|
||||||
|
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
||||||
|
}
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
// Clears the cache of stored binaries
|
// Clears the cache of stored binaries
|
||||||
|
|
|
@ -24,16 +24,14 @@ R"(
|
||||||
// Copies a matrix from source to destination. The output is padded with zero values in case the
|
// Copies a matrix from source to destination. The output is padded with zero values in case the
|
||||||
// destination matrix dimensions are larger than the source matrix dimensions. Additionally, the ld
|
// destination matrix dimensions are larger than the source matrix dimensions. Additionally, the ld
|
||||||
// value and offset can be different.
|
// value and offset can be different.
|
||||||
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
|
inline void _CopyPadMatrix(const int src_one, const int src_two,
|
||||||
void CopyPadMatrix(const int src_one, const int src_two,
|
|
||||||
const int src_ld, const int src_offset,
|
const int src_ld, const int src_offset,
|
||||||
__global const real* restrict src,
|
__global const real* restrict src,
|
||||||
const int dest_one, const int dest_two,
|
const int dest_one, const int dest_two,
|
||||||
const int dest_ld, const int dest_offset,
|
const int dest_ld, const int dest_offset,
|
||||||
__global real* dest,
|
__global real* dest,
|
||||||
const real_arg arg_alpha,
|
const real alpha,
|
||||||
const int do_conjugate) {
|
const int do_conjugate) {
|
||||||
const real alpha = GetRealArg(arg_alpha);
|
|
||||||
|
|
||||||
// Loops over the work per thread in both dimensions
|
// Loops over the work per thread in both dimensions
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -60,22 +58,36 @@ void CopyPadMatrix(const int src_one, const int src_two,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// =================================================================================================
|
// Interface to the above function
|
||||||
|
|
||||||
// Same as above, but now un-pads a matrix. This kernel reads data from a padded source matrix, but
|
|
||||||
// writes only the actual data back to the destination matrix. Again, the ld value and offset can
|
|
||||||
// be different.
|
|
||||||
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
|
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
|
||||||
void CopyMatrix(const int src_one, const int src_two,
|
void CopyPadMatrix(const int src_one, const int src_two,
|
||||||
const int src_ld, const int src_offset,
|
const int src_ld, const int src_offset,
|
||||||
__global const real* restrict src,
|
__global const real* restrict src,
|
||||||
const int dest_one, const int dest_two,
|
const int dest_one, const int dest_two,
|
||||||
const int dest_ld, const int dest_offset,
|
const int dest_ld, const int dest_offset,
|
||||||
__global real* dest,
|
__global real* dest,
|
||||||
const real_arg arg_alpha,
|
const real_arg arg_alpha,
|
||||||
|
const int do_conjugate) {
|
||||||
|
const real alpha = GetRealArg(arg_alpha);
|
||||||
|
_CopyPadMatrix(src_one, src_two, src_ld, src_offset, src,
|
||||||
|
dest_one, dest_two, dest_ld, dest_offset, dest,
|
||||||
|
alpha, do_conjugate);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Same as above, but now un-pads a matrix. This kernel reads data from a padded source matrix, but
|
||||||
|
// writes only the actual data back to the destination matrix. Again, the ld value and offset can
|
||||||
|
// be different.
|
||||||
|
inline void _CopyMatrix(const int src_one, const int src_two,
|
||||||
|
const int src_ld, const int src_offset,
|
||||||
|
__global const real* restrict src,
|
||||||
|
const int dest_one, const int dest_two,
|
||||||
|
const int dest_ld, const int dest_offset,
|
||||||
|
__global real* dest,
|
||||||
|
const real alpha,
|
||||||
const int upper, const int lower,
|
const int upper, const int lower,
|
||||||
const int diagonal_imag_zero) {
|
const int diagonal_imag_zero) {
|
||||||
const real alpha = GetRealArg(arg_alpha);
|
|
||||||
|
|
||||||
// Loops over the work per thread in both dimensions
|
// Loops over the work per thread in both dimensions
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -105,6 +117,62 @@ void CopyMatrix(const int src_one, const int src_two,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Interface to the above function
|
||||||
|
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
|
||||||
|
void CopyMatrix(const int src_one, const int src_two,
|
||||||
|
const int src_ld, const int src_offset,
|
||||||
|
__global const real* restrict src,
|
||||||
|
const int dest_one, const int dest_two,
|
||||||
|
const int dest_ld, const int dest_offset,
|
||||||
|
__global real* dest,
|
||||||
|
const real_arg arg_alpha,
|
||||||
|
const int upper, const int lower,
|
||||||
|
const int diagonal_imag_zero) {
|
||||||
|
const real alpha = GetRealArg(arg_alpha);
|
||||||
|
_CopyMatrix(src_one, src_two, src_ld, src_offset, src,
|
||||||
|
dest_one, dest_two, dest_ld, dest_offset, dest,
|
||||||
|
alpha, upper, lower, diagonal_imag_zero);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
#if defined(ROUTINE_GEMMBATCHED)
|
||||||
|
|
||||||
|
// Batched version of the above
|
||||||
|
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
|
||||||
|
void CopyPadMatrixBatched(const int src_one, const int src_two,
|
||||||
|
const int src_ld, const __constant int* src_offsets,
|
||||||
|
__global const real* restrict src,
|
||||||
|
const int dest_one, const int dest_two,
|
||||||
|
const int dest_ld, const __constant int* dest_offsets,
|
||||||
|
__global real* dest,
|
||||||
|
const int do_conjugate) {
|
||||||
|
const int batch = get_group_id(2);
|
||||||
|
const int src_offset = src_offsets[batch];
|
||||||
|
const int dest_offset = dest_offsets[batch];
|
||||||
|
real alpha; SetToOne(alpha);
|
||||||
|
_CopyPadMatrix(src_one, src_two, src_ld, src_offset, src,
|
||||||
|
dest_one, dest_two, dest_ld, dest_offset, dest,
|
||||||
|
alpha, do_conjugate);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Batched version of the above
|
||||||
|
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
|
||||||
|
void CopyMatrixBatched(const int src_one, const int src_two,
|
||||||
|
const int src_ld, const __constant int* src_offsets,
|
||||||
|
__global const real* restrict src,
|
||||||
|
const int dest_one, const int dest_two,
|
||||||
|
const int dest_ld, const __constant int* dest_offsets,
|
||||||
|
__global real* dest) {
|
||||||
|
const int batch = get_group_id(2);
|
||||||
|
const int src_offset = src_offsets[batch];
|
||||||
|
const int dest_offset = dest_offsets[batch];
|
||||||
|
real alpha; SetToOne(alpha);
|
||||||
|
_CopyMatrix(src_one, src_two, src_ld, src_offset, src,
|
||||||
|
dest_one, dest_two, dest_ld, dest_offset, dest,
|
||||||
|
alpha, 0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
// End of the C++11 raw string literal
|
// End of the C++11 raw string literal
|
||||||
|
|
|
@ -24,19 +24,15 @@ R"(
|
||||||
|
|
||||||
// Transposes a matrix from source to destination. The output is padded with zero values in case the
|
// Transposes a matrix from source to destination. The output is padded with zero values in case the
|
||||||
// destination matrix dimensions are larger than the transposed source matrix dimensions.
|
// destination matrix dimensions are larger than the transposed source matrix dimensions.
|
||||||
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
|
inline void _TransposePadMatrix(__local real* tile,
|
||||||
void TransposePadMatrix(const int src_one, const int src_two,
|
const int src_one, const int src_two,
|
||||||
const int src_ld, const int src_offset,
|
const int src_ld, const int src_offset,
|
||||||
__global const real* restrict src,
|
__global const real* restrict src,
|
||||||
const int dest_one, const int dest_two,
|
const int dest_one, const int dest_two,
|
||||||
const int dest_ld, const int dest_offset,
|
const int dest_ld, const int dest_offset,
|
||||||
__global real* dest,
|
__global real* dest,
|
||||||
const real_arg arg_alpha,
|
const real alpha,
|
||||||
const int do_conjugate) {
|
const int do_conjugate) {
|
||||||
const real alpha = GetRealArg(arg_alpha);
|
|
||||||
|
|
||||||
// Local memory to store a tile of the matrix (for coalescing)
|
|
||||||
__local real tile[PADTRA_WPT*PADTRA_TILE][PADTRA_WPT*PADTRA_TILE + PADTRA_PAD];
|
|
||||||
|
|
||||||
// Loop over the work per thread
|
// Loop over the work per thread
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -56,7 +52,9 @@ void TransposePadMatrix(const int src_one, const int src_two,
|
||||||
if (id_src_two < src_two && id_src_one < src_one) {
|
if (id_src_two < src_two && id_src_one < src_one) {
|
||||||
value = src[id_src_two*src_ld + id_src_one + src_offset];
|
value = src[id_src_two*src_ld + id_src_one + src_offset];
|
||||||
}
|
}
|
||||||
tile[get_local_id(1)*PADTRA_WPT + w_two][get_local_id(0)*PADTRA_WPT + w_one] = value;
|
const int tile_id0 = get_local_id(0)*PADTRA_WPT + w_one;
|
||||||
|
const int tile_id1 = get_local_id(1)*PADTRA_WPT + w_two;
|
||||||
|
tile[tile_id1 * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD) + tile_id0] = value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,7 +73,9 @@ void TransposePadMatrix(const int src_one, const int src_two,
|
||||||
|
|
||||||
// Stores the transposed value in the destination matrix
|
// Stores the transposed value in the destination matrix
|
||||||
if ((id_dest_one < dest_one) && (id_dest_two < dest_two)) {
|
if ((id_dest_one < dest_one) && (id_dest_two < dest_two)) {
|
||||||
real value = tile[get_local_id(0)*PADTRA_WPT + w_two][get_local_id(1)*PADTRA_WPT + w_one];
|
const int tile_id0 = get_local_id(1)*PADTRA_WPT + w_one;
|
||||||
|
const int tile_id1 = get_local_id(0)*PADTRA_WPT + w_two;
|
||||||
|
real value = tile[tile_id1 * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD) + tile_id0];
|
||||||
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
|
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
|
||||||
Multiply(dest[id_dest_two*dest_ld + id_dest_one + dest_offset], alpha, value);
|
Multiply(dest[id_dest_two*dest_ld + id_dest_one + dest_offset], alpha, value);
|
||||||
}
|
}
|
||||||
|
@ -83,25 +83,38 @@ void TransposePadMatrix(const int src_one, const int src_two,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// =================================================================================================
|
// Interface to the above function
|
||||||
|
|
||||||
// Transposes a matrix, while considering possible padding in the source matrix. Data is read from a
|
|
||||||
// padded source matrix, but only the actual data is written back to the transposed destination
|
|
||||||
// matrix. This kernel optionally checks for upper/lower triangular matrices.
|
|
||||||
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
|
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
|
||||||
void TransposeMatrix(const int src_one, const int src_two,
|
void TransposePadMatrix(const int src_one, const int src_two,
|
||||||
const int src_ld, const int src_offset,
|
const int src_ld, const int src_offset,
|
||||||
__global const real* restrict src,
|
__global const real* restrict src,
|
||||||
const int dest_one, const int dest_two,
|
const int dest_one, const int dest_two,
|
||||||
const int dest_ld, const int dest_offset,
|
const int dest_ld, const int dest_offset,
|
||||||
__global real* dest,
|
__global real* dest,
|
||||||
const real_arg arg_alpha,
|
const real_arg arg_alpha,
|
||||||
|
const int do_conjugate) {
|
||||||
|
const real alpha = GetRealArg(arg_alpha);
|
||||||
|
__local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
|
||||||
|
_TransposePadMatrix(tile, src_one, src_two, src_ld, src_offset, src,
|
||||||
|
dest_one, dest_two, dest_ld, dest_offset, dest,
|
||||||
|
alpha, do_conjugate);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Transposes a matrix, while considering possible padding in the source matrix. Data is read from a
|
||||||
|
// padded source matrix, but only the actual data is written back to the transposed destination
|
||||||
|
// matrix. This kernel optionally checks for upper/lower triangular matrices.
|
||||||
|
inline void _TransposeMatrix(__local real* tile,
|
||||||
|
const int src_one, const int src_two,
|
||||||
|
const int src_ld, const int src_offset,
|
||||||
|
__global const real* restrict src,
|
||||||
|
const int dest_one, const int dest_two,
|
||||||
|
const int dest_ld, const int dest_offset,
|
||||||
|
__global real* dest,
|
||||||
|
const real alpha,
|
||||||
const int upper, const int lower,
|
const int upper, const int lower,
|
||||||
const int diagonal_imag_zero) {
|
const int diagonal_imag_zero) {
|
||||||
const real alpha = GetRealArg(arg_alpha);
|
|
||||||
|
|
||||||
// Local memory to store a tile of the matrix (for coalescing)
|
|
||||||
__local real tile[PADTRA_WPT*PADTRA_TILE][PADTRA_WPT*PADTRA_TILE + PADTRA_PAD];
|
|
||||||
|
|
||||||
// Loop over the work per thread
|
// Loop over the work per thread
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -117,7 +130,9 @@ void TransposeMatrix(const int src_one, const int src_two,
|
||||||
// Loads data into the local memory if the thread IDs are within bounds of the source matrix.
|
// Loads data into the local memory if the thread IDs are within bounds of the source matrix.
|
||||||
if ((id_src_one < src_one) && (id_src_two < src_two)) {
|
if ((id_src_one < src_one) && (id_src_two < src_two)) {
|
||||||
real value = src[id_src_two*src_ld + id_src_one + src_offset];
|
real value = src[id_src_two*src_ld + id_src_one + src_offset];
|
||||||
tile[get_local_id(1)*PADTRA_WPT + w_two][get_local_id(0)*PADTRA_WPT + w_one] = value;
|
const int tile_id0 = get_local_id(0)*PADTRA_WPT + w_one;
|
||||||
|
const int tile_id1 = get_local_id(1)*PADTRA_WPT + w_two;
|
||||||
|
tile[tile_id1 * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD) + tile_id0] = value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -145,7 +160,9 @@ void TransposeMatrix(const int src_one, const int src_two,
|
||||||
|
|
||||||
// Stores the transposed value in the destination matrix
|
// Stores the transposed value in the destination matrix
|
||||||
if ((id_dest_one < dest_one) && (id_dest_two < dest_two)) {
|
if ((id_dest_one < dest_one) && (id_dest_two < dest_two)) {
|
||||||
real value = tile[get_local_id(0)*PADTRA_WPT + w_two][get_local_id(1)*PADTRA_WPT + w_one];
|
const int tile_id0 = get_local_id(1)*PADTRA_WPT + w_one;
|
||||||
|
const int tile_id1 = get_local_id(0)*PADTRA_WPT + w_two;
|
||||||
|
real value = tile[tile_id1 * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD) + tile_id0];
|
||||||
if (diagonal_imag_zero == 1 && id_dest_one == id_dest_two) { ImagToZero(value); }
|
if (diagonal_imag_zero == 1 && id_dest_one == id_dest_two) { ImagToZero(value); }
|
||||||
Multiply(dest[id_dest_two*dest_ld + id_dest_one + dest_offset], alpha, value);
|
Multiply(dest[id_dest_two*dest_ld + id_dest_one + dest_offset], alpha, value);
|
||||||
}
|
}
|
||||||
|
@ -154,6 +171,65 @@ void TransposeMatrix(const int src_one, const int src_two,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Interface to the above function
|
||||||
|
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
|
||||||
|
void TransposeMatrix(const int src_one, const int src_two,
|
||||||
|
const int src_ld, const int src_offset,
|
||||||
|
__global const real* restrict src,
|
||||||
|
const int dest_one, const int dest_two,
|
||||||
|
const int dest_ld, const int dest_offset,
|
||||||
|
__global real* dest,
|
||||||
|
const real_arg arg_alpha,
|
||||||
|
const int upper, const int lower,
|
||||||
|
const int diagonal_imag_zero) {
|
||||||
|
const real alpha = GetRealArg(arg_alpha);
|
||||||
|
__local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
|
||||||
|
_TransposeMatrix(tile, src_one, src_two, src_ld, src_offset, src,
|
||||||
|
dest_one, dest_two, dest_ld, dest_offset, dest,
|
||||||
|
alpha, upper, lower, diagonal_imag_zero);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
#if defined(ROUTINE_GEMMBATCHED)
|
||||||
|
|
||||||
|
// Batched version of the above
|
||||||
|
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
|
||||||
|
void TransposePadMatrixBatched(const int src_one, const int src_two,
|
||||||
|
const int src_ld, const __constant int* src_offsets,
|
||||||
|
__global const real* restrict src,
|
||||||
|
const int dest_one, const int dest_two,
|
||||||
|
const int dest_ld, const __constant int* dest_offsets,
|
||||||
|
__global real* dest,
|
||||||
|
const int do_conjugate) {
|
||||||
|
const int batch = get_group_id(2);
|
||||||
|
const int src_offset = src_offsets[batch];
|
||||||
|
const int dest_offset = dest_offsets[batch];
|
||||||
|
real alpha; SetToOne(alpha);
|
||||||
|
__local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
|
||||||
|
_TransposePadMatrix(tile, src_one, src_two, src_ld, src_offset, src,
|
||||||
|
dest_one, dest_two, dest_ld, dest_offset, dest,
|
||||||
|
alpha, do_conjugate);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Batched version of the above
|
||||||
|
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
|
||||||
|
void TransposeMatrixBatched(const int src_one, const int src_two,
|
||||||
|
const int src_ld, const __constant int* src_offsets,
|
||||||
|
__global const real* restrict src,
|
||||||
|
const int dest_one, const int dest_two,
|
||||||
|
const int dest_ld, const __constant int* dest_offsets,
|
||||||
|
__global real* dest) {
|
||||||
|
const int batch = get_group_id(2);
|
||||||
|
const int src_offset = src_offsets[batch];
|
||||||
|
const int dest_offset = dest_offsets[batch];
|
||||||
|
real alpha; SetToOne(alpha);
|
||||||
|
__local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
|
||||||
|
_TransposeMatrix(tile, src_one, src_two, src_ld, src_offset, src,
|
||||||
|
dest_one, dest_two, dest_ld, dest_offset, dest,
|
||||||
|
alpha, 0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
// End of the C++11 raw string literal
|
// End of the C++11 raw string literal
|
||||||
|
|
70
src/kernels/level3/xgemm_batched.opencl
Normal file
70
src/kernels/level3/xgemm_batched.opencl
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||||
|
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||||
|
// width of 100 characters per line.
|
||||||
|
//
|
||||||
|
// Author(s):
|
||||||
|
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||||
|
//
|
||||||
|
// This file contains the batched version of the non-direct GEMM kernel. See part 1 for information
|
||||||
|
// about the non-batched version of the kernel.
|
||||||
|
//
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
|
||||||
|
// literal). Comment-out this line for syntax-highlighting when developing.
|
||||||
|
R"(
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Main entry point of the kernel. This is the regular full version.
|
||||||
|
__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
|
||||||
|
void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK,
|
||||||
|
const __constant real_arg* arg_alphas,
|
||||||
|
const __constant real_arg* arg_betas,
|
||||||
|
const __global realM* restrict agm, const int a_one, const int a_two,
|
||||||
|
const __global realN* restrict bgm, const int b_one, const int b_two,
|
||||||
|
__global realM* cgm, const int c_one, const int c_two) {
|
||||||
|
const int batch = get_group_id(2);
|
||||||
|
const real alpha = GetRealArg(arg_alphas[batch]);
|
||||||
|
const real beta = GetRealArg(arg_betas[batch]);
|
||||||
|
|
||||||
|
// Sets the offsets
|
||||||
|
const int a_offset = batch * a_one * a_two;
|
||||||
|
const int b_offset = batch * b_one * b_two;
|
||||||
|
const int c_offset = batch * c_one * c_two;
|
||||||
|
const __global realM* restrict agm_ = &agm[a_offset / VWM];
|
||||||
|
const __global realN* restrict bgm_ = &bgm[b_offset / VWN];
|
||||||
|
__global realM* restrict cgm_ = &cgm[c_offset / VWM];
|
||||||
|
|
||||||
|
// Allocates workgroup-private memory (local memory)
|
||||||
|
#if SA == 1
|
||||||
|
__local realM alm[KWG * MWG/VWM];
|
||||||
|
#endif
|
||||||
|
#if SB == 1
|
||||||
|
__local realN blm[KWG * NWG/VWN];
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Computes the matrix-multiplication and stores the result in register memory
|
||||||
|
realM cpm[NWI][MWI/VWM];
|
||||||
|
#if SA == 1 && SB == 1
|
||||||
|
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, alm, blm);
|
||||||
|
#elif SA == 1
|
||||||
|
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, alm);
|
||||||
|
#elif SB == 1
|
||||||
|
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm, blm);
|
||||||
|
#else
|
||||||
|
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, cpm);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
|
||||||
|
StoreResults(cgm_, cpm, kSizeM, alpha, beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// End of the C++11 raw string literal
|
||||||
|
)"
|
||||||
|
|
||||||
|
// =================================================================================================
|
110
src/kernels/level3/xgemm_direct_batched.opencl
Normal file
110
src/kernels/level3/xgemm_direct_batched.opencl
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||||
|
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||||
|
// width of 100 characters per line.
|
||||||
|
//
|
||||||
|
// Author(s):
|
||||||
|
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||||
|
//
|
||||||
|
// This file contains the batched version of the direct GEMM kernels. See part 1 for information
|
||||||
|
// about the non-batched version of the kernel.
|
||||||
|
//
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Enables loading of this file using the C++ pre-processor's #include (C++11 standard raw string
|
||||||
|
// literal). Comment-out this line for syntax-highlighting when developing.
|
||||||
|
R"(
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Direct version of the batched GEMM kernel with [A, B] = [non-transposed, non-transposed]
|
||||||
|
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
|
||||||
|
__kernel void XgemmDirectBatchedNN(const int kSizeM, const int kSizeN, const int kSizeK,
|
||||||
|
const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
|
||||||
|
const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
|
||||||
|
const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
|
||||||
|
__global real* cgm, const __constant int* c_offsets, const int c_ld,
|
||||||
|
const int c_transpose, const int a_conjugate, const int b_conjugate) {
|
||||||
|
const int batch = get_group_id(2);
|
||||||
|
const real_arg arg_alpha = arg_alphas[batch];
|
||||||
|
const real_arg arg_beta = arg_betas[batch];
|
||||||
|
const int a_offset = a_offsets[batch];
|
||||||
|
const int b_offset = b_offsets[batch];
|
||||||
|
const int c_offset = c_offsets[batch];
|
||||||
|
__local real alm[WGD * (WGD + PADA)];
|
||||||
|
__local real blm[WGD * (WGD + PADB)];
|
||||||
|
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
|
||||||
|
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
|
||||||
|
alm, blm, 0, 0, c_transpose, a_conjugate, b_conjugate);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct version of the batched GEMM kernel with [A, B] = [non-transposed, transposed]
|
||||||
|
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
|
||||||
|
__kernel void XgemmDirectBatchedNT(const int kSizeM, const int kSizeN, const int kSizeK,
|
||||||
|
const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
|
||||||
|
const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
|
||||||
|
const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
|
||||||
|
__global real* cgm, const __constant int* c_offsets, const int c_ld,
|
||||||
|
const int c_transpose, const int a_conjugate, const int b_conjugate) {
|
||||||
|
const int batch = get_group_id(2);
|
||||||
|
const real_arg arg_alpha = arg_alphas[batch];
|
||||||
|
const real_arg arg_beta = arg_betas[batch];
|
||||||
|
const int a_offset = a_offsets[batch];
|
||||||
|
const int b_offset = b_offsets[batch];
|
||||||
|
const int c_offset = c_offsets[batch];
|
||||||
|
__local real alm[WGD * (WGD + PADA)];
|
||||||
|
__local real blm[WGD * (WGD + PADB)];
|
||||||
|
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
|
||||||
|
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
|
||||||
|
alm, blm, 0, 1, c_transpose, a_conjugate, b_conjugate);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct version of the batched GEMM kernel with [A, B] = [transposed, non-transposed]
|
||||||
|
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
|
||||||
|
__kernel void XgemmDirectBatchedTN(const int kSizeM, const int kSizeN, const int kSizeK,
|
||||||
|
const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
|
||||||
|
const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
|
||||||
|
const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
|
||||||
|
__global real* cgm, const __constant int* c_offsets, const int c_ld,
|
||||||
|
const int c_transpose, const int a_conjugate, const int b_conjugate) {
|
||||||
|
const int batch = get_group_id(2);
|
||||||
|
const real_arg arg_alpha = arg_alphas[batch];
|
||||||
|
const real_arg arg_beta = arg_betas[batch];
|
||||||
|
const int a_offset = a_offsets[batch];
|
||||||
|
const int b_offset = b_offsets[batch];
|
||||||
|
const int c_offset = c_offsets[batch];
|
||||||
|
__local real alm[WGD * (WGD + PADA)];
|
||||||
|
__local real blm[WGD * (WGD + PADB)];
|
||||||
|
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
|
||||||
|
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
|
||||||
|
alm, blm, 1, 0, c_transpose, a_conjugate, b_conjugate);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct version of the batched GEMM kernel with [A, B] = [transposed, transposed]
|
||||||
|
__attribute__((reqd_work_group_size(MDIMCD, NDIMCD, 1)))
|
||||||
|
__kernel void XgemmDirectBatchedTT(const int kSizeM, const int kSizeN, const int kSizeK,
|
||||||
|
const __constant real_arg* arg_alphas, const __constant real_arg* arg_betas,
|
||||||
|
const __global realMD* restrict agm, const __constant int* a_offsets, const int a_ld,
|
||||||
|
const __global realND* restrict bgm, const __constant int* b_offsets, const int b_ld,
|
||||||
|
__global real* cgm, const __constant int* c_offsets, const int c_ld,
|
||||||
|
const int c_transpose, const int a_conjugate, const int b_conjugate) {
|
||||||
|
const int batch = get_group_id(2);
|
||||||
|
const real_arg arg_alpha = arg_alphas[batch];
|
||||||
|
const real_arg arg_beta = arg_betas[batch];
|
||||||
|
const int a_offset = a_offsets[batch];
|
||||||
|
const int b_offset = b_offsets[batch];
|
||||||
|
const int c_offset = c_offsets[batch];
|
||||||
|
__local real alm[WGD * (WGD + PADA)];
|
||||||
|
__local real blm[WGD * (WGD + PADB)];
|
||||||
|
XgemmDirect(kSizeM, kSizeN, kSizeK, arg_alpha, arg_beta,
|
||||||
|
agm, a_offset, a_ld, bgm, b_offset, b_ld, cgm, c_offset, c_ld,
|
||||||
|
alm, blm, 1, 1, c_transpose, a_conjugate, b_conjugate);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// End of the C++11 raw string literal
|
||||||
|
)"
|
||||||
|
|
||||||
|
// =================================================================================================
|
|
@ -196,6 +196,70 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Batched version of the above
|
||||||
|
template <typename T>
|
||||||
|
void PadCopyTransposeMatrixBatched(Queue &queue, const Device &device,
|
||||||
|
const Databases &db,
|
||||||
|
EventPointer event, const std::vector<Event> &waitForEvents,
|
||||||
|
const size_t src_one, const size_t src_two,
|
||||||
|
const size_t src_ld, const Buffer<int> &src_offsets,
|
||||||
|
const Buffer<T> &src,
|
||||||
|
const size_t dest_one, const size_t dest_two,
|
||||||
|
const size_t dest_ld, const Buffer<int> &dest_offsets,
|
||||||
|
const Buffer<T> &dest,
|
||||||
|
const Program &program, const bool do_pad,
|
||||||
|
const bool do_transpose, const bool do_conjugate,
|
||||||
|
const size_t batch_count) {
|
||||||
|
|
||||||
|
// Determines the right kernel
|
||||||
|
auto kernel_name = std::string{};
|
||||||
|
if (do_transpose) {
|
||||||
|
kernel_name = (do_pad) ? "TransposePadMatrixBatched" : "TransposeMatrixBatched";
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
kernel_name = (do_pad) ? "CopyPadMatrixBatched" : "CopyMatrixBatched";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieves the kernel from the compiled binary
|
||||||
|
auto kernel = Kernel(program, kernel_name);
|
||||||
|
|
||||||
|
// Sets the kernel arguments
|
||||||
|
kernel.SetArgument(0, static_cast<int>(src_one));
|
||||||
|
kernel.SetArgument(1, static_cast<int>(src_two));
|
||||||
|
kernel.SetArgument(2, static_cast<int>(src_ld));
|
||||||
|
kernel.SetArgument(3, src_offsets());
|
||||||
|
kernel.SetArgument(4, src());
|
||||||
|
kernel.SetArgument(5, static_cast<int>(dest_one));
|
||||||
|
kernel.SetArgument(6, static_cast<int>(dest_two));
|
||||||
|
kernel.SetArgument(7, static_cast<int>(dest_ld));
|
||||||
|
kernel.SetArgument(8, dest_offsets());
|
||||||
|
kernel.SetArgument(9, dest());
|
||||||
|
if (do_pad) {
|
||||||
|
kernel.SetArgument(10, static_cast<int>(do_conjugate));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launches the kernel and returns the error code. Uses global and local thread sizes based on
|
||||||
|
// parameters in the database.
|
||||||
|
if (do_transpose) {
|
||||||
|
const auto global = std::vector<size_t>{
|
||||||
|
Ceil(CeilDiv(dest_one, db["PADTRA_WPT"]), db["PADTRA_TILE"]),
|
||||||
|
Ceil(CeilDiv(dest_two, db["PADTRA_WPT"]), db["PADTRA_TILE"]),
|
||||||
|
batch_count
|
||||||
|
};
|
||||||
|
const auto local = std::vector<size_t>{db["PADTRA_TILE"], db["PADTRA_TILE"], 1};
|
||||||
|
RunKernel(kernel, queue, device, global, local, event, waitForEvents);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const auto global = std::vector<size_t>{
|
||||||
|
Ceil(CeilDiv(dest_one, db["PAD_WPTX"]), db["PAD_DIMX"]),
|
||||||
|
Ceil(CeilDiv(dest_two, db["PAD_WPTY"]), db["PAD_DIMY"]),
|
||||||
|
batch_count
|
||||||
|
};
|
||||||
|
const auto local = std::vector<size_t>{db["PAD_DIMX"], db["PAD_DIMY"], 1};
|
||||||
|
RunKernel(kernel, queue, device, global, local, event, waitForEvents);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
} // namespace clblast
|
} // namespace clblast
|
||||||
|
|
||||||
|
|
|
@ -104,13 +104,13 @@ void Xgemm<T>::DoGemm(const Layout layout,
|
||||||
// Selects which version of GEMM to run
|
// Selects which version of GEMM to run
|
||||||
const auto do_gemm_direct = (m * n * k < db_["XGEMM_MIN_INDIRECT_SIZE"]);
|
const auto do_gemm_direct = (m * n * k < db_["XGEMM_MIN_INDIRECT_SIZE"]);
|
||||||
if (do_gemm_direct) { // for small sizes (single kernel)
|
if (do_gemm_direct) { // for small sizes (single kernel)
|
||||||
return GemmDirect(m, n, k, alpha,
|
GemmDirect(m, n, k, alpha,
|
||||||
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
|
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
|
||||||
c_buffer, c_offset, c_ld,
|
c_buffer, c_offset, c_ld,
|
||||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
|
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
|
||||||
}
|
}
|
||||||
else { // for larger sizes (pre/post-processing plus a very fast kernel)
|
else { // for larger sizes (pre/post-processing plus a very fast kernel)
|
||||||
return GemmIndirect(m, n, k, alpha,
|
GemmIndirect(m, n, k, alpha,
|
||||||
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
|
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
|
||||||
c_buffer, c_offset, c_ld,
|
c_buffer, c_offset, c_ld,
|
||||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
|
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
|
||||||
|
|
350
src/routines/levelx/xgemmbatched.cpp
Normal file
350
src/routines/levelx/xgemmbatched.cpp
Normal file
|
@ -0,0 +1,350 @@
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||||
|
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||||
|
// width of 100 characters per line.
|
||||||
|
//
|
||||||
|
// Author(s):
|
||||||
|
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||||
|
//
|
||||||
|
// This file implements the XgemmBatched class (see the header for information about the class).
|
||||||
|
//
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
#include "routines/levelx/xgemmbatched.hpp"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace clblast {
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Constructor: forwards to base class constructor
|
||||||
|
template <typename T>
|
||||||
|
XgemmBatched<T>::XgemmBatched(Queue &queue, EventPointer event, const std::string &name):
|
||||||
|
Routine(queue, event, name,
|
||||||
|
{"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","KernelSelection"},
|
||||||
|
PrecisionValue<T>(), {}, {
|
||||||
|
#include "../../kernels/level3/level3.opencl"
|
||||||
|
#include "../../kernels/level3/copy_fast.opencl"
|
||||||
|
#include "../../kernels/level3/copy_pad.opencl"
|
||||||
|
#include "../../kernels/level3/transpose_fast.opencl"
|
||||||
|
#include "../../kernels/level3/transpose_pad.opencl"
|
||||||
|
, // separated in multiple parts to prevent C1091 in MSVC 2013
|
||||||
|
#include "../../kernels/level3/xgemm_direct_part1.opencl"
|
||||||
|
#include "../../kernels/level3/xgemm_direct_part2.opencl"
|
||||||
|
#include "../../kernels/level3/xgemm_direct_part3.opencl"
|
||||||
|
, // separated in multiple parts to prevent C1091 in MSVC 2013
|
||||||
|
#include "../../kernels/level3/xgemm_part1.opencl"
|
||||||
|
#include "../../kernels/level3/xgemm_part2.opencl"
|
||||||
|
#include "../../kernels/level3/xgemm_part3.opencl"
|
||||||
|
, // separated in multiple parts to prevent C1091 in MSVC 2013
|
||||||
|
#include "../../kernels/level3/xgemm_batched.opencl"
|
||||||
|
#include "../../kernels/level3/xgemm_direct_batched.opencl"
|
||||||
|
}) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// The main routine
|
||||||
|
template <typename T>
|
||||||
|
void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const std::vector<T> &alphas,
|
||||||
|
const Buffer<T> & a_buffer, const std::vector<size_t> &a_offsets, const size_t a_ld,
|
||||||
|
const Buffer<T> & b_buffer, const std::vector<size_t> &b_offsets, const size_t b_ld,
|
||||||
|
const std::vector<T> &betas,
|
||||||
|
const Buffer<T> & c_buffer, const std::vector<size_t> &c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count) {
|
||||||
|
|
||||||
|
// Tests for a valid batch count
|
||||||
|
if ((batch_count < 1) || (alphas.size() != batch_count) || (betas.size() != batch_count) ||
|
||||||
|
(a_offsets.size() != batch_count) || (b_offsets.size() != batch_count) || (c_offsets.size() != batch_count)) {
|
||||||
|
throw BLASError(StatusCode::kInvalidBatchCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Makes sure all dimensions are larger than zero
|
||||||
|
if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
|
||||||
|
|
||||||
|
// Computes whether or not the matrices are transposed in memory. See GEMM routine for details.
|
||||||
|
const auto a_rotated = (layout == Layout::kColMajor && a_transpose != Transpose::kNo) ||
|
||||||
|
(layout == Layout::kRowMajor && a_transpose == Transpose::kNo);
|
||||||
|
const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
|
||||||
|
(layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
|
||||||
|
const auto c_rotated = (layout == Layout::kRowMajor);
|
||||||
|
static const auto a_want_rotated = false;
|
||||||
|
static const auto b_want_rotated = true;
|
||||||
|
static const auto c_want_rotated = false;
|
||||||
|
const auto a_do_transpose = a_rotated != a_want_rotated;
|
||||||
|
const auto b_do_transpose = b_rotated != b_want_rotated;
|
||||||
|
const auto c_do_transpose = c_rotated != c_want_rotated;
|
||||||
|
|
||||||
|
// In case of complex data-types, the transpose can also become a conjugate transpose
|
||||||
|
const auto a_conjugate = (a_transpose == Transpose::kConjugate);
|
||||||
|
const auto b_conjugate = (b_transpose == Transpose::kConjugate);
|
||||||
|
|
||||||
|
// Computes the first and second dimensions of the 3 matrices taking into account whether the
|
||||||
|
// matrices are rotated or not
|
||||||
|
const auto a_one = (a_rotated) ? k : m;
|
||||||
|
const auto a_two = (a_rotated) ? m : k;
|
||||||
|
const auto b_one = (b_rotated) ? n : k;
|
||||||
|
const auto b_two = (b_rotated) ? k : n;
|
||||||
|
const auto c_one = (c_rotated) ? n : m;
|
||||||
|
const auto c_two = (c_rotated) ? m : n;
|
||||||
|
|
||||||
|
// Tests the matrices for validity
|
||||||
|
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||||
|
TestMatrixA(a_one, a_two, a_buffer, a_offsets[batch], a_ld);
|
||||||
|
TestMatrixB(b_one, b_two, b_buffer, b_offsets[batch], b_ld);
|
||||||
|
TestMatrixC(c_one, c_two, c_buffer, c_offsets[batch], c_ld);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload the scalar arguments to the device
|
||||||
|
auto alphas_device = Buffer<T>(context_, BufferAccess::kReadOnly, batch_count);
|
||||||
|
auto betas_device = Buffer<T>(context_, BufferAccess::kReadOnly, batch_count);
|
||||||
|
alphas_device.Write(queue_, batch_count, alphas);
|
||||||
|
betas_device.Write(queue_, batch_count, betas);
|
||||||
|
|
||||||
|
// Converts the offset to integers
|
||||||
|
std::vector<int> a_offsets_int(a_offsets.begin(), a_offsets.end());
|
||||||
|
std::vector<int> b_offsets_int(b_offsets.begin(), b_offsets.end());
|
||||||
|
std::vector<int> c_offsets_int(c_offsets.begin(), c_offsets.end());
|
||||||
|
|
||||||
|
// Selects which version of the batched GEMM to run
|
||||||
|
const auto do_gemm_direct = true;
|
||||||
|
if (do_gemm_direct) { // single generic kernel
|
||||||
|
BatchedGemmDirect(m, n, k, alphas_device,
|
||||||
|
a_buffer, a_offsets_int, a_ld, b_buffer, b_offsets_int, b_ld,
|
||||||
|
betas_device, c_buffer, c_offsets_int, c_ld,
|
||||||
|
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
|
||||||
|
batch_count);
|
||||||
|
}
|
||||||
|
else { // pre/post-processing plus a very fast kernel
|
||||||
|
BatchedGemmIndirect(m, n, k, alphas_device,
|
||||||
|
a_buffer, a_offsets_int, a_ld, b_buffer, b_offsets_int, b_ld,
|
||||||
|
betas_device, c_buffer, c_offsets_int, c_ld,
|
||||||
|
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
|
||||||
|
a_one, a_two, a_want_rotated,
|
||||||
|
b_one, b_two, b_want_rotated,
|
||||||
|
c_one, c_two, c_want_rotated,
|
||||||
|
batch_count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// The indirect version of batched GEMM. This uses the faster but non-general kernel. It has specific
|
||||||
|
// requirements, but several pre and post-processing kernels take care of those. However, the
|
||||||
|
// overhead of these extra kernels might not be ideal for certain devices/arguments.
|
||||||
|
template <typename T>
|
||||||
|
void XgemmBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, const size_t k,
|
||||||
|
const Buffer<T> &alphas,
|
||||||
|
const Buffer<T> &a_buffer, const std::vector<int> &a_offsets, const size_t a_ld,
|
||||||
|
const Buffer<T> &b_buffer, const std::vector<int> &b_offsets, const size_t b_ld,
|
||||||
|
const Buffer<T> &betas,
|
||||||
|
const Buffer<T> &c_buffer, const std::vector<int> &c_offsets, const size_t c_ld,
|
||||||
|
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
|
||||||
|
const bool a_conjugate, const bool b_conjugate,
|
||||||
|
const size_t a_one, const size_t a_two, const bool a_want_rotated,
|
||||||
|
const size_t b_one, const size_t b_two, const bool b_want_rotated,
|
||||||
|
const size_t c_one, const size_t c_two, const bool c_want_rotated,
|
||||||
|
const size_t batch_count) {
|
||||||
|
// Calculates the ceiled versions of m, n, and k
|
||||||
|
const auto m_ceiled = Ceil(Ceil(m, db_["MWG"]), db_["VWM"]);
|
||||||
|
const auto n_ceiled = Ceil(Ceil(n, db_["NWG"]), db_["VWN"]);
|
||||||
|
const auto k_ceiled = Ceil(Ceil(k, db_["KWG"]), db_["VWM"]);
|
||||||
|
|
||||||
|
// Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
|
||||||
|
// whether the matrices need to be rotated or not for the kernel.
|
||||||
|
const auto a_one_i = (a_want_rotated) ? k_ceiled : m_ceiled;
|
||||||
|
const auto a_two_i = (a_want_rotated) ? m_ceiled : k_ceiled;
|
||||||
|
const auto b_one_i = (b_want_rotated) ? n_ceiled : k_ceiled;
|
||||||
|
const auto b_two_i = (b_want_rotated) ? k_ceiled : n_ceiled;
|
||||||
|
const auto c_one_i = (c_want_rotated) ? n_ceiled : m_ceiled;
|
||||||
|
const auto c_two_i = (c_want_rotated) ? m_ceiled : n_ceiled;
|
||||||
|
|
||||||
|
// Sets the "internal" offsets, i.e. the perfect offsets
|
||||||
|
auto a_offsets_i = std::vector<int>(batch_count);
|
||||||
|
auto b_offsets_i = std::vector<int>(batch_count);
|
||||||
|
auto c_offsets_i = std::vector<int>(batch_count);
|
||||||
|
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||||
|
a_offsets_i[batch] = batch * a_one_i * a_two_i;
|
||||||
|
b_offsets_i[batch] = batch * b_one_i * b_two_i;
|
||||||
|
c_offsets_i[batch] = batch * c_one_i * c_two_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determines whether or not temporary matrices are needed
|
||||||
|
auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offsets == a_offsets_i &&
|
||||||
|
a_do_transpose == false && a_conjugate == false;
|
||||||
|
auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && b_offsets == b_offsets_i &&
|
||||||
|
b_do_transpose == false && b_conjugate == false;
|
||||||
|
auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offsets == c_offsets_i &&
|
||||||
|
c_do_transpose == false;
|
||||||
|
|
||||||
|
// Creates the temporary matrices
|
||||||
|
const auto a_temp = (a_no_temp) ? a_buffer : Buffer<T>(context_, batch_count * a_one_i * a_two_i);
|
||||||
|
const auto b_temp = (b_no_temp) ? b_buffer : Buffer<T>(context_, batch_count * b_one_i * b_two_i);
|
||||||
|
const auto c_temp = (c_no_temp) ? c_buffer : Buffer<T>(context_, batch_count * c_one_i * c_two_i);
|
||||||
|
|
||||||
|
// Events of all kernels (including pre/post processing kernels)
|
||||||
|
auto eventWaitList = std::vector<Event>();
|
||||||
|
auto emptyEventList = std::vector<Event>();
|
||||||
|
|
||||||
|
// Runs the pre-processing kernel for matrix A. This transposes the matrix, but also pads zeros
|
||||||
|
// to fill it up until it reaches a certain multiple of size (kernel parameter dependent). In
|
||||||
|
// case nothing has to be done, these kernels can be skipped.
|
||||||
|
if (!a_no_temp) {
|
||||||
|
auto a_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
|
||||||
|
auto a_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
|
||||||
|
a_offsets_device.Write(queue_, batch_count, a_offsets);
|
||||||
|
a_offsets_i_device.Write(queue_, batch_count, a_offsets_i);
|
||||||
|
auto eventProcessA = Event();
|
||||||
|
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
|
||||||
|
a_one, a_two, a_ld, a_offsets_device, a_buffer,
|
||||||
|
a_one_i, a_two_i, a_one_i, a_offsets_i_device, a_temp,
|
||||||
|
program_, true, a_do_transpose, a_conjugate, batch_count);
|
||||||
|
eventWaitList.push_back(eventProcessA);
|
||||||
|
}
|
||||||
|
|
||||||
|
// As above, but now for matrix B
|
||||||
|
if (!b_no_temp) {
|
||||||
|
auto b_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
|
||||||
|
auto b_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
|
||||||
|
b_offsets_device.Write(queue_, batch_count, b_offsets);
|
||||||
|
b_offsets_i_device.Write(queue_, batch_count, b_offsets_i);
|
||||||
|
auto eventProcessB = Event();
|
||||||
|
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
|
||||||
|
b_one, b_two, b_ld, b_offsets_device, b_buffer,
|
||||||
|
b_one_i, b_two_i, b_one_i, b_offsets_i_device, b_temp,
|
||||||
|
program_, true, b_do_transpose, b_conjugate, batch_count);
|
||||||
|
eventWaitList.push_back(eventProcessB);
|
||||||
|
}
|
||||||
|
|
||||||
|
// As above, but now for matrix C
|
||||||
|
auto c_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
|
||||||
|
auto c_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
|
||||||
|
if (!c_no_temp) {
|
||||||
|
c_offsets_device.Write(queue_, batch_count, c_offsets);
|
||||||
|
c_offsets_i_device.Write(queue_, batch_count, c_offsets_i);
|
||||||
|
auto eventProcessC = Event();
|
||||||
|
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
|
||||||
|
c_one, c_two, c_ld, c_offsets_device, c_buffer,
|
||||||
|
c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp,
|
||||||
|
program_, true, c_do_transpose, false, batch_count);
|
||||||
|
eventWaitList.push_back(eventProcessC);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieves the Xgemm kernel from the compiled binary
|
||||||
|
auto kernel = Kernel(program_, "XgemmBatched");
|
||||||
|
|
||||||
|
// Sets the kernel arguments
|
||||||
|
kernel.SetArgument(0, static_cast<int>(m_ceiled));
|
||||||
|
kernel.SetArgument(1, static_cast<int>(n_ceiled));
|
||||||
|
kernel.SetArgument(2, static_cast<int>(k_ceiled));
|
||||||
|
kernel.SetArgument(3, alphas());
|
||||||
|
kernel.SetArgument(4, betas());
|
||||||
|
kernel.SetArgument(5, a_temp());
|
||||||
|
kernel.SetArgument(6, static_cast<int>(a_one_i));
|
||||||
|
kernel.SetArgument(7, static_cast<int>(a_two_i));
|
||||||
|
kernel.SetArgument(8, b_temp());
|
||||||
|
kernel.SetArgument(9, static_cast<int>(b_one_i));
|
||||||
|
kernel.SetArgument(10, static_cast<int>(b_two_i));
|
||||||
|
kernel.SetArgument(11, c_temp());
|
||||||
|
kernel.SetArgument(12, static_cast<int>(c_one_i));
|
||||||
|
kernel.SetArgument(13, static_cast<int>(c_two_i));
|
||||||
|
|
||||||
|
// Computes the global and local thread sizes
|
||||||
|
const auto global = std::vector<size_t>{
|
||||||
|
(c_one_i * db_["MDIMC"]) / db_["MWG"],
|
||||||
|
(c_two_i * db_["NDIMC"]) / db_["NWG"],
|
||||||
|
batch_count
|
||||||
|
};
|
||||||
|
const auto local = std::vector<size_t>{db_["MDIMC"], db_["NDIMC"], 1};
|
||||||
|
|
||||||
|
// Launches the kernel
|
||||||
|
auto eventKernel = Event();
|
||||||
|
auto eventPointer = eventKernel.pointer();
|
||||||
|
RunKernel(kernel, queue_, device_, global, local, eventPointer, eventWaitList);
|
||||||
|
|
||||||
|
// Runs the post-processing kernel if needed
|
||||||
|
if (!c_no_temp) {
|
||||||
|
eventWaitList.push_back(eventKernel);
|
||||||
|
PadCopyTransposeMatrixBatched(queue_, device_, db_, event_, eventWaitList,
|
||||||
|
c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp,
|
||||||
|
c_one, c_two, c_ld, c_offsets_device, c_buffer,
|
||||||
|
program_, false, c_do_transpose, false, batch_count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// The direct version of batched GEMM, requiring just one kernel, no pre or post-processing kernels.
|
||||||
|
template <typename T>
|
||||||
|
void XgemmBatched<T>::BatchedGemmDirect(const size_t m, const size_t n, const size_t k,
|
||||||
|
const Buffer<T> &alphas,
|
||||||
|
const Buffer<T> &a_buffer, const std::vector<int> &a_offsets, const size_t a_ld,
|
||||||
|
const Buffer<T> &b_buffer, const std::vector<int> &b_offsets, const size_t b_ld,
|
||||||
|
const Buffer<T> &betas,
|
||||||
|
const Buffer<T> &c_buffer, const std::vector<int> &c_offsets, const size_t c_ld,
|
||||||
|
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
|
||||||
|
const bool a_conjugate, const bool b_conjugate,
|
||||||
|
const size_t batch_count) {
|
||||||
|
|
||||||
|
// Uploads the offsets to the device
|
||||||
|
auto a_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
|
||||||
|
auto b_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
|
||||||
|
auto c_offsets_device = Buffer<int>(context_, BufferAccess::kReadOnly, batch_count);
|
||||||
|
a_offsets_device.Write(queue_, batch_count, a_offsets);
|
||||||
|
b_offsets_device.Write(queue_, batch_count, b_offsets);
|
||||||
|
c_offsets_device.Write(queue_, batch_count, c_offsets);
|
||||||
|
|
||||||
|
// Retrieves the proper XgemmDirect kernel from the compiled binary
|
||||||
|
const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectBatchedTT" : "XgemmDirectBatchedTN") :
|
||||||
|
(b_do_transpose ? "XgemmDirectBatchedNT" : "XgemmDirectBatchedNN");
|
||||||
|
auto kernel = Kernel(program_, name);
|
||||||
|
|
||||||
|
// Sets the kernel arguments
|
||||||
|
kernel.SetArgument(0, static_cast<int>(m));
|
||||||
|
kernel.SetArgument(1, static_cast<int>(n));
|
||||||
|
kernel.SetArgument(2, static_cast<int>(k));
|
||||||
|
kernel.SetArgument(3, alphas());
|
||||||
|
kernel.SetArgument(4, betas());
|
||||||
|
kernel.SetArgument(5, a_buffer());
|
||||||
|
kernel.SetArgument(6, a_offsets_device());
|
||||||
|
kernel.SetArgument(7, static_cast<int>(a_ld));
|
||||||
|
kernel.SetArgument(8, b_buffer());
|
||||||
|
kernel.SetArgument(9, b_offsets_device());
|
||||||
|
kernel.SetArgument(10, static_cast<int>(b_ld));
|
||||||
|
kernel.SetArgument(11, c_buffer());
|
||||||
|
kernel.SetArgument(12, c_offsets_device());
|
||||||
|
kernel.SetArgument(13, static_cast<int>(c_ld));
|
||||||
|
kernel.SetArgument(14, static_cast<int>(c_do_transpose));
|
||||||
|
kernel.SetArgument(15, static_cast<int>(a_conjugate));
|
||||||
|
kernel.SetArgument(16, static_cast<int>(b_conjugate));
|
||||||
|
|
||||||
|
// Computes the global and local thread sizes
|
||||||
|
const auto m_ceiled = Ceil(m, db_["WGD"]);
|
||||||
|
const auto n_ceiled = Ceil(n, db_["WGD"]);
|
||||||
|
const auto global = std::vector<size_t>{
|
||||||
|
(m_ceiled * db_["MDIMCD"]) / db_["WGD"],
|
||||||
|
(n_ceiled * db_["NDIMCD"]) / db_["WGD"],
|
||||||
|
batch_count
|
||||||
|
};
|
||||||
|
const auto local = std::vector<size_t>{db_["MDIMCD"], db_["NDIMCD"], 1};
|
||||||
|
|
||||||
|
// Launches the kernel
|
||||||
|
RunKernel(kernel, queue_, device_, global, local, event_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Compiles the templated class
|
||||||
|
template class XgemmBatched<half>;
|
||||||
|
template class XgemmBatched<float>;
|
||||||
|
template class XgemmBatched<double>;
|
||||||
|
template class XgemmBatched<float2>;
|
||||||
|
template class XgemmBatched<double2>;
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
} // namespace clblast
|
72
src/routines/levelx/xgemmbatched.hpp
Normal file
72
src/routines/levelx/xgemmbatched.hpp
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||||
|
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||||
|
// width of 100 characters per line.
|
||||||
|
//
|
||||||
|
// Author(s):
|
||||||
|
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||||
|
//
|
||||||
|
// This file implements the XgemmBatched routine. This is a non-blas batched version of GEMM.
|
||||||
|
//
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
#ifndef CLBLAST_ROUTINES_XGEMMBATCHED_H_
|
||||||
|
#define CLBLAST_ROUTINES_XGEMMBATCHED_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "routine.hpp"
|
||||||
|
|
||||||
|
namespace clblast {
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// See comment at top of file for a description of the class
|
||||||
|
template <typename T>
|
||||||
|
class XgemmBatched: public Routine {
|
||||||
|
public:
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
XgemmBatched(Queue &queue, EventPointer event, const std::string &name = "GEMMBATCHED");
|
||||||
|
|
||||||
|
// Templated-precision implementation of the routine
|
||||||
|
void DoGemmBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
|
||||||
|
const size_t m, const size_t n, const size_t k,
|
||||||
|
const std::vector<T> &alphas,
|
||||||
|
const Buffer<T> & a_buffer, const std::vector<size_t> &a_offsets, const size_t a_ld,
|
||||||
|
const Buffer<T> & b_buffer, const std::vector<size_t> &b_offsets, const size_t b_ld,
|
||||||
|
const std::vector<T> &betas,
|
||||||
|
const Buffer<T> & c_buffer, const std::vector<size_t> &c_offsets, const size_t c_ld,
|
||||||
|
const size_t batch_count);
|
||||||
|
|
||||||
|
// Indirect version of batched GEMM (with pre and post-processing kernels)
|
||||||
|
void BatchedGemmIndirect(const size_t m, const size_t n, const size_t k,
|
||||||
|
const Buffer<T> &alphas,
|
||||||
|
const Buffer<T> &a_buffer, const std::vector<int> &a_offsets, const size_t a_ld,
|
||||||
|
const Buffer<T> &b_buffer, const std::vector<int> &b_offsets, const size_t b_ld,
|
||||||
|
const Buffer<T> &betas,
|
||||||
|
const Buffer<T> &c_buffer, const std::vector<int> &c_offsets, const size_t c_ld,
|
||||||
|
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
|
||||||
|
const bool a_conjugate, const bool b_conjugate,
|
||||||
|
const size_t a_one, const size_t a_two, const bool a_want_rotated,
|
||||||
|
const size_t b_one, const size_t b_two, const bool b_want_rotated,
|
||||||
|
const size_t c_one, const size_t c_two, const bool c_want_rotated,
|
||||||
|
const size_t batch_count);
|
||||||
|
|
||||||
|
// Direct version of batched GEMM (no pre and post-processing kernels)
|
||||||
|
void BatchedGemmDirect(const size_t m, const size_t n, const size_t k,
|
||||||
|
const Buffer<T> &alphas,
|
||||||
|
const Buffer<T> &a_buffer, const std::vector<int> &a_offsets, const size_t a_ld,
|
||||||
|
const Buffer<T> &b_buffer, const std::vector<int> &b_offsets, const size_t b_ld,
|
||||||
|
const Buffer<T> &betas,
|
||||||
|
const Buffer<T> &c_buffer, const std::vector<int> &c_offsets, const size_t c_ld,
|
||||||
|
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
|
||||||
|
const bool a_conjugate, const bool b_conjugate,
|
||||||
|
const size_t batch_count);
|
||||||
|
};
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
} // namespace clblast
|
||||||
|
|
||||||
|
// CLBLAST_ROUTINES_XGEMMBATCHED_H_
|
||||||
|
#endif
|
|
@ -46,6 +46,7 @@ class TuneCopy {
|
||||||
static size_t DefaultM() { return 1024; }
|
static size_t DefaultM() { return 1024; }
|
||||||
static size_t DefaultN() { return 1024; }
|
static size_t DefaultN() { return 1024; }
|
||||||
static size_t DefaultK() { return 1; } // N/A for this kernel
|
static size_t DefaultK() { return 1; } // N/A for this kernel
|
||||||
|
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
|
||||||
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
||||||
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ class TunePad {
|
||||||
static size_t DefaultM() { return 1024; }
|
static size_t DefaultM() { return 1024; }
|
||||||
static size_t DefaultN() { return 1024; }
|
static size_t DefaultN() { return 1024; }
|
||||||
static size_t DefaultK() { return 1; } // N/A for this kernel
|
static size_t DefaultK() { return 1; } // N/A for this kernel
|
||||||
|
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
|
||||||
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
||||||
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ class TuneTranspose {
|
||||||
static size_t DefaultM() { return 1024; }
|
static size_t DefaultM() { return 1024; }
|
||||||
static size_t DefaultN() { return 1024; }
|
static size_t DefaultN() { return 1024; }
|
||||||
static size_t DefaultK() { return 1; } // N/A for this kernel
|
static size_t DefaultK() { return 1; } // N/A for this kernel
|
||||||
|
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
|
||||||
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
||||||
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ class TunePadTranspose {
|
||||||
static size_t DefaultM() { return 1024; }
|
static size_t DefaultM() { return 1024; }
|
||||||
static size_t DefaultN() { return 1024; }
|
static size_t DefaultN() { return 1024; }
|
||||||
static size_t DefaultK() { return 1; } // N/A for this kernel
|
static size_t DefaultK() { return 1; } // N/A for this kernel
|
||||||
|
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
|
||||||
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
||||||
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,7 @@ class TuneXaxpy {
|
||||||
static size_t DefaultM() { return 1; } // N/A for this kernel
|
static size_t DefaultM() { return 1; } // N/A for this kernel
|
||||||
static size_t DefaultN() { return 4096*1024; }
|
static size_t DefaultN() { return 4096*1024; }
|
||||||
static size_t DefaultK() { return 1; } // N/A for this kernel
|
static size_t DefaultK() { return 1; } // N/A for this kernel
|
||||||
|
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
|
||||||
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
||||||
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ class TuneXdot {
|
||||||
static size_t DefaultM() { return 1; } // N/A for this kernel
|
static size_t DefaultM() { return 1; } // N/A for this kernel
|
||||||
static size_t DefaultN() { return 2*1024*1024; }
|
static size_t DefaultN() { return 2*1024*1024; }
|
||||||
static size_t DefaultK() { return 1; } // N/A for this kernel
|
static size_t DefaultK() { return 1; } // N/A for this kernel
|
||||||
|
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
|
||||||
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
||||||
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,7 @@ class TuneXgemm {
|
||||||
static size_t DefaultM() { return 1024; }
|
static size_t DefaultM() { return 1024; }
|
||||||
static size_t DefaultN() { return 1024; }
|
static size_t DefaultN() { return 1024; }
|
||||||
static size_t DefaultK() { return 1024; }
|
static size_t DefaultK() { return 1024; }
|
||||||
|
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
|
||||||
static double DefaultFraction() { return (V==1) ? 1.0 : 512.0; } // test all or sample randomly
|
static double DefaultFraction() { return (V==1) ? 1.0 : 512.0; } // test all or sample randomly
|
||||||
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,7 @@ class TuneXgemmDirect {
|
||||||
static size_t DefaultM() { return 256; }
|
static size_t DefaultM() { return 256; }
|
||||||
static size_t DefaultN() { return 256; }
|
static size_t DefaultN() { return 256; }
|
||||||
static size_t DefaultK() { return 256; }
|
static size_t DefaultK() { return 256; }
|
||||||
|
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
|
||||||
static double DefaultFraction() { return (V==1) ? 1.0 : 32.0; } // test all or sample randomly
|
static double DefaultFraction() { return (V==1) ? 1.0 : 32.0; } // test all or sample randomly
|
||||||
static size_t DefaultNumRuns() { return 4; } // run every kernel this many times for averaging
|
static size_t DefaultNumRuns() { return 4; } // run every kernel this many times for averaging
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ class TuneXgemv {
|
||||||
static size_t DefaultM() { return 2048; }
|
static size_t DefaultM() { return 2048; }
|
||||||
static size_t DefaultN() { return 2048; }
|
static size_t DefaultN() { return 2048; }
|
||||||
static size_t DefaultK() { return 1; } // N/A for this kernel
|
static size_t DefaultK() { return 1; } // N/A for this kernel
|
||||||
|
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
|
||||||
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
||||||
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ class TuneXger {
|
||||||
static size_t DefaultM() { return 1024; }
|
static size_t DefaultM() { return 1024; }
|
||||||
static size_t DefaultN() { return 1024; }
|
static size_t DefaultN() { return 1024; }
|
||||||
static size_t DefaultK() { return 1; } // N/A for this kernel
|
static size_t DefaultK() { return 1; } // N/A for this kernel
|
||||||
|
static size_t DefaultBatchCount() { return 1; } // N/A for this kernel
|
||||||
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
static double DefaultFraction() { return 1.0; } // N/A for this kernel
|
||||||
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
static size_t DefaultNumRuns() { return 2; } // run every kernel this many times for averaging
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,7 @@ void Tuner(int argc, char* argv[]) {
|
||||||
if (o == kArgAlpha) { args.alpha = GetArgument(command_line_args, help, kArgAlpha, GetScalar<T>()); }
|
if (o == kArgAlpha) { args.alpha = GetArgument(command_line_args, help, kArgAlpha, GetScalar<T>()); }
|
||||||
if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<T>()); }
|
if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<T>()); }
|
||||||
if (o == kArgFraction) { args.fraction = GetArgument(command_line_args, help, kArgFraction, C::DefaultFraction()); }
|
if (o == kArgFraction) { args.fraction = GetArgument(command_line_args, help, kArgFraction, C::DefaultFraction()); }
|
||||||
|
if (o == kArgBatchCount) { args.batch_count = GetArgument(command_line_args, help, kArgBatchCount, C::DefaultBatchCount()); }
|
||||||
}
|
}
|
||||||
const auto num_runs = GetArgument(command_line_args, help, kArgNumRuns, C::DefaultNumRuns());
|
const auto num_runs = GetArgument(command_line_args, help, kArgNumRuns, C::DefaultNumRuns());
|
||||||
|
|
||||||
|
@ -158,6 +159,7 @@ void Tuner(int argc, char* argv[]) {
|
||||||
if (o == kArgK) { metadata.push_back({"arg_k", std::to_string(args.k)}); }
|
if (o == kArgK) { metadata.push_back({"arg_k", std::to_string(args.k)}); }
|
||||||
if (o == kArgAlpha) { metadata.push_back({"arg_alpha", ToString(args.alpha)}); }
|
if (o == kArgAlpha) { metadata.push_back({"arg_alpha", ToString(args.alpha)}); }
|
||||||
if (o == kArgBeta) { metadata.push_back({"arg_beta", ToString(args.beta)}); }
|
if (o == kArgBeta) { metadata.push_back({"arg_beta", ToString(args.beta)}); }
|
||||||
|
if (o == kArgBatchCount) { metadata.push_back({"arg_batch_count", ToString(args.batch_count)}); }
|
||||||
}
|
}
|
||||||
tuner.PrintJSON("clblast_"+C::KernelFamily()+"_"+precision_string+".json", metadata);
|
tuner.PrintJSON("clblast_"+C::KernelFamily()+"_"+precision_string+".json", metadata);
|
||||||
}
|
}
|
||||||
|
|
30
test/correctness/routines/levelx/xgemmbatched.cpp
Normal file
30
test/correctness/routines/levelx/xgemmbatched.cpp
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||||
|
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||||
|
// width of 100 characters per line.
|
||||||
|
//
|
||||||
|
// Author(s):
|
||||||
|
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||||
|
//
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
#include "test/correctness/testblas.hpp"
|
||||||
|
#include "test/routines/levelx/xgemmbatched.hpp"
|
||||||
|
|
||||||
|
// Shortcuts to the clblast namespace
|
||||||
|
using float2 = clblast::float2;
|
||||||
|
using double2 = clblast::double2;
|
||||||
|
|
||||||
|
// Main function (not within the clblast namespace)
|
||||||
|
int main(int argc, char *argv[]) {
|
||||||
|
auto errors = size_t{0};
|
||||||
|
errors += clblast::RunTests<clblast::TestXgemmBatched<float>, float, float>(argc, argv, false, "SGEMMBATCHED");
|
||||||
|
errors += clblast::RunTests<clblast::TestXgemmBatched<double>, double, double>(argc, argv, true, "DGEMMBATCHED");
|
||||||
|
errors += clblast::RunTests<clblast::TestXgemmBatched<float2>, float2, float2>(argc, argv, true, "CGEMMBATCHED");
|
||||||
|
errors += clblast::RunTests<clblast::TestXgemmBatched<double2>, double2, double2>(argc, argv, true, "ZGEMMBATCHED");
|
||||||
|
errors += clblast::RunTests<clblast::TestXgemmBatched<half>, half, half>(argc, argv, true, "HGEMMBATCHED");
|
||||||
|
if (errors > 0) { return 1; } else { return 0; }
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
|
@ -21,6 +21,8 @@
|
||||||
namespace clblast {
|
namespace clblast {
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
|
template <typename T, typename U> const auto TestBlas<T,U>::kSeed = 42; // fixed seed for reproducibility
|
||||||
|
|
||||||
// Test settings for the regular test. Append to these lists in case more tests are required.
|
// Test settings for the regular test. Append to these lists in case more tests are required.
|
||||||
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kVectorDims = { 7, 93, 4096 };
|
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kVectorDims = { 7, 93, 4096 };
|
||||||
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kIncrements = { 1, 2, 7 };
|
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kIncrements = { 1, 2, 7 };
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace clblast {
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
class TestBlas: public Tester<T,U> {
|
class TestBlas: public Tester<T,U> {
|
||||||
public:
|
public:
|
||||||
static constexpr auto kSeed = 42; // fixed seed for reproducibility
|
static const int kSeed;
|
||||||
|
|
||||||
// Uses several variables from the Tester class
|
// Uses several variables from the Tester class
|
||||||
using Tester<T,U>::context_;
|
using Tester<T,U>::context_;
|
||||||
|
|
|
@ -24,6 +24,8 @@
|
||||||
namespace clblast {
|
namespace clblast {
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
|
template <typename T, typename U> const int Client<T,U>::kSeed = 42; // fixed seed for reproducibility
|
||||||
|
|
||||||
// Constructor
|
// Constructor
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
Client<T,U>::Client(const Routine run_routine,
|
Client<T,U>::Client(const Routine run_routine,
|
||||||
|
|
|
@ -40,7 +40,7 @@ namespace clblast {
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
class Client {
|
class Client {
|
||||||
public:
|
public:
|
||||||
static constexpr auto kSeed = 42; // fixed seed for reproducibility
|
static const int kSeed;
|
||||||
|
|
||||||
// Shorthand for the routine-specific functions passed to the tester
|
// Shorthand for the routine-specific functions passed to the tester
|
||||||
using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
|
using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
|
||||||
|
|
37
test/performance/routines/levelx/xgemmbatched.cpp
Normal file
37
test/performance/routines/levelx/xgemmbatched.cpp
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||||
|
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||||
|
// width of 100 characters per line.
|
||||||
|
//
|
||||||
|
// Author(s):
|
||||||
|
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||||
|
//
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
#include "test/performance/client.hpp"
|
||||||
|
#include "test/routines/levelx/xgemmbatched.hpp"
|
||||||
|
|
||||||
|
// Shortcuts to the clblast namespace
|
||||||
|
using float2 = clblast::float2;
|
||||||
|
using double2 = clblast::double2;
|
||||||
|
|
||||||
|
// Main function (not within the clblast namespace)
|
||||||
|
int main(int argc, char *argv[]) {
|
||||||
|
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
|
||||||
|
switch(clblast::GetPrecision(command_line_args, clblast::Precision::kSingle)) {
|
||||||
|
case clblast::Precision::kHalf:
|
||||||
|
clblast::RunClient<clblast::TestXgemmBatched<half>, half, half>(argc, argv); break;
|
||||||
|
case clblast::Precision::kSingle:
|
||||||
|
clblast::RunClient<clblast::TestXgemmBatched<float>, float, float>(argc, argv); break;
|
||||||
|
case clblast::Precision::kDouble:
|
||||||
|
clblast::RunClient<clblast::TestXgemmBatched<double>, double, double>(argc, argv); break;
|
||||||
|
case clblast::Precision::kComplexSingle:
|
||||||
|
clblast::RunClient<clblast::TestXgemmBatched<float2>, float2, float2>(argc, argv); break;
|
||||||
|
case clblast::Precision::kComplexDouble:
|
||||||
|
clblast::RunClient<clblast::TestXgemmBatched<double2>, double2, double2>(argc, argv); break;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
207
test/routines/levelx/xgemmbatched.hpp
Normal file
207
test/routines/levelx/xgemmbatched.hpp
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||||
|
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||||
|
// width of 100 characters per line.
|
||||||
|
//
|
||||||
|
// Author(s):
|
||||||
|
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||||
|
//
|
||||||
|
// This file implements a class with static methods to describe the XgemmBatched routine. Examples of
|
||||||
|
// such 'descriptions' are how to calculate the size a of buffer or how to run the routine. These
|
||||||
|
// static methods are used by the correctness tester and the performance tester.
|
||||||
|
//
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
#ifndef CLBLAST_TEST_ROUTINES_XGEMMBATCHED_H_
|
||||||
|
#define CLBLAST_TEST_ROUTINES_XGEMMBATCHED_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#ifdef CLBLAST_REF_CLBLAS
|
||||||
|
#include "test/wrapper_clblas.hpp"
|
||||||
|
#endif
|
||||||
|
#ifdef CLBLAST_REF_CBLAS
|
||||||
|
#include "test/wrapper_cblas.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace clblast {
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// See comment at top of file for a description of the class
|
||||||
|
template <typename T>
|
||||||
|
class TestXgemmBatched {
|
||||||
|
public:
|
||||||
|
|
||||||
|
// Although it is a non-BLAS routine, it can still be tested against level-3 routines in a loop
|
||||||
|
static size_t BLASLevel() { return 3; }
|
||||||
|
|
||||||
|
// The list of arguments relevant for this routine
|
||||||
|
static std::vector<std::string> GetOptions() {
|
||||||
|
return {kArgM, kArgN, kArgK,
|
||||||
|
kArgLayout, kArgATransp, kArgBTransp,
|
||||||
|
kArgALeadDim, kArgBLeadDim, kArgCLeadDim,
|
||||||
|
kArgAOffset, kArgBOffset, kArgCOffset,
|
||||||
|
kArgBatchCount, kArgAlpha, kArgBeta};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper for the sizes per batch
|
||||||
|
static size_t PerBatchSizeA(const Arguments<T> &args) {
|
||||||
|
auto a_rotated = (args.layout == Layout::kColMajor && args.a_transpose != Transpose::kNo) ||
|
||||||
|
(args.layout == Layout::kRowMajor && args.a_transpose == Transpose::kNo);
|
||||||
|
auto a_two = (a_rotated) ? args.m : args.k;
|
||||||
|
return a_two * args.a_ld;
|
||||||
|
}
|
||||||
|
static size_t PerBatchSizeB(const Arguments<T> &args) {
|
||||||
|
auto b_rotated = (args.layout == Layout::kColMajor && args.b_transpose != Transpose::kNo) ||
|
||||||
|
(args.layout == Layout::kRowMajor && args.b_transpose == Transpose::kNo);
|
||||||
|
auto b_two = (b_rotated) ? args.k : args.n;
|
||||||
|
return b_two * args.b_ld;
|
||||||
|
}
|
||||||
|
static size_t PerBatchSizeC(const Arguments<T> &args) {
|
||||||
|
auto c_rotated = (args.layout == Layout::kRowMajor);
|
||||||
|
auto c_two = (c_rotated) ? args.m : args.n;
|
||||||
|
return c_two * args.c_ld;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes how to obtain the sizes of the buffers
|
||||||
|
static size_t GetSizeA(const Arguments<T> &args) {
|
||||||
|
return PerBatchSizeA(args) * args.batch_count + args.a_offset;
|
||||||
|
}
|
||||||
|
static size_t GetSizeB(const Arguments<T> &args) {
|
||||||
|
return PerBatchSizeB(args) * args.batch_count + args.b_offset;
|
||||||
|
}
|
||||||
|
static size_t GetSizeC(const Arguments<T> &args) {
|
||||||
|
return PerBatchSizeC(args) * args.batch_count + args.c_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes how to set the sizes of all the buffers
|
||||||
|
static void SetSizes(Arguments<T> &args) {
|
||||||
|
args.a_size = GetSizeA(args);
|
||||||
|
args.b_size = GetSizeB(args);
|
||||||
|
args.c_size = GetSizeC(args);
|
||||||
|
|
||||||
|
// Also sets the batch-related variables
|
||||||
|
args.a_offsets = std::vector<size_t>(args.batch_count);
|
||||||
|
args.b_offsets = std::vector<size_t>(args.batch_count);
|
||||||
|
args.c_offsets = std::vector<size_t>(args.batch_count);
|
||||||
|
args.alphas = std::vector<T>(args.batch_count);
|
||||||
|
args.betas = std::vector<T>(args.batch_count);
|
||||||
|
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
|
||||||
|
args.a_offsets[batch] = batch * PerBatchSizeA(args) + args.a_offset;
|
||||||
|
args.b_offsets[batch] = batch * PerBatchSizeB(args) + args.b_offset;
|
||||||
|
args.c_offsets[batch] = batch * PerBatchSizeC(args) + args.c_offset;
|
||||||
|
args.alphas[batch] = args.alpha + Constant<T>(batch);
|
||||||
|
args.betas[batch] = args.beta + Constant<T>(batch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes what the default values of the leading dimensions of the matrices are
|
||||||
|
static size_t DefaultLDA(const Arguments<T> &args) { return args.k; }
|
||||||
|
static size_t DefaultLDB(const Arguments<T> &args) { return args.n; }
|
||||||
|
static size_t DefaultLDC(const Arguments<T> &args) { return args.n; }
|
||||||
|
|
||||||
|
// Describes which transpose options are relevant for this routine
|
||||||
|
using Transposes = std::vector<Transpose>;
|
||||||
|
static Transposes GetATransposes(const Transposes &all) { return all; }
|
||||||
|
static Transposes GetBTransposes(const Transposes &all) { return all; }
|
||||||
|
|
||||||
|
// Describes how to prepare the input data
|
||||||
|
static void PrepareData(const Arguments<T>&, Queue&, const int, std::vector<T>&,
|
||||||
|
std::vector<T>&, std::vector<T>&, std::vector<T>&, std::vector<T>&,
|
||||||
|
std::vector<T>&, std::vector<T>&) {} // N/A for this routine
|
||||||
|
|
||||||
|
// Describes how to run the CLBlast routine
|
||||||
|
static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
|
||||||
|
auto queue_plain = queue();
|
||||||
|
auto event = cl_event{};
|
||||||
|
auto status = GemmBatched(args.layout, args.a_transpose, args.b_transpose,
|
||||||
|
args.m, args.n, args.k, args.alphas.data(),
|
||||||
|
buffers.a_mat(), args.a_offsets.data(), args.a_ld,
|
||||||
|
buffers.b_mat(), args.b_offsets.data(), args.b_ld, args.betas.data(),
|
||||||
|
buffers.c_mat(), args.c_offsets.data(), args.c_ld,
|
||||||
|
args.batch_count,
|
||||||
|
&queue_plain, &event);
|
||||||
|
if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes how to run the clBLAS routine (for correctness/performance comparison)
|
||||||
|
#ifdef CLBLAST_REF_CLBLAS
|
||||||
|
static StatusCode RunReference1(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
|
||||||
|
auto queue_plain = queue();
|
||||||
|
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
|
||||||
|
auto event = cl_event{};
|
||||||
|
auto status = clblasXgemm(convertToCLBLAS(args.layout),
|
||||||
|
convertToCLBLAS(args.a_transpose),
|
||||||
|
convertToCLBLAS(args.b_transpose),
|
||||||
|
args.m, args.n, args.k, args.alphas[batch],
|
||||||
|
buffers.a_mat, args.a_offsets[batch], args.a_ld,
|
||||||
|
buffers.b_mat, args.b_offsets[batch], args.b_ld, args.betas[batch],
|
||||||
|
buffers.c_mat, args.c_offsets[batch], args.c_ld,
|
||||||
|
1, &queue_plain, 0, nullptr, &event);
|
||||||
|
clWaitForEvents(1, &event);
|
||||||
|
if (static_cast<StatusCode>(status) != StatusCode::kSuccess) {
|
||||||
|
return static_cast<StatusCode>(status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return StatusCode::kSuccess;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Describes how to run the CPU BLAS routine (for correctness/performance comparison)
|
||||||
|
#ifdef CLBLAST_REF_CBLAS
|
||||||
|
static StatusCode RunReference2(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
|
||||||
|
std::vector<T> a_mat_cpu(args.a_size, static_cast<T>(0));
|
||||||
|
std::vector<T> b_mat_cpu(args.b_size, static_cast<T>(0));
|
||||||
|
std::vector<T> c_mat_cpu(args.c_size, static_cast<T>(0));
|
||||||
|
buffers.a_mat.Read(queue, args.a_size, a_mat_cpu);
|
||||||
|
buffers.b_mat.Read(queue, args.b_size, b_mat_cpu);
|
||||||
|
buffers.c_mat.Read(queue, args.c_size, c_mat_cpu);
|
||||||
|
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
|
||||||
|
cblasXgemm(convertToCBLAS(args.layout),
|
||||||
|
convertToCBLAS(args.a_transpose),
|
||||||
|
convertToCBLAS(args.b_transpose),
|
||||||
|
args.m, args.n, args.k, args.alphas[batch],
|
||||||
|
a_mat_cpu, args.a_offsets[batch], args.a_ld,
|
||||||
|
b_mat_cpu, args.b_offsets[batch], args.b_ld, args.betas[batch],
|
||||||
|
c_mat_cpu, args.c_offsets[batch], args.c_ld);
|
||||||
|
}
|
||||||
|
buffers.c_mat.Write(queue, args.c_size, c_mat_cpu);
|
||||||
|
return StatusCode::kSuccess;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Describes how to download the results of the computation (more importantly: which buffer)
|
||||||
|
static std::vector<T> DownloadResult(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
|
||||||
|
std::vector<T> result(args.c_size, static_cast<T>(0));
|
||||||
|
buffers.c_mat.Read(queue, args.c_size, result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes how to compute the indices of the result buffer
|
||||||
|
static size_t ResultID1(const Arguments<T> &args) { return args.m; }
|
||||||
|
static size_t ResultID2(const Arguments<T> &args) { return args.n * args.batch_count; }
|
||||||
|
static size_t GetResultIndex(const Arguments<T> &args, const size_t id1, const size_t id2_3) {
|
||||||
|
const size_t id2 = id2_3 % args.n;
|
||||||
|
const size_t id3 = id2_3 / args.n;
|
||||||
|
return (args.layout == Layout::kRowMajor) ?
|
||||||
|
id1*args.c_ld + id2 + args.c_offsets[id3]:
|
||||||
|
id2*args.c_ld + id1 + args.c_offsets[id3];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes how to compute performance metrics
|
||||||
|
static size_t GetFlops(const Arguments<T> &args) {
|
||||||
|
return args.batch_count * (2 * args.m * args.n * args.k);
|
||||||
|
}
|
||||||
|
static size_t GetBytes(const Arguments<T> &args) {
|
||||||
|
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
} // namespace clblast
|
||||||
|
|
||||||
|
// CLBLAST_TEST_ROUTINES_XGEMMBATCHED_H_
|
||||||
|
#endif
|
Loading…
Reference in a new issue