Fixed some failing tests for GEMM and batched GEMM routines
parent
f14e6f87d2
commit
93610a9cba
|
@ -59,13 +59,17 @@ void Xgemm<T>::DoGemm(const Layout layout,
|
|||
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld,
|
||||
const Buffer<T> &temp_buffer, const bool temp_buffer_provided) { // optional arguments
|
||||
|
||||
// Two methods to choose from, select which one to run
|
||||
const auto do_gemm_direct = UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]);
|
||||
const auto gemm_kernel_id = (do_gemm_direct) ? 0 : db_["GEMMK"];
|
||||
|
||||
// Computes the transpose/conjugate options and sets the a/b/c sizes based on that
|
||||
bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
|
||||
size_t a_one, a_two, b_one, b_two, c_one, c_two;
|
||||
ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
|
||||
a_one, a_two, b_one, b_two, c_one, c_two,
|
||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
|
||||
db_["GEMMK"]);
|
||||
gemm_kernel_id);
|
||||
|
||||
// Tests three matrices (A, B, C) for validity, first from a perspective of the OpenCL buffers and
|
||||
// their sizes, and then from a perspective of parameter values (e.g. m, n, k). Tests whether the
|
||||
|
@ -79,7 +83,6 @@ void Xgemm<T>::DoGemm(const Layout layout,
|
|||
TestMatrixC(c_one, c_two, c_buffer, c_offset, c_ld);
|
||||
|
||||
// Selects which version of GEMM to run
|
||||
const auto do_gemm_direct = UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]);
|
||||
if (do_gemm_direct) { // for small sizes (single kernel)
|
||||
GemmDirect(m, n, k, alpha,
|
||||
a_buffer, a_offset, a_ld, b_buffer, b_offset, b_ld, beta,
|
||||
|
|
|
@ -65,13 +65,17 @@ void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_trans
|
|||
throw BLASError(StatusCode::kInvalidBatchCount);
|
||||
}
|
||||
|
||||
// Two methods to choose from, select which one to run
|
||||
const auto do_gemm_direct = Xgemm<T>::UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]);
|
||||
const auto gemm_kernel_id = (do_gemm_direct) ? 0 : db_["GEMMK"];
|
||||
|
||||
// Computes the transpose/conjugate options and sets the a/b/c sizes based on that
|
||||
bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
|
||||
size_t a_one, a_two, b_one, b_two, c_one, c_two;
|
||||
Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
|
||||
a_one, a_two, b_one, b_two, c_one, c_two,
|
||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
|
||||
db_["GEMMK"]);
|
||||
gemm_kernel_id);
|
||||
|
||||
// Tests the matrices for validity
|
||||
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||
|
@ -97,7 +101,6 @@ void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_trans
|
|||
}
|
||||
|
||||
// Selects which version of the batched GEMM to run
|
||||
const auto do_gemm_direct = Xgemm<T>::UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]);
|
||||
if (do_gemm_direct) { // single generic kernel
|
||||
BatchedGemmDirect(m, n, k, alphas_device,
|
||||
a_buffer, a_offsets_int, a_ld, b_buffer, b_offsets_int, b_ld,
|
||||
|
|
|
@ -61,13 +61,17 @@ void XgemmStridedBatched<T>::DoGemmStridedBatched(const Layout layout, const Tra
|
|||
throw BLASError(StatusCode::kInvalidBatchCount);
|
||||
}
|
||||
|
||||
// Two methods to choose from, select which one to run
|
||||
const auto do_gemm_direct = Xgemm<T>::UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]);
|
||||
const auto gemm_kernel_id = (do_gemm_direct) ? 0 : db_["GEMMK"];
|
||||
|
||||
// Computes the transpose/conjugate options and sets the a/b/c sizes based on that
|
||||
bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
|
||||
size_t a_one, a_two, b_one, b_two, c_one, c_two;
|
||||
Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
|
||||
a_one, a_two, b_one, b_two, c_one, c_two,
|
||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
|
||||
db_["GEMMK"]);
|
||||
gemm_kernel_id);
|
||||
|
||||
// Tests the matrices for validity
|
||||
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||
|
@ -77,7 +81,6 @@ void XgemmStridedBatched<T>::DoGemmStridedBatched(const Layout layout, const Tra
|
|||
}
|
||||
|
||||
// Selects which version of the batched GEMM to run
|
||||
const auto do_gemm_direct = Xgemm<T>::UseDirectKernel(m, n, k, db_["XGEMM_MIN_INDIRECT_SIZE"]);;
|
||||
if (do_gemm_direct) { // single generic kernel
|
||||
BatchedGemmDirect(m, n, k, alpha,
|
||||
a_buffer, a_offset, a_ld, a_stride,
|
||||
|
|
|
@ -35,12 +35,12 @@ size_t RunOverrideTests(int argc, char *argv[], const bool silent, const std::st
|
|||
const auto kernel_name = std::string{"Xgemm"};
|
||||
const auto precision = PrecisionValue<T>();
|
||||
const auto valid_settings = std::vector<std::unordered_map<std::string,size_t>>{
|
||||
{ {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
|
||||
{ {"KWG",32}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",32}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",32}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
|
||||
{ {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
|
||||
{ {"GEMMK",0}, {"KREG",1}, {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
|
||||
{ {"GEMMK",0}, {"KREG",1}, {"KWG",32}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",32}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",32}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
|
||||
{ {"GEMMK",0}, {"KREG",1}, {"KWG",16}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0}, {"SB",0}, {"STRM",0}, {"STRN",0}, {"VWM",1}, {"VWN",1} },
|
||||
};
|
||||
const auto invalid_settings = std::vector<std::unordered_map<std::string,size_t>>{
|
||||
{ {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0} },
|
||||
{ {"GEMMK",0}, {"KREG",1}, {"KWI",2}, {"MDIMA",4}, {"MDIMC",4}, {"MWG",16}, {"NDIMB",4}, {"NDIMC",4}, {"NWG",16}, {"SA",0} },
|
||||
};
|
||||
|
||||
// Retrieves the arguments
|
||||
|
|
Loading…
Reference in New Issue