Removed half-precision support from the TRSM routine; too unstable
parent
70d8c4bad7
commit
b7310036ed
|
@ -254,7 +254,7 @@ CLBlast supports almost all the Netlib BLAS routines plus a couple of extra non-
|
|||
| xSPR | ✔ | ✔ | - | - | ✔ |
|
||||
| xSYR2 | ✔ | ✔ | - | - | ✔ |
|
||||
| xSPR2 | ✔ | ✔ | - | - | ✔ |
|
||||
| xTRSV | ✔ | ✔ | ✔ | ✔ | ✔ | (experimental, un-optimized)
|
||||
| xTRSV | ✔ | ✔ | ✔ | ✔ | | (experimental, un-optimized)
|
||||
|
||||
| Level-3 | S | D | C | Z | H |
|
||||
| ---------|---|---|---|---|---|
|
||||
|
@ -266,7 +266,7 @@ CLBlast supports almost all the Netlib BLAS routines plus a couple of extra non-
|
|||
| xSYR2K | ✔ | ✔ | ✔ | ✔ | ✔ |
|
||||
| xHER2K | - | - | ✔ | ✔ | - |
|
||||
| xTRMM | ✔ | ✔ | ✔ | ✔ | ✔ |
|
||||
| xTRSM | ✔ | ✔ | ✔ | ✔ | ✔ | (experimental, un-optimized)
|
||||
| xTRSM | ✔ | ✔ | ✔ | ✔ | | (experimental, un-optimized)
|
||||
|
||||
In addition, some extra non-BLAS routines are also supported by CLBlast, classified as level-X. They are experimental and should be used with care:
|
||||
|
||||
|
|
|
@ -2807,12 +2807,6 @@ CLBlastStatusCode CLBlastZtrsm(const CLBlastLayout layout, const CLBlastSide sid
|
|||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event)
|
||||
CLBlastStatusCode CLBlastHtrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const cl_half alpha,
|
||||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event)
|
||||
```
|
||||
|
||||
Arguments to TRSM:
|
||||
|
|
|
@ -583,7 +583,7 @@ StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, c
|
|||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event = nullptr);
|
||||
|
||||
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM
|
||||
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM
|
||||
template <typename T>
|
||||
StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
|
|
|
@ -1258,7 +1258,7 @@ CLBlastStatusCode PUBLIC_API CLBlastHtrmm(const CLBlastLayout layout, const CLBl
|
|||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event);
|
||||
|
||||
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM
|
||||
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM
|
||||
CLBlastStatusCode PUBLIC_API CLBlastStrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const float alpha,
|
||||
|
@ -1283,12 +1283,6 @@ CLBlastStatusCode PUBLIC_API CLBlastZtrsm(const CLBlastLayout layout, const CLBl
|
|||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event);
|
||||
CLBlastStatusCode PUBLIC_API CLBlastHtrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const cl_half alpha,
|
||||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event);
|
||||
|
||||
// =================================================================================================
|
||||
// Extra non-BLAS routines (level-X)
|
||||
|
|
|
@ -862,7 +862,7 @@ void PUBLIC_API cblas_ztrmm(const CLBlastLayout layout, const CLBlastSide side,
|
|||
const void* a, const int a_ld,
|
||||
void* b, const int b_ld);
|
||||
|
||||
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM
|
||||
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM
|
||||
void PUBLIC_API cblas_strsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
|
||||
const int m, const int n,
|
||||
const float alpha,
|
||||
|
|
|
@ -154,7 +154,7 @@ ROUTINES = [
|
|||
Routine(True, True, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * B^T + alpha * B * A^T + beta * C_ or _C = alpha * A^T * B + alpha * B^T * A + beta * C_, in which _A_ and _B_ are general matrices and _A^T_ and _B^T_ are their transposed versions, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
|
||||
Routine(True, True, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], [ankab,bnkab,cn],["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
|
||||
Routine(True, True, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]),
|
||||
Routine(True, True, "3", "trsm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Solves a triangular system of equations", "Solves the equation _A * X = alpha * B_ for the unknown _m_ by _n_ matrix X, in which _A_ is an _n_ by _n_ unit or non-unit triangular matrix and B is an _m_ by _n_ matrix. The matrix _B_ is overwritten by the solution _X_.", []),
|
||||
Routine(True, True, "3", "trsm", T, [S,D,C,Z], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], [amns,bmn], ["alpha"], "", "Solves a triangular system of equations", "Solves the equation _A * X = alpha * B_ for the unknown _m_ by _n_ matrix X, in which _A_ is an _n_ by _n_ unit or non-unit triangular matrix and B is an _m_ by _n_ matrix. The matrix _B_ is overwritten by the solution _X_.", []),
|
||||
],
|
||||
[ # Level X: extra routines (not part of BLAS)
|
||||
Routine(True, True, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
|
||||
|
|
|
@ -2075,7 +2075,7 @@ template StatusCode PUBLIC_API Trmm<half>(const Layout, const Side, const Triang
|
|||
cl_mem, const size_t, const size_t,
|
||||
cl_command_queue*, cl_event*);
|
||||
|
||||
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM/HTRSM
|
||||
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM
|
||||
template <typename T>
|
||||
StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, const Transpose a_transpose, const Diagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
|
@ -2118,12 +2118,6 @@ template StatusCode PUBLIC_API Trsm<double2>(const Layout, const Side, const Tri
|
|||
const cl_mem, const size_t, const size_t,
|
||||
cl_mem, const size_t, const size_t,
|
||||
cl_command_queue*, cl_event*);
|
||||
template StatusCode PUBLIC_API Trsm<half>(const Layout, const Side, const Triangle, const Transpose, const Diagonal,
|
||||
const size_t, const size_t,
|
||||
const half,
|
||||
const cl_mem, const size_t, const size_t,
|
||||
cl_mem, const size_t, const size_t,
|
||||
cl_command_queue*, cl_event*);
|
||||
|
||||
// =================================================================================================
|
||||
// Extra non-BLAS routines (level-X)
|
||||
|
@ -2178,7 +2172,6 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
|
|||
const cl_mem, const size_t, const size_t,
|
||||
cl_mem, const size_t, const size_t,
|
||||
cl_command_queue*, cl_event*);
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Clears the cache of stored binaries
|
||||
|
|
|
@ -3349,27 +3349,6 @@ CLBlastStatusCode CLBlastZtrsm(const CLBlastLayout layout, const CLBlastSide sid
|
|||
);
|
||||
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
||||
}
|
||||
CLBlastStatusCode CLBlastHtrsm(const CLBlastLayout layout, const CLBlastSide side, const CLBlastTriangle triangle, const CLBlastTranspose a_transpose, const CLBlastDiagonal diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const cl_half alpha,
|
||||
const cl_mem a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
cl_mem b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_command_queue* queue, cl_event* event) {
|
||||
try {
|
||||
return static_cast<CLBlastStatusCode>(
|
||||
clblast::Trsm(static_cast<clblast::Layout>(layout),
|
||||
static_cast<clblast::Side>(side),
|
||||
static_cast<clblast::Triangle>(triangle),
|
||||
static_cast<clblast::Transpose>(a_transpose),
|
||||
static_cast<clblast::Diagonal>(diagonal),
|
||||
m, n,
|
||||
alpha,
|
||||
a_buffer, a_offset, a_ld,
|
||||
b_buffer, b_offset, b_ld,
|
||||
queue, event)
|
||||
);
|
||||
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
// Extra non-BLAS routines (level-X)
|
||||
|
|
|
@ -23,7 +23,6 @@ int main(int argc, char *argv[]) {
|
|||
errors += clblast::RunTests<clblast::TestXtrsm<double>, double, double>(argc, argv, true, "DTRSM");
|
||||
errors += clblast::RunTests<clblast::TestXtrsm<float2>, float2, float2>(argc, argv, true, "CTRSM");
|
||||
errors += clblast::RunTests<clblast::TestXtrsm<double2>, double2, double2>(argc, argv, true, "ZTRSM");
|
||||
errors += clblast::RunTests<clblast::TestXtrsm<half>, half, half>(argc, argv, true, "HTRSM");
|
||||
if (errors > 0) { return 1; } else { return 0; }
|
||||
}
|
||||
|
||||
|
|
|
@ -20,8 +20,7 @@ using double2 = clblast::double2;
|
|||
int main(int argc, char *argv[]) {
|
||||
const auto command_line_args = clblast::RetrieveCommandLineArguments(argc, argv);
|
||||
switch(clblast::GetPrecision(command_line_args, clblast::Precision::kSingle)) {
|
||||
case clblast::Precision::kHalf:
|
||||
clblast::RunClient<clblast::TestXtrsm<half>, half, half>(argc, argv); break;
|
||||
case clblast::Precision::kHalf: throw std::runtime_error("Unsupported precision mode");
|
||||
case clblast::Precision::kSingle:
|
||||
clblast::RunClient<clblast::TestXtrsm<float>, float, float>(argc, argv); break;
|
||||
case clblast::Precision::kDouble:
|
||||
|
|
|
@ -2103,20 +2103,6 @@ void cblasXtrsm(const CBLAS_ORDER layout, const CBLAS_SIDE side, const CBLAS_UPL
|
|||
reinterpret_cast<const double*>(&a_buffer[a_offset]), a_ld,
|
||||
reinterpret_cast<double*>(&b_buffer[b_offset]), b_ld);
|
||||
}
|
||||
void cblasXtrsm(const CBLAS_ORDER layout, const CBLAS_SIDE side, const CBLAS_UPLO triangle, const CBLAS_TRANSPOSE a_transpose, const CBLAS_DIAG diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const half alpha,
|
||||
const std::vector<half>& a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
std::vector<half>& b_buffer, const size_t b_offset, const size_t b_ld) {
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto b_buffer_bis = HalfToFloatBuffer(b_buffer);
|
||||
cblasXtrsm(layout, side, triangle, a_transpose, diagonal,
|
||||
m, n,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
b_buffer_bis, b_offset, b_ld);
|
||||
FloatToHalfBuffer(b_buffer, b_buffer_bis);
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
|
|
@ -2865,24 +2865,6 @@ clblasStatus clblasXtrsm(const clblasOrder layout, const clblasSide side, const
|
|||
b_buffer(), b_offset, b_ld,
|
||||
num_queues, queues, num_wait_events, wait_events, events);
|
||||
}
|
||||
clblasStatus clblasXtrsm(const clblasOrder layout, const clblasSide side, const clblasUplo triangle, const clblasTranspose a_transpose, const clblasDiag diagonal,
|
||||
const size_t m, const size_t n,
|
||||
const half alpha,
|
||||
const Buffer<half>& a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
Buffer<half>& b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
cl_uint num_queues, cl_command_queue *queues,
|
||||
cl_uint num_wait_events, const cl_event *wait_events, cl_event *events) {
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer, queues[0]);
|
||||
auto b_buffer_bis = HalfToFloatBuffer(b_buffer, queues[0]);
|
||||
auto status = clblasXtrsm(layout, side, triangle, a_transpose, diagonal,
|
||||
m, n,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
b_buffer_bis, b_offset, b_ld,
|
||||
num_queues, queues, num_wait_events, wait_events, events);
|
||||
FloatToHalfBuffer(b_buffer, b_buffer_bis, queues[0]);
|
||||
return status;
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
|
Loading…
Reference in New Issue