Prepared for the addition of the TRSM triangular solver kernel
parent
6b533dda1c
commit
681a465b35
|
@ -158,7 +158,7 @@ endif()
|
|||
set(LEVEL1_ROUTINES xswap xscal xcopy xaxpy xdot xdotu xdotc xnrm2 xasum xamax)
|
||||
set(LEVEL2_ROUTINES xgemv xgbmv xhemv xhbmv xhpmv xsymv xsbmv xspmv xtrmv xtbmv xtpmv
|
||||
xger xgeru xgerc xher xhpr xher2 xhpr2 xsyr xspr xsyr2 xspr2)
|
||||
set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm)
|
||||
set(LEVEL3_ROUTINES xgemm xsymm xhemm xsyrk xherk xsyr2k xher2k xtrmm xtrsm)
|
||||
set(LEVELX_ROUTINES xomatcopy)
|
||||
set(ROUTINES ${LEVEL1_ROUTINES} ${LEVEL2_ROUTINES} ${LEVEL3_ROUTINES} ${LEVELX_ROUTINES})
|
||||
set(PRECISIONS 32 64 3232 6464 16)
|
||||
|
|
|
@ -2708,6 +2708,77 @@ Requirements for TRMM:
|
|||
|
||||
|
||||
|
||||
xTRSM: Solves a triangular system of equations
|
||||
-------------
|
||||
|
||||
Solves the equation _A * X = alpha * B_ for the unknown _m_ by _n_ matrix X, in which _A_ is an _n_ by _n_ unit or non-unit triangular matrix and B is an _m_ by _n_ matrix. The matrix _B_ is overwritten by the solution _X_.
|
||||
|
||||
C++ API:
|
||||
```
|
||||
template <typename T>
|
||||
StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const T alpha,
|
||||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event)
|
||||
```
|
||||
|
||||
C API:
|
||||
```
|
||||
CLBlastStatusCode CLBlastStrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const float alpha,
|
||||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event)
|
||||
CLBlastStatusCode CLBlastDtrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const double alpha,
|
||||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event)
|
||||
CLBlastStatusCode CLBlastCtrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const cl_float2 alpha,
|
||||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event)
|
||||
CLBlastStatusCode CLBlastZtrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const cl_double2 alpha,
|
||||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event)
|
||||
CLBlastStatusCode CLBlastHtrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const cl_half alpha,
|
||||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event)
|
||||
```
|
||||
|
||||
Arguments to TRSM:
|
||||
|
||||
* `const Layout layout`: Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.
|
||||
* `const Side side`: The position of the triangular matrix in the operation, either on the `Side::kLeft` (141) or `Side::kRight` (142).
|
||||
* `const Triangle triangle`: The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).
|
||||
* `const Transpose a_transpose`: Transposing the input matrix A, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.
|
||||
* `const Diagonal diagonal`: The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.
|
||||
* `const size_t m`: Integer size argument. This value must be positive.
|
||||
* `const size_t n`: Integer size argument. This value must be positive.
|
||||
* `const T alpha`: Input scalar constant.
|
||||
* `const cl_mem a_buffer`: OpenCL buffer to store the input A matrix.
|
||||
* `const size_t a_offset`: The offset in elements from the start of the input A matrix.
|
||||
* `const size_t a_ld`: Leading dimension of the input A matrix. This value must be greater than 0.
|
||||
* `cl_mem b_buffer`: OpenCL buffer to store the output B matrix.
|
||||
* `const size_t b_offset`: The offset in elements from the start of the output B matrix.
|
||||
* `const size_t b_ld`: Leading dimension of the output B matrix. This value must be greater than 0.
|
||||
* `cl_command_queue* queue`: Pointer to an OpenCL command queue associated with a context and device to execute the routine on.
|
||||
* `cl_event* event`: Pointer to an OpenCL event to be able to wait for completion of the routine's OpenCL kernel(s). This is an optional argument.
|
||||
|
||||
|
||||
|
||||
xOMATCOPY: Scaling and out-place transpose/copy (non-BLAS function)
|
||||
-------------
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ FILES = [
|
|||
"/include/clblast_netlib_c.h",
|
||||
"/src/clblast_netlib_c.cpp",
|
||||
]
|
||||
HEADER_LINES = [117, 73, 118, 22, 29, 41, 65, 32]
|
||||
HEADER_LINES = [117, 74, 118, 22, 29, 41, 65, 32]
|
||||
FOOTER_LINES = [17, 80, 19, 18, 6, 6, 9, 2]
|
||||
|
||||
# Different possibilities for requirements
|
||||
|
@ -154,7 +154,7 @@ ROUTINES = [
|
|||
Routine(True, True, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * B^T + alpha * B * A^T + beta * C_ or _C = alpha * A^T * B + alpha * B^T * A + beta * C_, in which _A_ and _B_ are general matrices and _A^T_ and _B^T_ are their transposed versions, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
|
||||
Routine(True, True, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
|
||||
Routine(True, True, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]),
|
||||
Routine(False, True, "3", "trsm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Solves a triangular system of equations", "", []),
|
||||
Routine(True, True, "3", "trsm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Solves a triangular system of equations", "Solves the equation _A * X = alpha * B_ for the unknown _m_ by _n_ matrix X, in which _A_ is an _n_ by _n_ unit or non-unit triangular matrix and B is an _m_ by _n_ matrix. The matrix _B_ is overwritten by the solution _X_.", []),
|
||||
],
|
||||
[ # Level X: extra routines (not part of BLAS)
|
||||
Routine(True, True, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
|
||||
|
|
|
@ -66,6 +66,7 @@
|
|||
#include "routines/level3/xsyr2k.hpp"
|
||||
#include "routines/level3/xher2k.hpp"
|
||||
#include "routines/level3/xtrmm.hpp"
|
||||
#include "routines/level3/xtrsm.hpp"
|
||||
|
||||
// Level-x includes (non-BLAS)
|
||||
#include "routines/levelx/xomatcopy.hpp"
|
||||
|
@ -2067,13 +2068,22 @@ template StatusCode PUBLIC_API Trmm<half>(const Layout, const Side, const Triang
|
|||
|
||||
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM
|
||||
template <typename T>
|
||||
StatusCode Trsm(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
|
||||
const size_t, const size_t,
|
||||
const T,
|
||||
const cl_mem, const size_t, const size_t,
|
||||
cl_mem, const size_t, const size_t,
|
||||
cl_command_queue*, cl_event*) {
|
||||
return StatusCode::kNotImplemented;
|
||||
StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const T alpha,
|
||||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event) {
|
||||
try {
|
||||
auto queue_cpp = Queue(*queue);
|
||||
auto routine = Xtrsm<T>(queue_cpp, event);
|
||||
routine.DoTrsm(layout, side, triangle, a_transpose, diagonal,
|
||||
m, n,
|
||||
alpha,
|
||||
Buffer<T>(a_buffer), a_offset, a_ld,
|
||||
Buffer<T>(b_buffer), b_offset, b_ld);
|
||||
return StatusCode::kSuccess;
|
||||
} catch (...) { return DispatchException(); }
|
||||
}
|
||||
template StatusCode PUBLIC_API Trsm<float>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
|
||||
const size_t, const size_t,
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
|
||||
// =================================================================================================
|
||||
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||
// width of 100 characters per line.
|
||||
//
|
||||
// Author(s):
|
||||
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||
//
|
||||
// This file implements the Xtrsm class (see the header for information about the class).
|
||||
//
|
||||
// =================================================================================================
|
||||
|
||||
#include "routines/level3/xtrsm.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
// Constructor: forwards to base class constructor
|
||||
template <typename T>
|
||||
Xtrsm<T>::Xtrsm(Queue &queue, EventPointer event, const std::string &name):
|
||||
Xgemm<T>(queue, event, name) {
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// The main routine
|
||||
template <typename T>
|
||||
void Xtrsm<T>::DoTrsm(const Layout layout, const Side side, const Triangle triangle,
|
||||
const Transpose a_transpose, const Diagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const T alpha,
|
||||
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld) {
|
||||
|
||||
// Makes sure all dimensions are larger than zero
|
||||
if ((m == 0) || (n == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
|
||||
|
||||
// Computes the k dimension. This is based on whether or not matrix is A (on the left)
|
||||
// or B (on the right) in the Xgemm routine.
|
||||
auto k = (side == Side::kLeft) ? m : n;
|
||||
|
||||
// Checks for validity of the triangular A matrix
|
||||
TestMatrixA(k, k, a_buffer, a_offset, a_ld);
|
||||
|
||||
// Checks for validity of the input/output B matrix
|
||||
const auto b_one = (layout == Layout::kRowMajor) ? n : m;
|
||||
const auto b_two = (layout == Layout::kRowMajor) ? m : n;
|
||||
TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld);
|
||||
|
||||
// Creates a copy of B to avoid overwriting input in GEMM while computing output
|
||||
const auto b_size = (b_ld * (b_two - 1) + b_one + b_offset);
|
||||
auto b_buffer_copy = Buffer<T>(context_, b_size);
|
||||
b_buffer.CopyTo(queue_, b_size, b_buffer_copy);
|
||||
|
||||
// TODO: Implement TRSM computation
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Compiles the templated class
|
||||
template class Xtrsm<half>;
|
||||
template class Xtrsm<float>;
|
||||
template class Xtrsm<double>;
|
||||
template class Xtrsm<float2>;
|
||||
template class Xtrsm<double2>;
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
|
@ -0,0 +1,52 @@
|
|||
|
||||
// =================================================================================================
|
||||
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||
// width of 100 characters per line.
|
||||
//
|
||||
// Author(s):
|
||||
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||
//
|
||||
// This file implements the Xtrsm routine. The implementation is based on ??? (TODO).
|
||||
// Therefore, this class inherits from the Xgemm class.
|
||||
//
|
||||
// =================================================================================================
|
||||
|
||||
#ifndef CLBLAST_ROUTINES_XTRSM_H_
|
||||
#define CLBLAST_ROUTINES_XTRSM_H_
|
||||
|
||||
#include "routines/level3/xgemm.hpp"
|
||||
|
||||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
// See comment at top of file for a description of the class
|
||||
template <typename T>
|
||||
class Xtrsm: public Xgemm<T> {
|
||||
public:
|
||||
|
||||
// Uses methods and variables the Xgemm routine
|
||||
using Xgemm<T>::routine_name_;
|
||||
using Xgemm<T>::queue_;
|
||||
using Xgemm<T>::context_;
|
||||
using Xgemm<T>::device_;
|
||||
using Xgemm<T>::db_;
|
||||
using Xgemm<T>::DoGemm;
|
||||
|
||||
// Constructor
|
||||
Xtrsm(Queue &queue, EventPointer event, const std::string &name = "TRSM");
|
||||
|
||||
// Templated-precision implementation of the routine
|
||||
void DoTrsm(const Layout layout, const Side side, const Triangle triangle,
|
||||
const Transpose a_transpose, const Diagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const T alpha,
|
||||
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld);
|
||||
};
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
||||
// CLBLAST_ROUTINES_XTRSM_H_
|
||||
#endif
|
|
@ -0,0 +1,159 @@
|
|||
|
||||
// =================================================================================================
|
||||
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||
// width of 100 characters per line.
|
||||
//
|
||||
// Author(s):
|
||||
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||
//
|
||||
// This file implements a class with static methods to describe the Xtrsm routine. Examples of
|
||||
// such 'descriptions' are how to calculate the size a of buffer or how to run the routine. These
|
||||
// static methods are used by the correctness tester and the performance tester.
|
||||
//
|
||||
// =================================================================================================
|
||||
|
||||
#ifndef CLBLAST_TEST_ROUTINES_XTRSM_H_
|
||||
#define CLBLAST_TEST_ROUTINES_XTRSM_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#ifdef CLBLAST_REF_CLBLAS
|
||||
#include "test/wrapper_clblas.hpp"
|
||||
#endif
|
||||
#ifdef CLBLAST_REF_CBLAS
|
||||
#include "test/wrapper_cblas.hpp"
|
||||
#endif
|
||||
|
||||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
// See comment at top of file for a description of the class
|
||||
template <typename T>
|
||||
class TestXtrsm {
|
||||
public:
|
||||
|
||||
// The BLAS level: 1, 2, or 3
|
||||
static size_t BLASLevel() { return 3; }
|
||||
|
||||
// The list of arguments relevant for this routine
|
||||
static std::vector<std::string> GetOptions() {
|
||||
return {kArgM, kArgN,
|
||||
kArgLayout, kArgSide, kArgTriangle, kArgATransp, kArgDiagonal,
|
||||
kArgALeadDim, kArgBLeadDim,
|
||||
kArgAOffset, kArgBOffset,
|
||||
kArgAlpha};
|
||||
}
|
||||
|
||||
// Describes how to obtain the sizes of the buffers
|
||||
static size_t GetSizeA(const Arguments<T> &args) {
|
||||
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
||||
return k * args.a_ld + args.a_offset;
|
||||
}
|
||||
static size_t GetSizeB(const Arguments<T> &args) {
|
||||
auto b_rotated = (args.layout == Layout::kRowMajor);
|
||||
auto b_two = (b_rotated) ? args.m : args.n;
|
||||
return b_two * args.b_ld + args.b_offset;
|
||||
}
|
||||
|
||||
// Describes how to set the sizes of all the buffers
|
||||
static void SetSizes(Arguments<T> &args) {
|
||||
args.a_size = GetSizeA(args);
|
||||
args.b_size = GetSizeB(args);
|
||||
}
|
||||
|
||||
// Describes what the default values of the leading dimensions of the matrices are
|
||||
static size_t DefaultLDA(const Arguments<T> &args) { return args.m; }
|
||||
static size_t DefaultLDB(const Arguments<T> &args) { return args.n; }
|
||||
static size_t DefaultLDC(const Arguments<T> &) { return 1; } // N/A for this routine
|
||||
|
||||
// Describes which transpose options are relevant for this routine
|
||||
using Transposes = std::vector<Transpose>;
|
||||
static Transposes GetATransposes(const Transposes &all) { return all; }
|
||||
static Transposes GetBTransposes(const Transposes &) { return {}; } // N/A for this routine
|
||||
|
||||
// Describes how to run the CLBlast routine
|
||||
static StatusCode RunRoutine(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
|
||||
auto queue_plain = queue();
|
||||
auto event = cl_event{};
|
||||
auto status = Trsm(args.layout, args.side, args.triangle, args.a_transpose, args.diagonal,
|
||||
args.m, args.n, args.alpha,
|
||||
buffers.a_mat(), args.a_offset, args.a_ld,
|
||||
buffers.b_mat(), args.b_offset, args.b_ld,
|
||||
&queue_plain, &event);
|
||||
if (status == StatusCode::kSuccess) { clWaitForEvents(1, &event); clReleaseEvent(event); }
|
||||
return status;
|
||||
}
|
||||
|
||||
// Describes how to run the clBLAS routine (for correctness/performance comparison)
|
||||
#ifdef CLBLAST_REF_CLBLAS
|
||||
static StatusCode RunReference1(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
|
||||
auto queue_plain = queue();
|
||||
auto event = cl_event{};
|
||||
auto status = clblasXtrsm(convertToCLBLAS(args.layout),
|
||||
convertToCLBLAS(args.side),
|
||||
convertToCLBLAS(args.triangle),
|
||||
convertToCLBLAS(args.a_transpose),
|
||||
convertToCLBLAS(args.diagonal),
|
||||
args.m, args.n, args.alpha,
|
||||
buffers.a_mat, args.a_offset, args.a_ld,
|
||||
buffers.b_mat, args.b_offset, args.b_ld,
|
||||
1, &queue_plain, 0, nullptr, &event);
|
||||
clWaitForEvents(1, &event);
|
||||
return static_cast<StatusCode>(status);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Describes how to run the CPU BLAS routine (for correctness/performance comparison)
|
||||
#ifdef CLBLAST_REF_CBLAS
|
||||
static StatusCode RunReference2(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
|
||||
std::vector<T> a_mat_cpu(args.a_size, static_cast<T>(0));
|
||||
std::vector<T> b_mat_cpu(args.b_size, static_cast<T>(0));
|
||||
buffers.a_mat.Read(queue, args.a_size, a_mat_cpu);
|
||||
buffers.b_mat.Read(queue, args.b_size, b_mat_cpu);
|
||||
cblasXtrsm(convertToCBLAS(args.layout),
|
||||
convertToCBLAS(args.side),
|
||||
convertToCBLAS(args.triangle),
|
||||
convertToCBLAS(args.a_transpose),
|
||||
convertToCBLAS(args.diagonal),
|
||||
args.m, args.n, args.alpha,
|
||||
a_mat_cpu, args.a_offset, args.a_ld,
|
||||
b_mat_cpu, args.b_offset, args.b_ld);
|
||||
buffers.b_mat.Write(queue, args.b_size, b_mat_cpu);
|
||||
return StatusCode::kSuccess;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Describes how to download the results of the computation (more importantly: which buffer)
|
||||
static std::vector<T> DownloadResult(const Arguments<T> &args, Buffers<T> &buffers, Queue &queue) {
|
||||
std::vector<T> result(args.b_size, static_cast<T>(0));
|
||||
buffers.b_mat.Read(queue, args.b_size, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Describes how to compute the indices of the result buffer
|
||||
static size_t ResultID1(const Arguments<T> &args) { return args.m; }
|
||||
static size_t ResultID2(const Arguments<T> &args) { return args.n; }
|
||||
static size_t GetResultIndex(const Arguments<T> &args, const size_t id1, const size_t id2) {
|
||||
return (args.layout == Layout::kRowMajor) ?
|
||||
id1*args.b_ld + id2 + args.b_offset:
|
||||
id2*args.b_ld + id1 + args.b_offset;
|
||||
}
|
||||
|
||||
// Describes how to compute performance metrics
|
||||
static size_t GetFlops(const Arguments<T> &args) {
|
||||
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
||||
return args.m * args.n * k;
|
||||
}
|
||||
static size_t GetBytes(const Arguments<T> &args) {
|
||||
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
||||
return (k*k + 2*args.m*args.n) * sizeof(T);
|
||||
}
|
||||
};
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
||||
// CLBLAST_TEST_ROUTINES_XTRSM_H_
|
||||
#endif
|
Loading…
Reference in New Issue