Added API and tests for new GemmStridedBatched routine

pull/239/head
Cedric Nugteren 2018-01-07 14:27:15 +01:00
parent 0c48c6e6c4
commit 9fb2c61b25
17 changed files with 1182 additions and 80 deletions

View File

@ -202,7 +202,7 @@ set(LEVEL1_ROUTINES xswap xscal xcopy xaxpy xdot xdotu xdotc xnrm2 xasum xamax)
set(LEVEL2_ROUTINES xgemv xgbmv xhemv xhbmv xhpmv xsymv xsbmv xspmv xtrmv xtbmv xtpmv xtrsv set(LEVEL2_ROUTINES xgemv xgbmv xhemv xhbmv xhpmv xsymv xsbmv xspmv xtrmv xtbmv xtpmv xtrsv
xger xgeru xgerc xher xhpr xher2 xhpr2 xsyr xspr xsyr2 xspr2) xger xgeru xgerc xher xhpr xher2 xhpr2 xsyr xspr xsyr2 xspr2)
set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm xtrsm) set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm xtrsm)
set(LEVELX_ROUTINES xomatcopy xim2col xaxpybatched xgemmbatched) set(LEVELX_ROUTINES xomatcopy xim2col xaxpybatched xgemmbatched xgemmstridedbatched)
set(ROUTINES ${LEVEL1_ROUTINES} ${LEVEL2_ROUTINES} ${LEVEL3_ROUTINES} ${LEVELX_ROUTINES}) set(ROUTINES ${LEVEL1_ROUTINES} ${LEVEL2_ROUTINES} ${LEVEL3_ROUTINES} ${LEVELX_ROUTINES})
set(PRECISIONS 32 64 3232 6464 16) set(PRECISIONS 32 64 3232 6464 16)

View File

@ -3182,6 +3182,108 @@ Requirements for GEMMBATCHED:
xGEMMSTRIDEDBATCHED: StridedBatched version of GEMM
-------------
As GEMM, but multiple strided operations are batched together for better performance.
C++ API:
```
template <typename T>
StatusCode GemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const T alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const T beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
```
C API:
```
CLBlastStatusCode CLBlastSgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const float alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const float beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastDgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const double alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const double beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastCgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_float2 alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const cl_float2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastZgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_double2 alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const cl_double2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastHgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_half alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const cl_half beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
```
Arguments to GEMMSTRIDEDBATCHED:
* `const Layout layout`: Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.
* `const Transpose a_transpose`: Transposing the input matrix A, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.
* `const Transpose b_transpose`: Transposing the input matrix B, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.
* `const size_t m`: Integer size argument. This value must be positive.
* `const size_t n`: Integer size argument. This value must be positive.
* `const size_t k`: Integer size argument. This value must be positive.
* `const T alpha`: Input scalar constant.
* `const cl_mem a_buffer`: OpenCL buffer to store the input A matrix.
* `const size_t a_offset`: The offset in elements from the start of the input A matrix.
* `const size_t a_ld`: Leading dimension of the input A matrix. This value must be greater than 0.
* `const size_t a_stride`: The (fixed) stride between two batches of the A matrix.
* `const cl_mem b_buffer`: OpenCL buffer to store the input B matrix.
* `const size_t b_offset`: The offset in elements from the start of the input B matrix.
* `const size_t b_ld`: Leading dimension of the input B matrix. This value must be greater than 0.
* `const size_t b_stride`: The (fixed) stride between two batches of the B matrix.
* `const T beta`: Input scalar constant.
* `cl_mem c_buffer`: OpenCL buffer to store the output C matrix.
* `const size_t c_offset`: The offset in elements from the start of the output C matrix.
* `const size_t c_ld`: Leading dimension of the output C matrix. This value must be greater than 0.
* `const size_t c_stride`: The (fixed) stride between two batches of the C matrix.
* `const size_t batch_count`: Number of batches. This value must be positive.
* `cl_command_queue* queue`: Pointer to an OpenCL command queue associated with a context and device to execute the routine on.
* `cl_event* event`: Pointer to an OpenCL event to be able to wait for completion of the routine's OpenCL kernel(s). This is an optional argument.
Requirements for GEMMSTRIDEDBATCHED:
* When `transpose_a == Transpose::kNo`, then `a_ld` must be at least `m`, otherwise `a_ld` must be at least `k`.
* When `transpose_b == Transpose::kNo`, then `b_ld` must be at least `k`, otherwise `b_ld` must be at least `n`.
* The value of `c_ld` must be at least `m`.
ClearCache: Resets the cache of compiled binaries (auxiliary function) ClearCache: Resets the cache of compiled binaries (auxiliary function)
------------- -------------

View File

@ -647,6 +647,18 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
const size_t batch_count, const size_t batch_count,
cl_command_queue* queue, cl_event* event = nullptr); cl_command_queue* queue, cl_event* event = nullptr);
// StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED
template <typename T>
StatusCode GemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const T alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const T beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event = nullptr);
// ================================================================================================= // =================================================================================================
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) // Retrieves the required size of the temporary buffer for the GEMM kernel (optional)

View File

@ -1451,6 +1451,53 @@ CLBlastStatusCode PUBLIC_API CLBlastHgemmBatched(const CLBlastLayout layout, con
const size_t batch_count, const size_t batch_count,
cl_command_queue* queue, cl_event* event); cl_command_queue* queue, cl_event* event);
// StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED
CLBlastStatusCode PUBLIC_API CLBlastSgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const float alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const float beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastDgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const double alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const double beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastCgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_float2 alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const cl_float2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastZgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_double2 alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const cl_double2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastHgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_half alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const cl_half beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
// ================================================================================================= // =================================================================================================
// CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on // CLBlast stores binaries of compiled kernels into a cache in case the same kernel is used later on

View File

@ -619,6 +619,18 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
const size_t batch_count, const size_t batch_count,
const CUcontext context, const CUdevice device); const CUcontext context, const CUdevice device);
// StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED
template <typename T>
StatusCode GemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const T alpha,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
const CUcontext context, const CUdevice device);
// ================================================================================================= // =================================================================================================
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) // Retrieves the required size of the temporary buffer for the GEMM kernel (optional)

View File

@ -109,71 +109,72 @@ col = "height * width * channels"
im2col_constants = ["channels", "height", "width", "kernel_h", "kernel_w", "pad_h", "pad_w", "stride_h", "stride_w", "dilation_h", "dilation_w"] im2col_constants = ["channels", "height", "width", "kernel_h", "kernel_w", "pad_h", "pad_w", "stride_h", "stride_w", "dilation_h", "dilation_w"]
ROUTINES = [ ROUTINES = [
[ # Level 1: vector-vector [ # Level 1: vector-vector
Routine(False, True, False, False, "1", "rotg", T, [S,D], [], [], [], ["sa","sb","sc","ss"], ["1","1","1","1"], [], "", "Generate givens plane rotation", "", []), Routine(False, True, 0, False, "1", "rotg", T, [S,D], [], [], [], ["sa","sb","sc","ss"], ["1","1","1","1"], [], "", "Generate givens plane rotation", "", []),
Routine(False, True, False, False, "1", "rotmg", T, [S,D], [], [], ["sy1"], ["sd1","sd2","sx1","sparam"], ["1","1","1","1","1"], [], "", "Generate modified givens plane rotation", "", []), Routine(False, True, 0, False, "1", "rotmg", T, [S,D], [], [], ["sy1"], ["sd1","sd2","sx1","sparam"], ["1","1","1","1","1"], [], "", "Generate modified givens plane rotation", "", []),
Routine(False, True, False, False, "1", "rot", T, [S,D], ["n"], [], [], ["x","y"], [xn,yn], ["cos","sin"],"", "Apply givens plane rotation", "", []), Routine(False, True, 0, False, "1", "rot", T, [S,D], ["n"], [], [], ["x","y"], [xn,yn], ["cos","sin"],"", "Apply givens plane rotation", "", []),
Routine(False, True, False, False, "1", "rotm", T, [S,D], ["n"], [], [], ["x","y","sparam"], [xn,yn,"1"], [], "", "Apply modified givens plane rotation", "", []), Routine(False, True, 0, False, "1", "rotm", T, [S,D], ["n"], [], [], ["x","y","sparam"], [xn,yn,"1"], [], "", "Apply modified givens plane rotation", "", []),
Routine(True, True, False, False, "1", "swap", T, [S,D,C,Z,H], ["n"], [], [], ["x","y"], [xn,yn], [], "", "Swap two vectors", "Interchanges _n_ elements of vectors _x_ and _y_.", []), Routine(True, True, 0, False, "1", "swap", T, [S,D,C,Z,H], ["n"], [], [], ["x","y"], [xn,yn], [], "", "Swap two vectors", "Interchanges _n_ elements of vectors _x_ and _y_.", []),
Routine(True, True, False, False, "1", "scal", T, [S,D,C,Z,H], ["n"], [], [], ["x"], [xn], ["alpha"], "", "Vector scaling", "Multiplies _n_ elements of vector _x_ by a scalar constant _alpha_.", []), Routine(True, True, 0, False, "1", "scal", T, [S,D,C,Z,H], ["n"], [], [], ["x"], [xn], ["alpha"], "", "Vector scaling", "Multiplies _n_ elements of vector _x_ by a scalar constant _alpha_.", []),
Routine(True, True, False, False, "1", "copy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], [], "", "Vector copy", "Copies the contents of vector _x_ into vector _y_.", []), Routine(True, True, 0, False, "1", "copy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], [], "", "Vector copy", "Copies the contents of vector _x_ into vector _y_.", []),
Routine(True, True, False, False, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation _y = alpha * x + y_, in which _x_ and _y_ are vectors and _alpha_ is a scalar constant.", []), Routine(True, True, 0, False, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation _y = alpha * x + y_, in which _x_ and _y_ are vectors and _alpha_ is a scalar constant.", []),
Routine(True, True, False, False, "1", "dot", T, [S,D,H], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two vectors", "Multiplies _n_ elements of the vectors _x_ and _y_ element-wise and accumulates the results. The sum is stored in the _dot_ buffer.", []), Routine(True, True, 0, False, "1", "dot", T, [S,D,H], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two vectors", "Multiplies _n_ elements of the vectors _x_ and _y_ element-wise and accumulates the results. The sum is stored in the _dot_ buffer.", []),
Routine(True, True, False, False, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []), Routine(True, True, 0, False, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []),
Routine(True, True, False, False, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []), Routine(True, True, 0, False, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [xn,yn,"1"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []),
Routine(True, True, False, False, "1", "nrm2", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["nrm2"], [xn,"1"], [], "2*n", "Euclidian norm of a vector", "Accumulates the square of _n_ elements in the _x_ vector and takes the square root. The resulting L2 norm is stored in the _nrm2_ buffer.", []), Routine(True, True, 0, False, "1", "nrm2", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["nrm2"], [xn,"1"], [], "2*n", "Euclidian norm of a vector", "Accumulates the square of _n_ elements in the _x_ vector and takes the square root. The resulting L2 norm is stored in the _nrm2_ buffer.", []),
Routine(True, True, False, False, "1", "asum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["asum"], [xn,"1"], [], "n", "Absolute sum of values in a vector", "Accumulates the absolute value of _n_ elements in the _x_ vector. The results are stored in the _asum_ buffer.", []), Routine(True, True, 0, False, "1", "asum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["asum"], [xn,"1"], [], "n", "Absolute sum of values in a vector", "Accumulates the absolute value of _n_ elements in the _x_ vector. The results are stored in the _asum_ buffer.", []),
Routine(True, False, False, False, "1", "sum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["sum"], [xn,"1"], [], "n", "Sum of values in a vector (non-BLAS function)", "Accumulates the values of _n_ elements in the _x_ vector. The results are stored in the _sum_ buffer. This routine is the non-absolute version of the xASUM BLAS routine.", []), Routine(True, False, 0, False, "1", "sum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["sum"], [xn,"1"], [], "n", "Sum of values in a vector (non-BLAS function)", "Accumulates the values of _n_ elements in the _x_ vector. The results are stored in the _sum_ buffer. This routine is the non-absolute version of the xASUM BLAS routine.", []),
Routine(True, True, False, False, "1", "amax", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [xn,"1"], [], "2*n", "Index of absolute maximum value in a vector", "Finds the index of the maximum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer.", []), Routine(True, True, 0, False, "1", "amax", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [xn,"1"], [], "2*n", "Index of absolute maximum value in a vector", "Finds the index of the maximum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer.", []),
Routine(True, False, False, False, "1", "amin", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [xn,"1"], [], "2*n", "Index of absolute minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer.", []), Routine(True, False, 0, False, "1", "amin", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [xn,"1"], [], "2*n", "Index of absolute minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer.", []),
Routine(True, False, False, False, "1", "max", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [xn,"1"], [], "2*n", "Index of maximum value in a vector (non-BLAS function)", "Finds the index of the maximum of the values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer. This routine is the non-absolute version of the IxAMAX BLAS routine.", []), Routine(True, False, 0, False, "1", "max", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [xn,"1"], [], "2*n", "Index of maximum value in a vector (non-BLAS function)", "Finds the index of the maximum of the values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer. This routine is the non-absolute version of the IxAMAX BLAS routine.", []),
Routine(True, False, False, False, "1", "min", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [xn,"1"], [], "2*n", "Index of minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer. This routine is the non-absolute minimum version of the IxAMAX BLAS routine.", []), Routine(True, False, 0, False, "1", "min", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [xn,"1"], [], "2*n", "Index of minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer. This routine is the non-absolute minimum version of the IxAMAX BLAS routine.", []),
], ],
[ # Level 2: matrix-vector [ # Level 2: matrix-vector
Routine(True, True, False, False, "2a", "gemv", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a","x"], ["y"], [amn,xmn,ynm], ["alpha","beta"], "", "General matrix-vector multiplication", "Performs the operation _y = alpha * A * x + beta * y_, in which _x_ is an input vector, _y_ is an input and output vector, _A_ is an input matrix, and _alpha_ and _beta_ are scalars. The matrix _A_ can optionally be transposed before performing the operation.", [ald_m]), Routine(True, True, 0, False, "2a", "gemv", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a","x"], ["y"], [amn,xmn,ynm], ["alpha","beta"], "", "General matrix-vector multiplication", "Performs the operation _y = alpha * A * x + beta * y_, in which _x_ is an input vector, _y_ is an input and output vector, _A_ is an input matrix, and _alpha_ and _beta_ are scalars. The matrix _A_ can optionally be transposed before performing the operation.", [ald_m]),
Routine(True, True, False, False, "2a", "gbmv", T, [S,D,C,Z,H], ["m","n","kl","ku"], ["layout","a_transpose"], ["a","x"], ["y"], [amn,xmn,ynm], ["alpha","beta"], "", "General banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is banded instead.", [ald_kl_ku_one]), Routine(True, True, 0, False, "2a", "gbmv", T, [S,D,C,Z,H], ["m","n","kl","ku"], ["layout","a_transpose"], ["a","x"], ["y"], [amn,xmn,ynm], ["alpha","beta"], "", "General banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is banded instead.", [ald_kl_ku_one]),
Routine(True, True, False, False, "2a", "hemv", T, [C,Z], ["n"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Hermitian matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian matrix instead.", [ald_n]), Routine(True, True, 0, False, "2a", "hemv", T, [C,Z], ["n"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Hermitian matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian matrix instead.", [ald_n]),
Routine(True, True, False, False, "2a", "hbmv", T, [C,Z], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Hermitian banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian banded matrix instead.", [ald_k_one]), Routine(True, True, 0, False, "2a", "hbmv", T, [C,Z], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Hermitian banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian banded matrix instead.", [ald_k_one]),
Routine(True, True, False, False, "2a", "hpmv", T, [C,Z], ["n"], ["layout","triangle"], ["ap","x"], ["y"], [apn,xn,yn], ["alpha","beta"], "", "Hermitian packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []), Routine(True, True, 0, False, "2a", "hpmv", T, [C,Z], ["n"], ["layout","triangle"], ["ap","x"], ["y"], [apn,xn,yn], ["alpha","beta"], "", "Hermitian packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
Routine(True, True, False, False, "2a", "symv", T, [S,D,H], ["n"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Symmetric matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric instead.", [ald_n]), Routine(True, True, 0, False, "2a", "symv", T, [S,D,H], ["n"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Symmetric matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric instead.", [ald_n]),
Routine(True, True, False, False, "2a", "sbmv", T, [S,D,H], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Symmetric banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric and banded instead.", [ald_k_one]), Routine(True, True, 0, False, "2a", "sbmv", T, [S,D,H], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], [an,xn,yn], ["alpha","beta"], "", "Symmetric banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric and banded instead.", [ald_k_one]),
Routine(True, True, False, False, "2a", "spmv", T, [S,D,H], ["n"], ["layout","triangle"], ["ap","x"], ["y"], [apn,xn,yn], ["alpha","beta"], "", "Symmetric packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []), Routine(True, True, 0, False, "2a", "spmv", T, [S,D,H], ["n"], ["layout","triangle"], ["ap","x"], ["y"], [apn,xn,yn], ["alpha","beta"], "", "Symmetric packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
Routine(True, True, False, False, "2a", "trmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "n", "Triangular matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular instead.", [ald_n]), Routine(True, True, 0, False, "2a", "trmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "n", "Triangular matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular instead.", [ald_n]),
Routine(True, True, False, False, "2a", "tbmv", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "n", "Triangular banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular and banded instead.", [ald_k_one]), Routine(True, True, 0, False, "2a", "tbmv", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "n", "Triangular banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular and banded instead.", [ald_k_one]),
Routine(True, True, False, False, "2a", "tpmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [apn,xn], [], "n", "Triangular packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a triangular packed matrix instead and repreented as _AP_.", []), Routine(True, True, 0, False, "2a", "tpmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [apn,xn], [], "n", "Triangular packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a triangular packed matrix instead and repreented as _AP_.", []),
Routine(True, True, False, False, "2a", "trsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "", "Solves a triangular system of equations", "", []), Routine(True, True, 0, False, "2a", "trsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "", "Solves a triangular system of equations", "", []),
Routine(False, True, False, False, "2a", "tbsv", T, [S,D,C,Z], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "", "Solves a banded triangular system of equations", "", [ald_k_one]), Routine(False, True, 0, False, "2a", "tbsv", T, [S,D,C,Z], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [an,xn], [], "", "Solves a banded triangular system of equations", "", [ald_k_one]),
Routine(False, True, False, False, "2a", "tpsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [apn,xn], [], "", "Solves a packed triangular system of equations", "", []), Routine(False, True, 0, False, "2a", "tpsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [apn,xn], [], "", "Solves a packed triangular system of equations", "", []),
# Level 2: matrix update # Level 2: matrix update
Routine(True, True, False, False, "2b", "ger", T, [S,D,H], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 matrix update", "Performs the operation _A = alpha * x * y^T + A_, in which _x_ is an input vector, _y^T_ is the transpose of the input vector _y_, _A_ is the matrix to be updated, and _alpha_ is a scalar value.", [ald_m]), Routine(True, True, 0, False, "2b", "ger", T, [S,D,H], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 matrix update", "Performs the operation _A = alpha * x * y^T + A_, in which _x_ is an input vector, _y^T_ is the transpose of the input vector _y_, _A_ is the matrix to be updated, and _alpha_ is a scalar value.", [ald_m]),
Routine(True, True, False, False, "2b", "geru", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 complex matrix update", "Same operation as xGER, but with complex data-types.", [ald_m]), Routine(True, True, 0, False, "2b", "geru", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 complex matrix update", "Same operation as xGER, but with complex data-types.", [ald_m]),
Routine(True, True, False, False, "2b", "gerc", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 complex conjugated matrix update", "Same operation as xGERU, but the update is done based on the complex conjugate of the input vectors.", [ald_m]), Routine(True, True, 0, False, "2b", "gerc", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], [xm,yn,amn], ["alpha"], "", "General rank-1 complex conjugated matrix update", "Same operation as xGERU, but the update is done based on the complex conjugate of the input vectors.", [ald_m]),
Routine(True, True, False, False, "2b", "her", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["a"], [xn,an], ["alpha"], "", "Hermitian rank-1 matrix update", "Performs the operation _A = alpha * x * x^T + A_, in which x is an input vector, x^T is the transpose of this vector, _A_ is the triangular Hermetian matrix to be updated, and alpha is a scalar value.", [ald_n]), Routine(True, True, 0, False, "2b", "her", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["a"], [xn,an], ["alpha"], "", "Hermitian rank-1 matrix update", "Performs the operation _A = alpha * x * x^T + A_, in which x is an input vector, x^T is the transpose of this vector, _A_ is the triangular Hermetian matrix to be updated, and alpha is a scalar value.", [ald_n]),
Routine(True, True, False, False, "2b", "hpr", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["ap"], [xn,apn], ["alpha"], "", "Hermitian packed rank-1 matrix update", "Same operation as xHER, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []), Routine(True, True, 0, False, "2b", "hpr", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["ap"], [xn,apn], ["alpha"], "", "Hermitian packed rank-1 matrix update", "Same operation as xHER, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
Routine(True, True, False, False, "2b", "her2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["a"], [xn,yn,an], ["alpha"], "", "Hermitian rank-2 matrix update", "Performs the operation _A = alpha * x * y^T + conj(alpha) * y * x^T + A_, in which _x_ is an input vector and _x^T_ its transpose, _y_ is an input vector and _y^T_ its transpose, _A_ is the triangular Hermetian matrix to be updated, _alpha_ is a scalar value and _conj(alpha)_ its complex conjugate.", [ald_n]), Routine(True, True, 0, False, "2b", "her2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["a"], [xn,yn,an], ["alpha"], "", "Hermitian rank-2 matrix update", "Performs the operation _A = alpha * x * y^T + conj(alpha) * y * x^T + A_, in which _x_ is an input vector and _x^T_ its transpose, _y_ is an input vector and _y^T_ its transpose, _A_ is the triangular Hermetian matrix to be updated, _alpha_ is a scalar value and _conj(alpha)_ its complex conjugate.", [ald_n]),
Routine(True, True, False, False, "2b", "hpr2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["ap"], [xn,yn,apn], ["alpha"], "", "Hermitian packed rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []), Routine(True, True, 0, False, "2b", "hpr2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["ap"], [xn,yn,apn], ["alpha"], "", "Hermitian packed rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
Routine(True, True, False, False, "2b", "syr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["a"], [xn,an], ["alpha"], "", "Symmetric rank-1 matrix update", "Same operation as xHER, but matrix A is a symmetric matrix instead.", [ald_n]), Routine(True, True, 0, False, "2b", "syr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["a"], [xn,an], ["alpha"], "", "Symmetric rank-1 matrix update", "Same operation as xHER, but matrix A is a symmetric matrix instead.", [ald_n]),
Routine(True, True, False, False, "2b", "spr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["ap"], [xn,apn], ["alpha"], "", "Symmetric packed rank-1 matrix update", "Same operation as xSPR, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []), Routine(True, True, 0, False, "2b", "spr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["ap"], [xn,apn], ["alpha"], "", "Symmetric packed rank-1 matrix update", "Same operation as xSPR, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
Routine(True, True, False, False, "2b", "syr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["a"], [xn,yn,an], ["alpha"], "", "Symmetric rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is a symmetric matrix instead.", [ald_n]), Routine(True, True, 0, False, "2b", "syr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["a"], [xn,yn,an], ["alpha"], "", "Symmetric rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is a symmetric matrix instead.", [ald_n]),
Routine(True, True, False, False, "2b", "spr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["ap"], [xn,yn,apn], ["alpha"], "", "Symmetric packed rank-2 matrix update", "Same operation as xSPR2, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []), Routine(True, True, 0, False, "2b", "spr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["ap"], [xn,yn,apn], ["alpha"], "", "Symmetric packed rank-2 matrix update", "Same operation as xSPR2, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
], ],
[ # Level 3: matrix-matrix [ # Level 3: matrix-matrix
Routine(True, True, False, True, "3", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "General matrix-matrix multiplication", "Performs the matrix product _C = alpha * A * B + beta * C_, in which _A_ (_m_ by _k_) and _B_ (_k_ by _n_) are two general rectangular input matrices, _C_ (_m_ by _n_) is the matrix to be updated, and _alpha_ and _beta_ are scalar values. The matrices _A_ and/or _B_ can optionally be transposed before performing the operation.", [ald_transa_m_k, bld_transb_k_n, cld_m]), Routine(True, True, 0, True, "3", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "General matrix-matrix multiplication", "Performs the matrix product _C = alpha * A * B + beta * C_, in which _A_ (_m_ by _k_) and _B_ (_k_ by _n_) are two general rectangular input matrices, _C_ (_m_ by _n_) is the matrix to be updated, and _alpha_ and _beta_ are scalar values. The matrices _A_ and/or _B_ can optionally be transposed before performing the operation.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
Routine(True, True, False, False, "3", "symm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], [ammn,bmnn,cmn], ["alpha","beta"], "", "Symmetric matrix-matrix multiplication", "Same operation as xGEMM, but _A_ is symmetric instead. In case of `side == kLeft`, _A_ is a symmetric _m_ by _m_ matrix and _C = alpha * A * B + beta * C_ is performed. Otherwise, in case of `side == kRight`, _A_ is a symmtric _n_ by _n_ matrix and _C = alpha * B * A + beta * C_ is performed.", [ald_side_m_n, bld_m, cld_m]), Routine(True, True, 0, False, "3", "symm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], [ammn,bmnn,cmn], ["alpha","beta"], "", "Symmetric matrix-matrix multiplication", "Same operation as xGEMM, but _A_ is symmetric instead. In case of `side == kLeft`, _A_ is a symmetric _m_ by _m_ matrix and _C = alpha * A * B + beta * C_ is performed. Otherwise, in case of `side == kRight`, _A_ is a symmtric _n_ by _n_ matrix and _C = alpha * B * A + beta * C_ is performed.", [ald_side_m_n, bld_m, cld_m]),
Routine(True, True, False, False, "3", "hemm", T, [C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], [ammn,bmnn,cmn], ["alpha","beta"], "", "Hermitian matrix-matrix multiplication", "Same operation as xSYMM, but _A_ is an Hermitian matrix instead.", [ald_side_m_n, bld_m, cld_m]), Routine(True, True, 0, False, "3", "hemm", T, [C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], [ammn,bmnn,cmn], ["alpha","beta"], "", "Hermitian matrix-matrix multiplication", "Same operation as xSYMM, but _A_ is an Hermitian matrix instead.", [ald_side_m_n, bld_m, cld_m]),
Routine(True, True, False, False, "3", "syrk", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], [ank,cn], ["alpha","beta"], "", "Rank-K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * A^T + beta * C_ or _C = alpha * A^T * A + beta * C_, in which _A_ is a general matrix and _A^T_ is its transpose, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, cld_m]), Routine(True, True, 0, False, "3", "syrk", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], [ank,cn], ["alpha","beta"], "", "Rank-K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * A^T + beta * C_ or _C = alpha * A^T * A + beta * C_, in which _A_ is a general matrix and _A^T_ is its transpose, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, cld_m]),
Routine(True, True, False, False, "3", "herk", Tc, [Css,Zdd], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], [ank,cn], ["alpha","beta"], "", "Rank-K update of a hermitian matrix", "Same operation as xSYRK, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, cld_m]), Routine(True, True, 0, False, "3", "herk", Tc, [Css,Zdd], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], [ank,cn], ["alpha","beta"], "", "Rank-K update of a hermitian matrix", "Same operation as xSYRK, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, cld_m]),
Routine(True, True, False, False, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * B^T + alpha * B * A^T + beta * C_ or _C = alpha * A^T * B + alpha * B^T * A + beta * C_, in which _A_ and _B_ are general matrices and _A^T_ and _B^T_ are their transposed versions, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, bld_trans_n_k, cld_n]), Routine(True, True, 0, False, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * B^T + alpha * B * A^T + beta * C_ or _C = alpha * A^T * B + alpha * B^T * A + beta * C_, in which _A_ and _B_ are general matrices and _A^T_ and _B^T_ are their transposed versions, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
Routine(True, True, False, False, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]), Routine(True, True, 0, False, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
Routine(True, True, False, False, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]), Routine(True, True, 0, False, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]),
Routine(True, True, False, False, "3", "trsm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Solves a triangular system of equations", "Solves the equation _A * X = alpha * B_ for the unknown _m_ by _n_ matrix X, in which _A_ is an _n_ by _n_ unit or non-unit triangular matrix and B is an _m_ by _n_ matrix. The matrix _B_ is overwritten by the solution _X_.", []), Routine(True, True, 0, False, "3", "trsm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Solves a triangular system of equations", "Solves the equation _A * X = alpha * B_ for the unknown _m_ by _n_ matrix X, in which _A_ is an _n_ by _n_ unit or non-unit triangular matrix and B is an _m_ by _n_ matrix. The matrix _B_ is overwritten by the solution _X_.", []),
], ],
[ # Level X: extra routines (not part of BLAS) [ # Level X: extra routines (not part of BLAS)
# Special routines: # Special routines:
Routine(True, True, False, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]), Routine(True, True, 0, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
Routine(True, True, False, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, [], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix.", []), Routine(True, True, 0, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, [], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix.", []),
# Batched routines: # Batched routines:
Routine(True, True, True, False, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []), Routine(True, True, 1, False, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
Routine(True, True, True, False, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]), Routine(True, True, 1, False, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
Routine(True, True, 2, False, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "StridedBatched version of GEMM", "As GEMM, but multiple strided operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
]] ]]
@ -223,10 +224,10 @@ def main(argv):
if i == 6: if i == 6:
body += cpp.wrapper_cublas(routine) body += cpp.wrapper_cublas(routine)
if i == 7: if i == 7:
if not routine.batched: if routine.batched == 0:
body += cpp.clblast_netlib_c_h(routine) body += cpp.clblast_netlib_c_h(routine)
if i == 8: if i == 8:
if not routine.batched: if routine.batched == 0:
body += cpp.clblast_netlib_c_cc(routine) body += cpp.clblast_netlib_c_cc(routine)
if i == 9: if i == 9:
body += cpp.clblast_h(routine, cuda=True) body += cpp.clblast_h(routine, cuda=True)

View File

@ -58,7 +58,7 @@ def clblast_cc(routine, cuda=False):
result += " auto queue_cpp = Queue(*queue);" + NL result += " auto queue_cpp = Queue(*queue);" + NL
event = "nullptr" if cuda else "event" event = "nullptr" if cuda else "event"
result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, " + event + ");" + NL result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, " + event + ");" + NL
if routine.batched: if routine.batched == 1:
result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL
if routine.temp_buffer: if routine.temp_buffer:
null = "0" if cuda else "nullptr" null = "0" if cuda else "nullptr"
@ -110,7 +110,7 @@ def clblast_c_cc(routine):
template = "<" + flavour.template + ">" if routine.no_scalars() else "" template = "<" + flavour.template + ">" if routine.no_scalars() else ""
indent = " " * (16 + routine.length() + len(template)) indent = " " * (16 + routine.length() + len(template))
result += routine.routine_header_c(flavour, 27, "") + " {" + NL result += routine.routine_header_c(flavour, 27, "") + " {" + NL
if routine.batched: if routine.batched == 1:
result += " " + (NL + " ").join(routine.batched_transform_to_complex(flavour)) + NL result += " " + (NL + " ").join(routine.batched_transform_to_complex(flavour)) + NL
result += " try {" + NL result += " try {" + NL
result += " return static_cast<CLBlastStatusCode>(" + NL result += " return static_cast<CLBlastStatusCode>(" + NL
@ -388,7 +388,7 @@ def performance_test(routine, level_string):
found = False found = False
for flavour in routine.flavours: for flavour in routine.flavours:
if flavour.precision_name == precision: if flavour.precision_name == precision:
extra_template_argument = "0, " if routine.name == "gemm" and not routine.batched else "" extra_template_argument = "0, " if routine.name == "gemm" and routine.batched == 0 else ""
result += NL + " clblast::RunClient<clblast::TestX" + routine.plain_name() result += NL + " clblast::RunClient<clblast::TestX" + routine.plain_name()
result += flavour.test_template(extra_template_argument) result += flavour.test_template(extra_template_argument)
result += ">(argc, argv); break;" + NL result += ">(argc, argv); break;" + NL
@ -410,7 +410,7 @@ def correctness_test(routine, level_string):
result += "int main(int argc, char *argv[]) {" + NL result += "int main(int argc, char *argv[]) {" + NL
result += " auto errors = size_t{0};" + NL result += " auto errors = size_t{0};" + NL
not_first = "false" not_first = "false"
extra_template_arguments = ["1, ", "2, "] if routine.name == "gemm" and not routine.batched else [""] extra_template_arguments = ["1, ", "2, "] if routine.name == "gemm" and routine.batched == 0 else [""]
for extra_template_argument in extra_template_arguments: for extra_template_argument in extra_template_arguments:
for flavour in routine.flavours: for flavour in routine.flavours:
result += " errors += clblast::RunTests<clblast::TestX" + routine.plain_name() result += " errors += clblast::RunTests<clblast::TestX" + routine.plain_name()

View File

@ -12,12 +12,12 @@ import generator.convert as convert
class Routine: class Routine:
"""Class holding routine-specific information (e.g. name, which arguments, which precisions)""" """Class holding routine-specific information (e.g. name, which arguments, which precisions)"""
def __init__(self, implemented, has_tests, batched, temp_buffer, level, name, template, flavours, sizes, options, def __init__(self, implemented, has_tests, batched_strided, temp_buffer, level, name, template, flavours, sizes, options,
inputs, outputs, buffer_sizes, scalars, scratch, inputs, outputs, buffer_sizes, scalars, scratch,
description, details, requirements): description, details, requirements):
self.implemented = implemented self.implemented = implemented
self.has_tests = has_tests self.has_tests = has_tests
self.batched = batched self.batched = batched_strided
self.temp_buffer = temp_buffer self.temp_buffer = temp_buffer
self.level = level self.level = level
self.name = name self.name = name
@ -35,38 +35,42 @@ class Routine:
self.requirements = requirements self.requirements = requirements
def lowercase_name(self): def lowercase_name(self):
postfix = "batched" if self.batched else "" postfix = "strided" if self.batched == 2 else ""
postfix += "batched" if self.batched != 0 else ""
return self.name + postfix return self.name + postfix
def plain_name(self): def plain_name(self):
postfix = "Batched" if self.batched else "" postfix = "Strided" if self.batched == 2 else ""
postfix += "Batched" if self.batched != 0 else ""
return self.name + postfix return self.name + postfix
def capitalized_name(self): def capitalized_name(self):
postfix = "Batched" if self.batched else "" postfix = "Strided" if self.batched == 2 else ""
postfix += "Batched" if self.batched != 0 else ""
return self.name.capitalize() + postfix return self.name.capitalize() + postfix
def upper_name(self): def upper_name(self):
postfix = "BATCHED" if self.batched else "" postfix = "STRIDED" if self.batched == 2 else ""
postfix += "BATCHED" if self.batched != 0 else ""
return self.name.upper() + postfix return self.name.upper() + postfix
def b_star(self): def b_star(self):
return "*" if self.batched else "" return "*" if self.batched == 1 else ""
def b_s(self): def b_s(self):
return "s" if self.batched else "" return "s" if self.batched == 1 else ""
def batch_count_def(self): def batch_count_def(self):
return ["const size_t batch_count"] if self.batched else [] return ["const size_t batch_count"] if self.batched != 0 else []
def batch_count_list(self): def batch_count_list(self):
return ["batch_count"] if self.batched else [] return ["batch_count"] if self.batched != 0 else []
def batch_count_type(self): def batch_count_type(self):
return ["const size_t"] if self.batched else [] return ["const size_t"] if self.batched != 0 else []
def batch_count_doc(self): def batch_count_doc(self):
return ["`const size_t batch_count`: Number of batches. This value must be positive."] if self.batched else [] return ["`const size_t batch_count`: Number of batches. This value must be positive."] if self.batched != 0 else []
def batched_transform_to_cpp(self): def batched_transform_to_cpp(self):
result = [] result = []
@ -230,6 +234,8 @@ class Routine:
a = [name + "_buffer"] a = [name + "_buffer"]
b = [name + "_offset" + self.b_s()] b = [name + "_offset" + self.b_s()]
c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else [] c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else []
if self.batched == 2:
c += [name + "_stride"]
return [", ".join(a + b + c)] return [", ".join(a + b + c)]
return [] return []
@ -239,6 +245,8 @@ class Routine:
a = [name + "_buffer_bis"] a = [name + "_buffer_bis"]
b = [name + "_offset"] b = [name + "_offset"]
c = [name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else [] c = [name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
if self.batched == 2:
c += [name + "_stride"]
return [", ".join(a + b + c)] return [", ".join(a + b + c)]
return [] return []
@ -258,6 +266,8 @@ class Routine:
a = [prefix + "cl_mem " + name + "_buffer"] a = [prefix + "cl_mem " + name + "_buffer"]
b = ["const size_t " + self.b_star() + name + "_offset" + self.b_s()] b = ["const size_t " + self.b_star() + name + "_offset" + self.b_s()]
c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else [] c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
if self.batched == 2:
c += ["const size_t " + name + "_stride"]
return [", ".join(a + b + c)] return [", ".join(a + b + c)]
return [] return []
@ -307,8 +317,10 @@ class Routine:
if name in self.inputs or name in self.outputs: if name in self.inputs or name in self.outputs:
buffer_type = "unsigned int" if (name in self.index_buffers()) else self.template.buffer_type buffer_type = "unsigned int" if (name in self.index_buffers()) else self.template.buffer_type
a = ["Buffer<" + buffer_type + ">(" + name + "_buffer)"] a = ["Buffer<" + buffer_type + ">(" + name + "_buffer)"]
b = [name + "_offsets_cpp"] if self.batched else [name + "_offset"] b = [name + "_offsets_cpp"] if self.batched == 1 else [name + "_offset"]
c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else [] c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else []
if self.batched == 2:
c += [name + "_stride"]
return [", ".join(a + b + c)] return [", ".join(a + b + c)]
return [] return []
@ -375,6 +387,8 @@ class Routine:
a = [prefix + "cl_mem"] a = [prefix + "cl_mem"]
b = ["const size_t" + self.b_star()] b = ["const size_t" + self.b_star()]
c = ["const size_t"] if (name not in self.buffers_without_ld_inc()) else [] c = ["const size_t"] if (name not in self.buffers_without_ld_inc()) else []
if self.batched == 2:
c += ["const size_t"]
return [", ".join(a + b + c)] return [", ".join(a + b + c)]
return [] return []
@ -391,13 +405,15 @@ class Routine:
if name not in self.buffers_without_ld_inc(): if name not in self.buffers_without_ld_inc():
c = ["`const size_t " + name + "_" + self.postfix(name) + "`: " + c = ["`const size_t " + name + "_" + self.postfix(name) + "`: " +
inc_ld_description + "of the " + inout + " " + math_name + ". This value must be greater than 0."] inc_ld_description + "of the " + inout + " " + math_name + ". This value must be greater than 0."]
if self.batched == 2:
c += ["`const size_t " + name + "_stride`: The (fixed) stride between two batches of the " + name.upper() + " matrix."]
return a + b + c return a + b + c
return [] return []
def scalar(self, name): def scalar(self, name):
"""Retrieves the name of a scalar (alpha/beta)""" """Retrieves the name of a scalar (alpha/beta)"""
if name in self.scalars: if name in self.scalars:
if self.batched: if self.batched == 1:
return [name + "s_cpp"] return [name + "s_cpp"]
return [name] return [name]
return [] return []
@ -418,11 +434,11 @@ class Routine:
"""Retrieves the use of a scalar (alpha/beta)""" """Retrieves the use of a scalar (alpha/beta)"""
if name in self.scalars: if name in self.scalars:
if name == "alpha": if name == "alpha":
if self.batched: if self.batched == 1:
return ["alphas_cpp.data()"] return ["alphas_cpp.data()"]
return [flavour.use_alpha()] return [flavour.use_alpha()]
elif name == "beta": elif name == "beta":
if self.batched: if self.batched == 1:
return ["betas_cpp.data()"] return ["betas_cpp.data()"]
return [flavour.use_beta()] return [flavour.use_beta()]
return [name] return [name]
@ -866,7 +882,7 @@ class Routine:
if self.name in self.routines_scalar_no_return(): if self.name in self.routines_scalar_no_return():
routine_name += "_sub" routine_name += "_sub"
indent += " " indent += " "
if self.batched: if self.batched != 0:
routine_name += "batched" routine_name += "batched"
result = return_type + extra_qualifier + " cblas_" + flavour.name.lower() + routine_name + "(" result = return_type + extra_qualifier + " cblas_" + flavour.name.lower() + routine_name + "("
result += (",\n" + indent).join([a for a in self.arguments_def_netlib(flavour)]) + ")" result += (",\n" + indent).join([a for a in self.arguments_def_netlib(flavour)]) + ")"

View File

@ -2336,6 +2336,77 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
const size_t, const size_t,
cl_command_queue*, cl_event*); cl_command_queue*, cl_event*);
// StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED
template <typename T>
StatusCode GemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const T alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const T beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
try {
auto queue_cpp = Queue(*queue);
auto routine = XgemmStridedBatched<T>(queue_cpp, event);
routine.DoGemmStridedBatched(layout, a_transpose, b_transpose,
m, n, k,
alpha,
Buffer<T>(a_buffer), a_offset, a_ld, a_stride,
Buffer<T>(b_buffer), b_offset, b_ld, b_stride,
beta,
Buffer<T>(c_buffer), c_offset, c_ld, c_stride,
batch_count);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
template StatusCode PUBLIC_API GemmStridedBatched<float>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float,
const cl_mem, const size_t, const size_t, const size_t,
const cl_mem, const size_t, const size_t, const size_t,
const float,
cl_mem, const size_t, const size_t, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API GemmStridedBatched<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double,
const cl_mem, const size_t, const size_t, const size_t,
const cl_mem, const size_t, const size_t, const size_t,
const double,
cl_mem, const size_t, const size_t, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API GemmStridedBatched<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float2,
const cl_mem, const size_t, const size_t, const size_t,
const cl_mem, const size_t, const size_t, const size_t,
const float2,
cl_mem, const size_t, const size_t, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API GemmStridedBatched<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double2,
const cl_mem, const size_t, const size_t, const size_t,
const cl_mem, const size_t, const size_t, const size_t,
const double2,
cl_mem, const size_t, const size_t, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API GemmStridedBatched<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const half,
const cl_mem, const size_t, const size_t, const size_t,
const cl_mem, const size_t, const size_t, const size_t,
const half,
cl_mem, const size_t, const size_t, const size_t,
const size_t,
cl_command_queue*, cl_event*);
// ================================================================================================= // =================================================================================================
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) // Retrieves the required size of the temporary buffer for the GEMM kernel (optional)

View File

@ -3846,6 +3846,133 @@ CLBlastStatusCode CLBlastHgemmBatched(const CLBlastLayout layout, const CLBlastT
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); } } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
} }
// GEMM
CLBlastStatusCode CLBlastSgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const float alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const float beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
clblast::GemmStridedBatched(static_cast<clblast::Layout>(layout),
static_cast<clblast::Transpose>(a_transpose),
static_cast<clblast::Transpose>(b_transpose),
m, n, k,
alpha,
a_buffer, a_offset, a_ld, a_stride,
b_buffer, b_offset, b_ld, b_stride,
beta,
c_buffer, c_offset, c_ld, c_stride,
batch_count,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
CLBlastStatusCode CLBlastDgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const double alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const double beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
clblast::GemmStridedBatched(static_cast<clblast::Layout>(layout),
static_cast<clblast::Transpose>(a_transpose),
static_cast<clblast::Transpose>(b_transpose),
m, n, k,
alpha,
a_buffer, a_offset, a_ld, a_stride,
b_buffer, b_offset, b_ld, b_stride,
beta,
c_buffer, c_offset, c_ld, c_stride,
batch_count,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
CLBlastStatusCode CLBlastCgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_float2 alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const cl_float2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
clblast::GemmStridedBatched(static_cast<clblast::Layout>(layout),
static_cast<clblast::Transpose>(a_transpose),
static_cast<clblast::Transpose>(b_transpose),
m, n, k,
float2{alpha.s[0], alpha.s[1]},
a_buffer, a_offset, a_ld, a_stride,
b_buffer, b_offset, b_ld, b_stride,
float2{beta.s[0], beta.s[1]},
c_buffer, c_offset, c_ld, c_stride,
batch_count,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
CLBlastStatusCode CLBlastZgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_double2 alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const cl_double2 beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
clblast::GemmStridedBatched(static_cast<clblast::Layout>(layout),
static_cast<clblast::Transpose>(a_transpose),
static_cast<clblast::Transpose>(b_transpose),
m, n, k,
double2{alpha.s[0], alpha.s[1]},
a_buffer, a_offset, a_ld, a_stride,
b_buffer, b_offset, b_ld, b_stride,
double2{beta.s[0], beta.s[1]},
c_buffer, c_offset, c_ld, c_stride,
batch_count,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
CLBlastStatusCode CLBlastHgemmStridedBatched(const CLBlastLayout layout, const CLBlastTranspose a_transpose, const CLBlastTranspose b_transpose,
const size_t m, const size_t n, const size_t k,
const cl_half alpha,
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const cl_half beta,
cl_mem c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
clblast::GemmStridedBatched(static_cast<clblast::Layout>(layout),
static_cast<clblast::Transpose>(a_transpose),
static_cast<clblast::Transpose>(b_transpose),
m, n, k,
alpha,
a_buffer, a_offset, a_ld, a_stride,
b_buffer, b_offset, b_ld, b_stride,
beta,
c_buffer, c_offset, c_ld, c_stride,
batch_count,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
// ================================================================================================= // =================================================================================================
// Clears the cache of stored binaries // Clears the cache of stored binaries

View File

@ -2436,6 +2436,79 @@ template StatusCode PUBLIC_API GemmBatched<half>(const Layout, const Transpose,
const size_t, const size_t,
const CUcontext, const CUdevice); const CUcontext, const CUdevice);
// StridedBatched version of GEMM: SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED
template <typename T>
StatusCode GemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k,
const T alpha,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride,
const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count,
const CUcontext context, const CUdevice device) {
try {
const auto context_cpp = Context(context);
const auto device_cpp = Device(device);
auto queue_cpp = Queue(context_cpp, device_cpp);
auto routine = XgemmStridedBatched<T>(queue_cpp, nullptr);
routine.DoGemmStridedBatched(layout, a_transpose, b_transpose,
m, n, k,
alpha,
Buffer<T>(a_buffer), a_offset, a_ld, a_stride,
Buffer<T>(b_buffer), b_offset, b_ld, b_stride,
beta,
Buffer<T>(c_buffer), c_offset, c_ld, c_stride,
batch_count);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
template StatusCode PUBLIC_API GemmStridedBatched<float>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float,
const CUdeviceptr, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t, const size_t,
const float,
CUdeviceptr, const size_t, const size_t, const size_t,
const size_t,
const CUcontext, const CUdevice);
template StatusCode PUBLIC_API GemmStridedBatched<double>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double,
const CUdeviceptr, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t, const size_t,
const double,
CUdeviceptr, const size_t, const size_t, const size_t,
const size_t,
const CUcontext, const CUdevice);
template StatusCode PUBLIC_API GemmStridedBatched<float2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const float2,
const CUdeviceptr, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t, const size_t,
const float2,
CUdeviceptr, const size_t, const size_t, const size_t,
const size_t,
const CUcontext, const CUdevice);
template StatusCode PUBLIC_API GemmStridedBatched<double2>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const double2,
const CUdeviceptr, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t, const size_t,
const double2,
CUdeviceptr, const size_t, const size_t, const size_t,
const size_t,
const CUcontext, const CUdevice);
template StatusCode PUBLIC_API GemmStridedBatched<half>(const Layout, const Transpose, const Transpose,
const size_t, const size_t, const size_t,
const half,
const CUdeviceptr, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t, const size_t, const size_t,
const half,
CUdeviceptr, const size_t, const size_t, const size_t,
const size_t,
const CUcontext, const CUdevice);
// ================================================================================================= // =================================================================================================
// Retrieves the required size of the temporary buffer for the GEMM kernel (optional) // Retrieves the required size of the temporary buffer for the GEMM kernel (optional)

View File

@ -0,0 +1,297 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements the XgemmStridedBatched class (see the header for information about the class).
//
// =================================================================================================
#include "routines/levelx/xgemmstridedbatched.hpp"
#include "routines/level3/xgemm.hpp"
#include <string>
#include <vector>
namespace clblast {
// =================================================================================================
// Constructor: forwards to base class constructor
template <typename T>
XgemmStridedBatched<T>::XgemmStridedBatched(Queue &queue, EventPointer event, const std::string &name):
Routine(queue, event, name, {"Copy","Pad","Transpose","Padtranspose","Xgemm","XgemmDirect","GemmRoutine"},
PrecisionValue<T>(), {}, {
#include "../../kernels/level3/level3.opencl"
#include "../../kernels/level3/copy_fast.opencl"
#include "../../kernels/level3/copy_pad.opencl"
#include "../../kernels/level3/transpose_fast.opencl"
#include "../../kernels/level3/transpose_pad.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_direct_part1.opencl"
#include "../../kernels/level3/xgemm_direct_part2.opencl"
#include "../../kernels/level3/xgemm_direct_part3.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_part1.opencl"
#include "../../kernels/level3/xgemm_part2.opencl"
#include "../../kernels/level3/xgemm_part3.opencl"
#include "../../kernels/level3/xgemm_part4.opencl"
, // separated in multiple parts to prevent C1091 in MSVC 2013
#include "../../kernels/level3/xgemm_batched.opencl"
#include "../../kernels/level3/xgemm_direct_batched.opencl"
}) {
}
// =================================================================================================
// The main routine
template <typename T>
void XgemmStridedBatched<T>::DoGemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k, const T alpha,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const T beta,
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count) {
// Tests for a valid batch count
if (batch_count < 1) {
throw BLASError(StatusCode::kInvalidBatchCount);
}
// Computes the transpose/conjugate options and sets the a/b/c sizes based on that
bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
size_t a_one, a_two, b_one, b_two, c_one, c_two;
Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
a_one, a_two, b_one, b_two, c_one, c_two,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
// Tests the matrices for validity
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
TestMatrixA(a_one, a_two, a_buffer, a_offset + a_stride * batch, a_ld);
TestMatrixB(b_one, b_two, b_buffer, b_offset + b_stride * batch, b_ld);
TestMatrixC(c_one, c_two, c_buffer, c_offset + c_stride * batch, c_ld);
}
// Selects which version of the batched GEMM to run
const auto do_gemm_direct = true;
if (do_gemm_direct) { // single generic kernel
BatchedGemmDirect(m, n, k, alpha,
a_buffer, a_offset, a_ld, a_stride,
b_buffer, b_offset, b_ld, b_stride, beta,
c_buffer, c_offset, c_ld, c_stride,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
batch_count);
}
else { // pre/post-processing plus a very fast kernel
BatchedGemmIndirect(m, n, k, alpha,
a_buffer, a_offset, a_ld, a_stride,
b_buffer, b_offset, b_ld, b_stride, beta,
c_buffer, c_offset, c_ld, c_stride,
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
a_one, a_two, b_one, b_two, c_one, c_two, batch_count);
}
}
// =================================================================================================
// The indirect version of batched GEMM. This uses the faster but non-general kernel. It has specific
// requirements, but several pre and post-processing kernels take care of those. However, the
// overhead of these extra kernels might not be ideal for certain devices/arguments.
template <typename T>
void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, const size_t k, const T alpha,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const T beta,
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
const size_t a_one, const size_t a_two,
const size_t b_one, const size_t b_two,
const size_t c_one, const size_t c_two,
const size_t batch_count) {
// Calculates the ceiled versions of m, n, and k
const auto m_ceiled = Ceil(Ceil(m, db_["MWG"]), db_["VWM"]);
const auto n_ceiled = Ceil(Ceil(n, db_["NWG"]), db_["VWN"]);
const auto k_ceiled = Ceil(Ceil(k, db_["KWG"]), db_["VWM"]);
// Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
// whether the matrices need to be rotated or not for the kernel.
size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i;
Xgemm<T>::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"],
a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
/* TODO
// Sets the "internal" offsets, i.e. the perfect offsets
auto a_offsets_i = 0;//std::vector<int>(batch_count);
auto b_offsets_i = 0;//std::vector<int>(batch_count);
auto c_offsets_i = 0;//std::vector<int>(batch_count);
// Determines whether or not temporary matrices are needed
auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offsets == a_offsets_i &&
!a_do_transpose && !a_conjugate;
auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && b_offsets == b_offsets_i &&
!b_do_transpose && !b_conjugate;
auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offsets == c_offsets_i &&
!c_do_transpose;
// Creates the temporary matrices
const auto a_temp = (a_no_temp) ? a_buffer : Buffer<T>(context_, batch_count * a_one_i * a_two_i);
const auto b_temp = (b_no_temp) ? b_buffer : Buffer<T>(context_, batch_count * b_one_i * b_two_i);
const auto c_temp = (c_no_temp) ? c_buffer : Buffer<T>(context_, batch_count * c_one_i * c_two_i);
// Events of all kernels (including pre/post processing kernels)
auto eventWaitList = std::vector<Event>();
auto emptyEventList = std::vector<Event>();
// Runs the pre-processing kernel for matrix A. This transposes the matrix, but also pads zeros
// to fill it up until it reaches a certain multiple of size (kernel parameter dependent). In
// case nothing has to be done, these kernels can be skipped.
if (!a_no_temp) {
auto a_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
auto a_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
a_offsets_device.Write(queue_, batch_count, a_offsets);
a_offsets_i_device.Write(queue_, batch_count, a_offsets_i);
auto eventProcessA = Event();
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
a_one, a_two, a_ld, a_offsets_device, a_buffer,
a_one_i, a_two_i, a_one_i, a_offsets_i_device, a_temp,
program_, true, a_do_transpose, a_conjugate, batch_count);
eventWaitList.push_back(eventProcessA);
}
// As above, but now for matrix B
if (!b_no_temp) {
auto b_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
auto b_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
b_offsets_device.Write(queue_, batch_count, b_offsets);
b_offsets_i_device.Write(queue_, batch_count, b_offsets_i);
auto eventProcessB = Event();
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
b_one, b_two, b_ld, b_offsets_device, b_buffer,
b_one_i, b_two_i, b_one_i, b_offsets_i_device, b_temp,
program_, true, b_do_transpose, b_conjugate, batch_count);
eventWaitList.push_back(eventProcessB);
}
// As above, but now for matrix C
auto c_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
auto c_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
if (!c_no_temp) {
c_offsets_device.Write(queue_, batch_count, c_offsets);
c_offsets_i_device.Write(queue_, batch_count, c_offsets_i);
auto eventProcessC = Event();
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
c_one, c_two, c_ld, c_offsets_device, c_buffer,
c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp,
program_, true, c_do_transpose, false, batch_count);
eventWaitList.push_back(eventProcessC);
}
// Retrieves the Xgemm kernel from the compiled binary
auto kernel = Kernel(program_, "XgemmStridedBatched");
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(m_ceiled));
kernel.SetArgument(1, static_cast<int>(n_ceiled));
kernel.SetArgument(2, static_cast<int>(k_ceiled));
kernel.SetArgument(3, alpha);
kernel.SetArgument(4, beta);
kernel.SetArgument(5, a_temp());
kernel.SetArgument(6, static_cast<int>(a_one_i));
kernel.SetArgument(7, static_cast<int>(a_two_i));
kernel.SetArgument(8, b_temp());
kernel.SetArgument(9, static_cast<int>(b_one_i));
kernel.SetArgument(10, static_cast<int>(b_two_i));
kernel.SetArgument(11, c_temp());
kernel.SetArgument(12, static_cast<int>(c_one_i));
kernel.SetArgument(13, static_cast<int>(c_two_i));
// Computes the global and local thread sizes
const auto global = std::vector<size_t>{
(c_one_i * db_["MDIMC"]) / db_["MWG"],
(c_two_i * db_["NDIMC"]) / db_["NWG"],
batch_count
};
const auto local = std::vector<size_t>{db_["MDIMC"], db_["NDIMC"], 1};
// Launches the kernel
auto eventKernel = Event();
auto eventPointer = eventKernel.pointer();
RunKernel(kernel, queue_, device_, global, local, eventPointer, eventWaitList);
// Runs the post-processing kernel if needed
if (!c_no_temp) {
eventWaitList.push_back(eventKernel);
PadCopyTransposeMatrixBatched(queue_, device_, db_, event_, eventWaitList,
c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp,
c_one, c_two, c_ld, c_offsets_device, c_buffer,
program_, false, c_do_transpose, false, batch_count);
}
*/
}
// =================================================================================================
// The direct version of batched GEMM, requiring just one kernel, no pre or post-processing kernels.
template <typename T>
void XgemmStridedBatched<T>::BatchedGemmDirect(const size_t m, const size_t n, const size_t k, const T alpha,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const T beta,
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
const size_t batch_count) {
/* TODO
// Retrieves the proper XgemmDirect kernel from the compiled binary
const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectBatchedTT" : "XgemmDirectBatchedTN") :
(b_do_transpose ? "XgemmDirectBatchedNT" : "XgemmDirectBatchedNN");
auto kernel = Kernel(program_, name);
// Sets the kernel arguments
kernel.SetArgument(0, static_cast<int>(m));
kernel.SetArgument(1, static_cast<int>(n));
kernel.SetArgument(2, static_cast<int>(k));
kernel.SetArgument(3, alpha);
kernel.SetArgument(4, beta);
kernel.SetArgument(5, a_buffer());
kernel.SetArgument(6, a_offset);
kernel.SetArgument(7, static_cast<int>(a_ld));
kernel.SetArgument(8, b_buffer());
kernel.SetArgument(9, b_offset);
kernel.SetArgument(10, static_cast<int>(b_ld));
kernel.SetArgument(11, c_buffer());
kernel.SetArgument(12, c_offset);
kernel.SetArgument(13, static_cast<int>(c_ld));
kernel.SetArgument(14, static_cast<int>(c_do_transpose));
kernel.SetArgument(15, static_cast<int>(a_conjugate));
kernel.SetArgument(16, static_cast<int>(b_conjugate));
// Computes the global and local thread sizes
const auto m_ceiled = Ceil(m, db_["WGD"]);
const auto n_ceiled = Ceil(n, db_["WGD"]);
const auto global = std::vector<size_t>{
(m_ceiled * db_["MDIMCD"]) / db_["WGD"],
(n_ceiled * db_["NDIMCD"]) / db_["WGD"],
batch_count
};
const auto local = std::vector<size_t>{db_["MDIMCD"], db_["NDIMCD"], 1};
// Launches the kernel
RunKernel(kernel, queue_, device_, global, local, event_);
*/
}
// =================================================================================================
// Compiles the templated class
template class XgemmStridedBatched<half>;
template class XgemmStridedBatched<float>;
template class XgemmStridedBatched<double>;
template class XgemmStridedBatched<float2>;
template class XgemmStridedBatched<double2>;
// =================================================================================================
} // namespace clblast

View File

@ -0,0 +1,66 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements the XgemmStridedBatched routine. This is a non-blas batched version of GEMM.
//
// =================================================================================================
#ifndef CLBLAST_ROUTINES_XGEMMSTRIDEDBATCHED_H_
#define CLBLAST_ROUTINES_XGEMMSTRIDEDBATCHED_H_
#include <vector>
#include "routine.hpp"
namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
template <typename T>
class XgemmStridedBatched: public Routine {
public:
// Constructor
XgemmStridedBatched(Queue &queue, EventPointer event, const std::string &name = "GEMMSTRIDEDBATCHED");
// Templated-precision implementation of the routine
void DoGemmStridedBatched(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
const size_t m, const size_t n, const size_t k, const T alpha,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const T beta,
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const size_t batch_count);
// Indirect version of strided batched GEMM (with pre and post-processing kernels)
void BatchedGemmIndirect(const size_t m, const size_t n, const size_t k, const T alpha,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const T beta,
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
const size_t a_one, const size_t a_two,
const size_t b_one, const size_t b_two,
const size_t c_one, const size_t c_two,
const size_t batch_count);
// Direct version of strided batched GEMM (no pre and post-processing kernels)
void BatchedGemmDirect(const size_t m, const size_t n, const size_t k, const T alpha,
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld, const size_t a_stride,
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld, const size_t b_stride, const T beta,
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld, const size_t c_stride,
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
const bool a_conjugate, const bool b_conjugate,
const size_t batch_count);
};
// =================================================================================================
} // namespace clblast
// CLBLAST_ROUTINES_XGEMMSTRIDEDBATCHED_H_
#endif

View File

@ -71,6 +71,7 @@
#include "routines/levelx/xim2col.hpp" #include "routines/levelx/xim2col.hpp"
#include "routines/levelx/xaxpybatched.hpp" #include "routines/levelx/xaxpybatched.hpp"
#include "routines/levelx/xgemmbatched.hpp" #include "routines/levelx/xgemmbatched.hpp"
#include "routines/levelx/xgemmstridedbatched.hpp"
// CLBLAST_ROUTINES_ROUTINES_H_ // CLBLAST_ROUTINES_ROUTINES_H_
#endif #endif

View File

@ -0,0 +1,26 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// =================================================================================================
#include "test/correctness/testblas.hpp"
#include "test/routines/levelx/xgemmstridedbatched.hpp"
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
auto errors = size_t{0};
errors += clblast::RunTests<clblast::TestXgemmStridedBatched<float>, float, float>(argc, argv, false, "SGEMMSTRIDEDBATCHED");
errors += clblast::RunTests<clblast::TestXgemmStridedBatched<double>, double, double>(argc, argv, true, "DGEMMSTRIDEDBATCHED");
errors += clblast::RunTests<clblast::TestXgemmStridedBatched<clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CGEMMSTRIDEDBATCHED");
errors += clblast::RunTests<clblast::TestXgemmStridedBatched<clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZGEMMSTRIDEDBATCHED");
errors += clblast::RunTests<clblast::TestXgemmStridedBatched<clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HGEMMSTRIDEDBATCHED");
if (errors > 0) { return 1; } else { return 0; }
}
// =================================================================================================

View File

@ -0,0 +1,33 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// =================================================================================================
#include "test/performance/client.hpp"
#include "test/routines/levelx/xgemmstridedbatched.hpp"
// Main function (not within the clblast namespace)
int main(int argc, char *argv[]) {
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
switch(clblast::GetPrecision(command_line_args, clblast::Precision::kSingle)) {
case clblast::Precision::kHalf:
clblast::RunClient<clblast::TestXgemmStridedBatched<clblast::half>, clblast::half, clblast::half>(argc, argv); break;
case clblast::Precision::kSingle:
clblast::RunClient<clblast::TestXgemmStridedBatched<float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
clblast::RunClient<clblast::TestXgemmStridedBatched<double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
clblast::RunClient<clblast::TestXgemmStridedBatched<clblast::float2>, clblast::float2, clblast::float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
clblast::RunClient<clblast::TestXgemmStridedBatched<clblast::double2>, clblast::double2, clblast::double2>(argc, argv); break;
}
return 0;
}
// =================================================================================================

View File

@ -0,0 +1,218 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements a class with static methods to describe the XgemmStridedBatched routine. Examples of
// such 'descriptions' are how to calculate the size a of buffer or how to run the routine. These
// static methods are used by the correctness tester and the performance tester.
//
// =================================================================================================
#ifndef CLBLAST_TEST_ROUTINES_XGEMMSTRIDEDBATCHED_H_
#define CLBLAST_TEST_ROUTINES_XGEMMSTRIDEDBATCHED_H_
#include "test/routines/common.hpp"
namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
template <typename T>
class TestXgemmStridedBatched {
public:
// Although it is a non-BLAS routine, it can still be tested against level-3 routines in a loop
static size_t BLASLevel() { return 3; }
// The list of arguments relevant for this routine
static std::vector<std::string> GetOptions() {
return {kArgM, kArgN, kArgK,
kArgLayout, kArgATransp, kArgBTransp,
kArgALeadDim, kArgBLeadDim, kArgCLeadDim,
kArgAOffset, kArgBOffset, kArgCOffset,
kArgBatchCount, kArgAlpha, kArgBeta};
}
static std::vector<std::string> BuffersIn() { return {kBufMatA, kBufMatB, kBufMatC}; }
static std::vector<std::string> BuffersOut() { return {kBufMatC}; }
// Helper for the sizes per batch
static size_t PerBatchSizeA(const Arguments<T> &args) {
auto a_rotated = (args.layout == Layout::kColMajor && args.a_transpose != Transpose::kNo) ||
(args.layout == Layout::kRowMajor && args.a_transpose == Transpose::kNo);
auto a_two = (a_rotated) ? args.m : args.k;
return a_two * args.a_ld;
}
static size_t PerBatchSizeB(const Arguments<T> &args) {
auto b_rotated = (args.layout == Layout::kColMajor && args.b_transpose != Transpose::kNo) ||
(args.layout == Layout::kRowMajor && args.b_transpose == Transpose::kNo);
auto b_two = (b_rotated) ? args.k : args.n;
return b_two * args.b_ld;
}
static size_t PerBatchSizeC(const Arguments<T> &args) {
auto c_rotated = (args.layout == Layout::kRowMajor);
auto c_two = (c_rotated) ? args.m : args.n;
return c_two * args.c_ld;
}
// Describes how to obtain the sizes of the buffers
static size_t GetSizeA(const Arguments<T> &args) {
return PerBatchSizeA(args) * args.batch_count + args.a_offset;
}
static size_t GetSizeB(const Arguments<T> &args) {
return PerBatchSizeB(args) * args.batch_count + args.b_offset;
}
static size_t GetSizeC(const Arguments<T> &args) {
return PerBatchSizeC(args) * args.batch_count + args.c_offset;
}
// Describes how to set the sizes of all the buffers
static void SetSizes(Arguments<T> &args, Queue&) {
args.a_size = GetSizeA(args);
args.b_size = GetSizeB(args);
args.c_size = GetSizeC(args);
}
// Describes what the default values of the leading dimensions of the matrices are
static size_t DefaultLDA(const Arguments<T> &args) { return args.k; }
static size_t DefaultLDB(const Arguments<T> &args) { return args.n; }
static size_t DefaultLDC(const Arguments<T> &args) { return args.n; }
// Describes which transpose options are relevant for this routine
using Transposes = std::vector<Transpose>;
static Transposes GetATransposes(const Transposes &all) { return all; }
static Transposes GetBTransposes(const Transposes &all) { return all; }
// Describes how to prepare the input data
static void PrepareData(const Arguments<T>&, Queue&, const int, std::vector<T>&,
std::vector<T>&, std::vector<T>&, std::vector<T>&, std::vector<T>&,
std::vector<T>&, std::vector<T>&) {} // N/A for this routine
// Describes how to run the CLBlast routine
static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
#ifdef OPENCL_API
auto queue_plain = queue();
auto event = cl_event{};
auto status = GemmStridedBatched(args.layout, args.a_transpose, args.b_transpose,
args.m, args.n, args.k, args.alpha,
buffers.a_mat(), args.a_offset, args.a_ld, PerBatchSizeA(args),
buffers.b_mat(), args.b_offset, args.b_ld, PerBatchSizeB(args), args.beta,
buffers.c_mat(), args.c_offset, args.c_ld, PerBatchSizeC(args),
args.batch_count,
&queue_plain, &event);
if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
#elif CUDA_API
auto status = GemmStridedBatched(args.layout, args.a_transpose, args.b_transpose,
args.m, args.n, args.k, args.alpha,
buffers.a_mat(), args.a_offset, args.a_ld, PerBatchSizeA(args),
buffers.b_mat(), args.b_offset, args.b_ld, PerBatchSizeB(args), args.beta,
buffers.c_mat(), args.c_offset, args.c_ld, PerBatchSizeC(args),
args.batch_count,
queue.GetContext()(), queue.GetDevice()());
cuStreamSynchronize(queue());
#endif
return status;
}
// Describes how to run the clBLAS routine (for correctness/performance comparison)
#ifdef CLBLAST_REF_CLBLAS
static StatusCode RunReference1(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
auto queue_plain = queue();
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
const auto a_batch_offset = args.a_offset + PerBatchSizeA(args) * batch;
const auto b_batch_offset = args.c_offset + PerBatchSizeB(args) * batch;
const auto c_batch_offset = args.b_offset + PerBatchSizeC(args) * batch;
auto event = cl_event{};
auto status = clblasXgemm(convertToCLBLAS(args.layout),
convertToCLBLAS(args.a_transpose),
convertToCLBLAS(args.b_transpose),
args.m, args.n, args.k, args.alpha,
buffers.a_mat, a_batch_offset, args.a_ld,
buffers.b_mat, b_batch_offset, args.b_ld, args.beta,
buffers.c_mat, c_batch_offset, args.c_ld,
1, &queue_plain, 0, nullptr, &event);
clWaitForEvents(1, &event);
if (static_cast<StatusCode>(status) != StatusCode::kSuccess) {
return static_cast<StatusCode>(status);
}
}
return StatusCode::kSuccess;
}
#endif
// Describes how to run the CPU BLAS routine (for correctness/performance comparison)
#ifdef CLBLAST_REF_CBLAS
static StatusCode RunReference2(const Arguments<T> &args, BuffersHost<T> &buffers_host, Queue &) {
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
const auto a_batch_offset = args.a_offset + PerBatchSizeA(args) * batch;
const auto b_batch_offset = args.c_offset + PerBatchSizeB(args) * batch;
const auto c_batch_offset = args.b_offset + PerBatchSizeC(args) * batch;
cblasXgemm(convertToCBLAS(args.layout),
convertToCBLAS(args.a_transpose),
convertToCBLAS(args.b_transpose),
args.m, args.n, args.k, args.alpha,
buffers_host.a_mat, a_batch_offset, args.a_ld,
buffers_host.b_mat, b_batch_offset, args.b_ld, args.beta,
buffers_host.c_mat, c_batch_offset, args.c_ld);
}
return StatusCode::kSuccess;
}
#endif
// Describes how to run the cuBLAS routine (for correctness/performance comparison)
#ifdef CLBLAST_REF_CUBLAS
static StatusCode RunReference3(const Arguments<T> &args, BuffersCUDA<T> &buffers, Queue &) {
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
const auto a_batch_offset = args.a_offset + PerBatchSizeA(args) * batch;
const auto b_batch_offset = args.c_offset + PerBatchSizeB(args) * batch;
const auto c_batch_offset = args.b_offset + PerBatchSizeC(args) * batch;
auto status = cublasXgemm(reinterpret_cast<cublasHandle_t>(args.cublas_handle), args.layout,
convertToCUBLAS(args.a_transpose),
convertToCUBLAS(args.b_transpose),
args.m, args.n, args.k, args.alpha,
buffers.a_mat, a_batch_offset, args.a_ld,
buffers.b_mat, b_batch_offset, args.b_ld, args.beta,
buffers.c_mat, c_batch_offset, args.c_ld);
if (status != CUBLAS_STATUS_SUCCESS) { return StatusCode::kUnknownError; }
}
return StatusCode::kSuccess;
}
#endif
// Describes how to download the results of the computation (more importantly: which buffer)
static std::vector<T> DownloadResult(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
std::vector<T> result(args.c_size, static_cast<T>(0));
buffers.c_mat.Read(queue, args.c_size, result);
return result;
}
// Describes how to compute the indices of the result buffer
static size_t ResultID1(const Arguments<T> &args) { return args.m; }
static size_t ResultID2(const Arguments<T> &args) { return args.n * args.batch_count; }
static size_t GetResultIndex(const Arguments<T> &args, const size_t id1, const size_t id2_3) {
const size_t id2 = id2_3 % args.n;
const size_t id3 = id2_3 / args.n;
const auto c_batch_offset = args.c_offset + PerBatchSizeC(args) * id3;
return (args.layout == Layout::kRowMajor) ?
id1*args.c_ld + id2 + c_batch_offset:
id2*args.c_ld + id1 + c_batch_offset;
}
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
return args.batch_count * (2 * args.m * args.n * args.k);
}
static size_t GetBytes(const Arguments<T> &args) {
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
}
};
// =================================================================================================
} // namespace clblast
// CLBLAST_TEST_ROUTINES_XGEMMSTRIDEDBATCHED_H_
#endif