Added first naive version of the batched AXPY routine

pull/141/head
Cedric Nugteren 2017-03-05 15:06:14 +01:00
parent cdf354f895
commit b114ea49a9
11 changed files with 342 additions and 69 deletions

View File

@ -159,7 +159,7 @@ set(LEVEL1_ROUTINES xswap xscal xcopy xaxpy xdot xdotu xdotc xnrm2 xasum xamax)
set(LEVEL2_ROUTINES xgemv xgbmv xhemv xhbmv xhpmv xsymv xsbmv xspmv xtrmv xtbmv xtpmv xtrsv
xger xgeru xgerc xher xhpr xher2 xhpr2 xsyr xspr xsyr2 xspr2)
set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm xtrsm)
set(LEVELX_ROUTINES xomatcopy)
set(LEVELX_ROUTINES xomatcopy xaxpybatched)
set(ROUTINES ${LEVEL1_ROUTINES} ${LEVEL2_ROUTINES} ${LEVEL3_ROUTINES} ${LEVELX_ROUTINES})
set(PRECISIONS 32 64 3232 6464 16)

View File

@ -2913,8 +2913,8 @@ C++ API:
template <typename T>
StatusCode AxpyBatched(const size_t n,
const T *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
```
@ -2923,32 +2923,32 @@ C API:
```
CLBlastStatusCode CLBlastSaxpyBatched(const size_t n,
const float *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastDaxpyBatched(const size_t n,
const double *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastCaxpyBatched(const size_t n,
const cl_float2 *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastZaxpyBatched(const size_t n,
const cl_double2 *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastHaxpyBatched(const size_t n,
const cl_half *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event)
```
@ -2958,10 +2958,8 @@ Arguments to AXPYBATCHED:
* `const size_t n`: Integer size argument. This value must be positive.
* `const T *alphas`: Input scalar constants.
* `const cl_mem *x_buffers`: OpenCL buffers to store the input x vectors.
* `const size_t x_offset`: The offset in elements from the start of the input x vectors.
* `const size_t x_inc`: Stride/increment of the input x vectors. This value must be greater than 0.
* `cl_mem *y_buffers`: OpenCL buffers to store the output y vectors.
* `const size_t y_offset`: The offset in elements from the start of the output y vectors.
* `const size_t y_inc`: Stride/increment of the output y vectors. This value must be greater than 0.
* `const size_t batch_count`: Number of batches. This value must be positive.
* `cl_command_queue* queue`: Pointer to an OpenCL command queue associated with a context and device to execute the routine on.

View File

@ -97,6 +97,7 @@ enum class StatusCode {
kInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
// Custom additional status codes for CLBlast
kInvalidBatchCount = -2049, // The batch count needs to be positive
kInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
kMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
kInvalidLocalMemUsage = -2046, // Not enough local memory available on this device
@ -613,8 +614,8 @@ StatusCode Omatcopy(const Layout layout, const Transpose a_transpose,
template <typename T>
StatusCode AxpyBatched(const size_t n,
const T *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event = nullptr);

View File

@ -96,6 +96,7 @@ typedef enum CLBlastStatusCode_ {
CLBlastInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
// Custom additional status codes for CLBlast
CLBlastInvalidBatchCount = -2049, // The batch count needs to be positive
CLBlastInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
CLBlastMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
CLBlastInvalidLocalMemUsage = -2046, // Not enough local memory available on this device
@ -1330,32 +1331,32 @@ CLBlastStatusCode PUBLIC_API CLBlastHomatcopy(const CLBlastLayout layout, const
// Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED
CLBlastStatusCode PUBLIC_API CLBlastSaxpyBatched(const size_t n,
const float *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastDaxpyBatched(const size_t n,
const double *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastCaxpyBatched(const size_t n,
const cl_float2 *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastZaxpyBatched(const size_t n,
const cl_double2 *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastHaxpyBatched(const size_t n,
const cl_half *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event);

View File

@ -41,7 +41,7 @@ FILES = [
"/include/clblast_netlib_c.h",
"/src/clblast_netlib_c.cpp",
]
HEADER_LINES = [121, 76, 125, 23, 29, 41, 65, 32]
HEADER_LINES = [122, 76, 126, 23, 29, 41, 65, 32]
FOOTER_LINES = [25, 138, 27, 38, 6, 6, 9, 2]
HEADER_LINES_DOC = 0
FOOTER_LINES_DOC = 63

View File

@ -223,7 +223,7 @@ class Routine:
"""Retrieves a variable name for a specific input/output vector/matrix (e.g. 'x')"""
if name in self.inputs or name in self.outputs:
a = [name + "_buffer" + self.b_s()]
b = [name + "_offset"]
b = [name + "_offset"] if not self.batched else []
c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else []
return [", ".join(a + b + c)]
return []
@ -251,7 +251,7 @@ class Routine:
prefix = "const " if name in self.inputs else ""
if name in self.inputs or name in self.outputs:
a = [prefix + "cl_mem " + self.b_star() + name + "_buffer" + self.b_s()]
b = ["const size_t " + name + "_offset"]
b = ["const size_t " + name + "_offset"] if not self.batched else []
c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
return [", ".join(a + b + c)]
return []
@ -295,7 +295,7 @@ class Routine:
a = [name + "_buffers_cpp"]
else:
a = ["Buffer<" + buffer_type + ">(" + name + "_buffer)"]
b = [name + "_offset"]
b = [name + "_offset"] if not self.batched else []
c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else []
return [", ".join(a + b + c)]
return []
@ -337,7 +337,7 @@ class Routine:
prefix = "const " if (name in self.inputs) else ""
if (name in self.inputs) or (name in self.outputs):
a = [prefix + "cl_mem" + self.b_star()]
b = ["const size_t"]
b = ["const size_t"] if not self.batched else []
c = ["const size_t"] if (name not in self.buffers_without_ld_inc()) else []
return [", ".join(a + b + c)]
return []
@ -350,12 +350,13 @@ class Routine:
math_name = name.upper() + " matrix" + self.b_s() if (name in self.buffers_matrix()) else name + " vector" + self.b_s()
inc_ld_description = "Leading dimension " if (name in self.buffers_matrix()) else "Stride/increment "
a = ["`" + prefix + "cl_mem " + self.b_star() + name + "_buffer" + self.b_s() + "`: OpenCL buffer" + self.b_s() + " to store the " + inout + " " + math_name + "."]
b = ["`const size_t " + name + "_offset`: The offset in elements from the start of the " + inout + " " + math_name + "."]
b = []
if not self.batched:
b = ["`const size_t " + name + "_offset`: The offset in elements from the start of the " + inout + " " + math_name + "."]
c = []
if name not in self.buffers_without_ld_inc():
c = ["`const size_t " + name + "_" + self.postfix(name) + "`: " +
inc_ld_description + "of the " + inout + " " + math_name + ". This value must be greater than 0."]
else:
c = []
return a + b + c
return []

View File

@ -2178,8 +2178,8 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
template <typename T>
StatusCode AxpyBatched(const size_t n,
const T *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
try {
@ -2195,40 +2195,40 @@ StatusCode AxpyBatched(const size_t n,
}
routine.DoAxpyBatched(n,
alphas_cpp,
x_buffers_cpp, x_offset, x_inc,
y_buffers_cpp, y_offset, y_inc,
x_buffers_cpp, x_inc,
y_buffers_cpp, y_inc,
batch_count);
return StatusCode::kSuccess;
} catch (...) { return DispatchException(); }
}
template StatusCode PUBLIC_API AxpyBatched<float>(const size_t,
const float*,
const cl_mem*, const size_t, const size_t,
cl_mem*, const size_t, const size_t,
const cl_mem*, const size_t,
cl_mem*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API AxpyBatched<double>(const size_t,
const double*,
const cl_mem*, const size_t, const size_t,
cl_mem*, const size_t, const size_t,
const cl_mem*, const size_t,
cl_mem*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API AxpyBatched<float2>(const size_t,
const float2*,
const cl_mem*, const size_t, const size_t,
cl_mem*, const size_t, const size_t,
const cl_mem*, const size_t,
cl_mem*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API AxpyBatched<double2>(const size_t,
const double2*,
const cl_mem*, const size_t, const size_t,
cl_mem*, const size_t, const size_t,
const cl_mem*, const size_t,
cl_mem*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API AxpyBatched<half>(const size_t,
const half*,
const cl_mem*, const size_t, const size_t,
cl_mem*, const size_t, const size_t,
const cl_mem*, const size_t,
cl_mem*, const size_t,
const size_t,
cl_command_queue*, cl_event*);
// =================================================================================================

View File

@ -3450,8 +3450,8 @@ CLBlastStatusCode CLBlastHomatcopy(const CLBlastLayout layout, const CLBlastTran
// AXPY
CLBlastStatusCode CLBlastSaxpyBatched(const size_t n,
const float *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
auto alphas_cpp = std::vector<float>();
@ -3462,8 +3462,8 @@ CLBlastStatusCode CLBlastSaxpyBatched(const size_t n,
return static_cast<CLBlastStatusCode>(
clblast::AxpyBatched(n,
alphas_cpp.data(),
x_buffers, x_offset, x_inc,
y_buffers, y_offset, y_inc,
x_buffers, x_inc,
y_buffers, y_inc,
batch_count,
queue, event)
);
@ -3471,8 +3471,8 @@ CLBlastStatusCode CLBlastSaxpyBatched(const size_t n,
}
CLBlastStatusCode CLBlastDaxpyBatched(const size_t n,
const double *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
auto alphas_cpp = std::vector<double>();
@ -3483,8 +3483,8 @@ CLBlastStatusCode CLBlastDaxpyBatched(const size_t n,
return static_cast<CLBlastStatusCode>(
clblast::AxpyBatched(n,
alphas_cpp.data(),
x_buffers, x_offset, x_inc,
y_buffers, y_offset, y_inc,
x_buffers, x_inc,
y_buffers, y_inc,
batch_count,
queue, event)
);
@ -3492,8 +3492,8 @@ CLBlastStatusCode CLBlastDaxpyBatched(const size_t n,
}
CLBlastStatusCode CLBlastCaxpyBatched(const size_t n,
const cl_float2 *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
auto alphas_cpp = std::vector<float2>();
@ -3504,8 +3504,8 @@ CLBlastStatusCode CLBlastCaxpyBatched(const size_t n,
return static_cast<CLBlastStatusCode>(
clblast::AxpyBatched(n,
alphas_cpp.data(),
x_buffers, x_offset, x_inc,
y_buffers, y_offset, y_inc,
x_buffers, x_inc,
y_buffers, y_inc,
batch_count,
queue, event)
);
@ -3513,8 +3513,8 @@ CLBlastStatusCode CLBlastCaxpyBatched(const size_t n,
}
CLBlastStatusCode CLBlastZaxpyBatched(const size_t n,
const cl_double2 *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
auto alphas_cpp = std::vector<double2>();
@ -3525,8 +3525,8 @@ CLBlastStatusCode CLBlastZaxpyBatched(const size_t n,
return static_cast<CLBlastStatusCode>(
clblast::AxpyBatched(n,
alphas_cpp.data(),
x_buffers, x_offset, x_inc,
y_buffers, y_offset, y_inc,
x_buffers, x_inc,
y_buffers, y_inc,
batch_count,
queue, event)
);
@ -3534,8 +3534,8 @@ CLBlastStatusCode CLBlastZaxpyBatched(const size_t n,
}
CLBlastStatusCode CLBlastHaxpyBatched(const size_t n,
const cl_half *alphas,
const cl_mem *x_buffers, const size_t x_offset, const size_t x_inc,
cl_mem *y_buffers, const size_t y_offset, const size_t y_inc,
const cl_mem *x_buffers, const size_t x_inc,
cl_mem *y_buffers, const size_t y_inc,
const size_t batch_count,
cl_command_queue* queue, cl_event* event) {
auto alphas_cpp = std::vector<half>();
@ -3546,8 +3546,8 @@ CLBlastStatusCode CLBlastHaxpyBatched(const size_t n,
return static_cast<CLBlastStatusCode>(
clblast::AxpyBatched(n,
alphas_cpp.data(),
x_buffers, x_offset, x_inc,
y_buffers, y_offset, y_inc,
x_buffers, x_inc,
y_buffers, y_inc,
batch_count,
queue, event)
);

View File

@ -0,0 +1,59 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements the XaxpyBatched class (see the header for information about the class).
//
// =================================================================================================
#include "routines/levelx/xaxpybatched.hpp"
#include <string>
#include <vector>
namespace clblast {
// =================================================================================================
// Constructor: forwards to base class constructor
template <typename T>
XaxpyBatched<T>::XaxpyBatched(Queue &queue, EventPointer event, const std::string &name):
Xaxpy<T>(queue, event, name) {
}
// =================================================================================================
// The main routine
template <typename T>
void XaxpyBatched<T>::DoAxpyBatched(const size_t n, const std::vector<T> &alphas,
const std::vector<Buffer<T>> &x_buffers, const size_t x_inc,
const std::vector<Buffer<T>> &y_buffers, const size_t y_inc,
const size_t batch_count) {
if (batch_count < 1) { throw BLASError(StatusCode::kInvalidBatchCount); }
if (alphas.size() != batch_count) { throw BLASError(StatusCode::kInvalidBatchCount); }
if (x_buffers.size() != batch_count) { throw BLASError(StatusCode::kInvalidBatchCount); }
if (y_buffers.size() != batch_count) { throw BLASError(StatusCode::kInvalidBatchCount); }
// Naive implementation: calls regular Axpy multiple times
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
DoAxpy(n, alphas[batch],
x_buffers[batch], 0, x_inc,
y_buffers[batch], 0, y_inc);
}
}
// =================================================================================================
// Compiles the templated class
template class XaxpyBatched<half>;
template class XaxpyBatched<float>;
template class XaxpyBatched<double>;
template class XaxpyBatched<float2>;
template class XaxpyBatched<double2>;
// =================================================================================================
} // namespace clblast

View File

@ -0,0 +1,46 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements the XaxpyBatched routine. This is a non-blas batched version of AXPY.
//
// =================================================================================================
#ifndef CLBLAST_ROUTINES_XAXPYBATCHED_H_
#define CLBLAST_ROUTINES_XAXPYBATCHED_H_
#include <vector>
#include "routines/level1/xaxpy.hpp"
namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
template <typename T>
class XaxpyBatched: public Xaxpy<T> {
public:
// Uses the regular Xaxpy routine
using Xaxpy<T>::DoAxpy;
// Constructor
XaxpyBatched(Queue &queue, EventPointer event, const std::string &name = "AXPYBATCHED");
// Templated-precision implementation of the routine
void DoAxpyBatched(const size_t n, const std::vector<T> &alphas,
const std::vector<Buffer<T>> &x_buffers, const size_t x_inc,
const std::vector<Buffer<T>> &y_buffers, const size_t y_inc,
const size_t batch_count);
};
// =================================================================================================
} // namespace clblast
// CLBLAST_ROUTINES_XAXPYBATCHED_H_
#endif

View File

@ -0,0 +1,167 @@
// =================================================================================================
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
// width of 100 characters per line.
//
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file implements a class with static methods to describe the XaxpyBatched routine. Examples of
// such 'descriptions' are how to calculate the size a of buffer or how to run the routine. These
// static methods are used by the correctness tester and the performance tester.
//
// =================================================================================================
#ifndef CLBLAST_TEST_ROUTINES_XAXPYBATCHED_H_
#define CLBLAST_TEST_ROUTINES_XAXPYBATCHED_H_
#include <vector>
#include <string>
#include "utilities/utilities.hpp"
#ifdef CLBLAST_REF_CLBLAS
#include "test/wrapper_clblas.hpp"
#endif
#ifdef CLBLAST_REF_CBLAS
#include "test/wrapper_cblas.hpp"
#endif
namespace clblast {
// =================================================================================================
// See comment at top of file for a description of the class
template <typename T>
class TestXaxpyBatched {
public:
// Although it is a non-BLAS routine, it can still be tested against level-1 routines in a loop
static size_t BLASLevel() { return 1; }
// The list of arguments relevant for this routine
static std::vector<std::string> GetOptions() {
return {kArgN,
kArgXInc, kArgYInc,
kArgAlpha, kArgBatchCount};
}
// Helper to determine a different alpha value per batch
static T GetAlpha(const T alpha_base, const size_t batch_id) {
return alpha_base + Constant<T>(batch_id);
}
// Describes how to obtain the sizes of the buffers (per item, not for the full batch)
static size_t GetSizeX(const Arguments<T> &args) {
return args.n * args.x_inc;
}
static size_t GetSizeY(const Arguments<T> &args) {
return args.n * args.y_inc;
}
// Describes how to set the sizes of all the buffers (per item, not for the full batch)
static void SetSizes(Arguments<T> &args) {
args.x_size = GetSizeX(args);
args.y_size = GetSizeY(args);
}
// Describes what the default values of the leading dimensions of the matrices are
static size_t DefaultLDA(const Arguments<T> &) { return 1; } // N/A for this routine
static size_t DefaultLDB(const Arguments<T> &) { return 1; } // N/A for this routine
static size_t DefaultLDC(const Arguments<T> &) { return 1; } // N/A for this routine
// Describes which transpose options are relevant for this routine
using Transposes = std::vector<Transpose>;
static Transposes GetATransposes(const Transposes &) { return {}; } // N/A for this routine
static Transposes GetBTransposes(const Transposes &) { return {}; } // N/A for this routine
// Describes how to prepare the input data
static void PrepareData(const Arguments<T>&, Queue&, const int, std::vector<T>&,
std::vector<T>&, std::vector<T>&, std::vector<T>&, std::vector<T>&,
std::vector<T>&, std::vector<T>&) {} // N/A for this routine
// Describes how to run the CLBlast routine
static StatusCode RunRoutine(const Arguments<T> &args, std::vector<Buffers<T>> &buffers, Queue &queue) {
auto queue_plain = queue();
auto event = cl_event{};
auto alphas = std::vector<T>();
auto x_buffers = std::vector<cl_mem>();
auto y_buffers = std::vector<cl_mem>();
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
alphas.push_back(GetAlpha(args.alpha, batch));
x_buffers.push_back(buffers[batch].x_vec());
y_buffers.push_back(buffers[batch].y_vec());
}
auto status = AxpyBatched(args.n, alphas.data(),
x_buffers.data(), args.x_inc,
y_buffers.data(), args.y_inc,
args.batch_count,
&queue_plain, &event);
if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
return status;
}
// Describes how to run the clBLAS routine (for correctness/performance comparison)
#ifdef CLBLAST_REF_CLBLAS
static StatusCode RunReference1(const Arguments<T> &args, std::vector<Buffers<T>> &buffers, Queue &queue) {
auto queue_plain = queue();
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
auto event = cl_event{};
auto status = clblasXaxpy(args.n, GetAlpha(args.alpha, batch),
buffers[batch].x_vec, 0, args.x_inc,
buffers[batch].y_vec, 0, args.y_inc,
1, &queue_plain, 0, nullptr, &event);
clWaitForEvents(1, &event);
if (static_cast<StatusCode>(status) != StatusCode::kSuccess) {
return static_cast<StatusCode>(status);
}
}
return StatusCode::kSuccess;
}
#endif
// Describes how to run the CPU BLAS routine (for correctness/performance comparison)
#ifdef CLBLAST_REF_CBLAS
static StatusCode RunReference2(const Arguments<T> &args, std::vector<Buffers<T>> &buffers, Queue &queue) {
for (auto batch = size_t{0}; batch < args.batch_count; ++batch) {
std::vector<T> x_vec_cpu(args.x_size, static_cast<T>(0));
std::vector<T> y_vec_cpu(args.y_size, static_cast<T>(0));
buffers[batch].x_vec.Read(queue, args.x_size, x_vec_cpu);
buffers[batch].y_vec.Read(queue, args.y_size, y_vec_cpu);
cblasXaxpy(args.n, GetAlpha(args.alpha, batch),
x_vec_cpu, 0, args.x_inc,
y_vec_cpu, 0, args.y_inc);
buffers[batch].y_vec.Write(queue, args.y_size, y_vec_cpu);
}
return StatusCode::kSuccess;
}
#endif
// Describes how to download the results of the computation (per item, not for the full batch)
static std::vector<T> DownloadResult(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
std::vector<T> result(args.y_size, static_cast<T>(0));
buffers.y_vec.Read(queue, args.y_size, result);
return result;
}
// Describes how to compute the indices of the result buffer (per item, not for the full batch)
static size_t ResultID1(const Arguments<T> &args) { return args.n; }
static size_t ResultID2(const Arguments<T> &) { return 1; } // N/A for this routine
static size_t GetResultIndex(const Arguments<T> &args, const size_t id1, const size_t) {
return id1 * args.y_inc;
}
// Describes how to compute performance metrics (per item, not for the full batch)
static size_t GetFlops(const Arguments<T> &args) {
return 2 * args.n;
}
static size_t GetBytes(const Arguments<T> &args) {
return (3 * args.n) * sizeof(T);
}
};
// =================================================================================================
} // namespace clblast
// CLBLAST_TEST_ROUTINES_XAXPYBATCHED_H_
#endif