Made GEMM rotation expectations kernel-specific
parent
0f49dd24e5
commit
0dff7f1ac4
|
@ -2490,7 +2490,8 @@ StatusCode GemmTempBufferSize(const Layout layout, const Transpose a_transpose,
|
|||
else {
|
||||
temp_buffer_size = Xgemm<T>::GetTempSize(layout, a_transpose, b_transpose, m, n, k,
|
||||
a_offset, a_ld, b_offset, b_ld, c_offset, c_ld,
|
||||
db["MWG"], db["NWG"], db["KWG"]);
|
||||
db["MWG"], db["NWG"], db["KWG"] * db["KREG"],
|
||||
db["GEMMK"]);
|
||||
}
|
||||
temp_buffer_size *= sizeof(T); // translate from num-elements to bytes
|
||||
return StatusCode::kSuccess;
|
||||
|
|
|
@ -19,11 +19,6 @@
|
|||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
// Defines the assumptions of the GEMM kernels
|
||||
template <typename T> const bool Xgemm<T>::a_want_rotated_ = false;
|
||||
template <typename T> const bool Xgemm<T>::b_want_rotated_ = true;
|
||||
template <typename T> const bool Xgemm<T>::c_want_rotated_ = false;
|
||||
|
||||
// Constructor: forwards to base class constructor
|
||||
template <typename T>
|
||||
Xgemm<T>::Xgemm(Queue &queue, EventPointer event, const std::string &name):
|
||||
|
@ -69,7 +64,8 @@ void Xgemm<T>::DoGemm(const Layout layout,
|
|||
size_t a_one, a_two, b_one, b_two, c_one, c_two;
|
||||
ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
|
||||
a_one, a_two, b_one, b_two, c_one, c_two,
|
||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
|
||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
|
||||
db_["GEMMK"]);
|
||||
|
||||
// Tests three matrices (A, B, C) for validity, first from a perspective of the OpenCL buffers and
|
||||
// their sizes, and then from a perspective of parameter values (e.g. m, n, k). Tests whether the
|
||||
|
@ -122,13 +118,14 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
|
|||
// Calculates the ceiled versions of m, n, and k
|
||||
const auto m_ceiled = Ceil(m, db_["MWG"]);
|
||||
const auto n_ceiled = Ceil(n, db_["NWG"]);
|
||||
const auto k_ceiled = Ceil(k, db_["KWG"]);
|
||||
const auto k_ceiled = Ceil(k, db_["KWG"] * db_["KREG"]);
|
||||
|
||||
// Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
|
||||
// whether the matrices need to be rotated or not for the kernel.
|
||||
size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i;
|
||||
CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"],
|
||||
a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
|
||||
CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"] * db_["KREG"],
|
||||
a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i,
|
||||
db_["GEMMK"]);
|
||||
|
||||
// Determines whether or not temporary matrices are needed
|
||||
auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate);
|
||||
|
|
|
@ -25,9 +25,9 @@ class Xgemm: public Routine {
|
|||
public:
|
||||
|
||||
// Defines the assumptions of the GEMM kernels
|
||||
static const bool a_want_rotated_;
|
||||
static const bool b_want_rotated_;
|
||||
static const bool c_want_rotated_;
|
||||
static const bool a_want_rotated_(const size_t gemm_kernel_id) { return gemm_kernel_id == 1; }
|
||||
static const bool b_want_rotated_(const size_t gemm_kernel_id) { return true; }
|
||||
static const bool c_want_rotated_(const size_t gemm_kernel_id) { return gemm_kernel_id == 1; }
|
||||
|
||||
// Computes the size of the temporary GEMM buffer based on user-arguments
|
||||
static size_t GetTempSize(const Layout layout, const Transpose a_transpose, const Transpose b_transpose,
|
||||
|
@ -35,20 +35,23 @@ class Xgemm: public Routine {
|
|||
const size_t a_offset, const size_t a_ld,
|
||||
const size_t b_offset, const size_t b_ld,
|
||||
const size_t c_offset, const size_t c_ld,
|
||||
const size_t mwg, const size_t nwg, const size_t kwg) {
|
||||
const size_t mwg, const size_t nwg, const size_t kwg,
|
||||
const size_t gemm_kernel_id) {
|
||||
|
||||
// Computes the transpose/conjugate options and sets the a/b/c sizes based on that
|
||||
bool a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate;
|
||||
size_t a_one, a_two, b_one, b_two, c_one, c_two;
|
||||
ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
|
||||
a_one, a_two, b_one, b_two, c_one, c_two,
|
||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
|
||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
|
||||
gemm_kernel_id);
|
||||
|
||||
// Computes the first and second "internal" (ceiled) dimensions of the 3 matrices taking into account
|
||||
// whether the matrices need to be rotated or not for the kernel.
|
||||
size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i;
|
||||
CalculateInternalDimensions(m, n, k, mwg, nwg, kwg,
|
||||
a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
|
||||
a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i,
|
||||
gemm_kernel_id);
|
||||
|
||||
// Determines whether or not temporary matrices are needed
|
||||
auto a_no_temp = NoTempBuffer(a_one, a_one_i, a_two, a_two_i, a_ld, a_offset, a_do_transpose, a_conjugate);
|
||||
|
@ -79,7 +82,8 @@ class Xgemm: public Routine {
|
|||
size_t& a_one, size_t& a_two, size_t& b_one,
|
||||
size_t& b_two, size_t& c_one, size_t& c_two,
|
||||
bool& a_do_transpose, bool& b_do_transpose, bool& c_do_transpose,
|
||||
bool& a_conjugate, bool& b_conjugate) {
|
||||
bool& a_conjugate, bool& b_conjugate,
|
||||
const size_t gemm_kernel_id) {
|
||||
|
||||
// Makes sure all dimensions are larger than zero
|
||||
if ((m == 0) || (n == 0) || (k == 0)) { throw BLASError(StatusCode::kInvalidDimension); }
|
||||
|
@ -94,9 +98,9 @@ class Xgemm: public Routine {
|
|||
const auto b_rotated = (layout == Layout::kColMajor && b_transpose != Transpose::kNo) ||
|
||||
(layout == Layout::kRowMajor && b_transpose == Transpose::kNo);
|
||||
const auto c_rotated = (layout == Layout::kRowMajor);
|
||||
a_do_transpose = a_rotated != a_want_rotated_;
|
||||
b_do_transpose = b_rotated != b_want_rotated_;
|
||||
c_do_transpose = c_rotated != c_want_rotated_;
|
||||
a_do_transpose = a_rotated != a_want_rotated_(gemm_kernel_id);
|
||||
b_do_transpose = b_rotated != b_want_rotated_(gemm_kernel_id);
|
||||
c_do_transpose = c_rotated != c_want_rotated_(gemm_kernel_id);
|
||||
|
||||
// In case of complex data-types, the transpose can also become a conjugate transpose
|
||||
a_conjugate = (a_transpose == Transpose::kConjugate);
|
||||
|
@ -136,16 +140,17 @@ class Xgemm: public Routine {
|
|||
static void CalculateInternalDimensions(const size_t m, const size_t n, const size_t k,
|
||||
const size_t mwg, const size_t nwg, const size_t kwg,
|
||||
size_t& a_one_i, size_t& a_two_i, size_t& b_one_i,
|
||||
size_t& b_two_i, size_t& c_one_i, size_t& c_two_i) {
|
||||
size_t& b_two_i, size_t& c_one_i, size_t& c_two_i,
|
||||
const size_t gemm_kernel_id) {
|
||||
const auto m_ceiled = Ceil(m, mwg);
|
||||
const auto n_ceiled = Ceil(n, nwg);
|
||||
const auto k_ceiled = Ceil(k, kwg);
|
||||
a_one_i = (a_want_rotated_) ? k_ceiled : m_ceiled;
|
||||
a_two_i = (a_want_rotated_) ? m_ceiled : k_ceiled;
|
||||
b_one_i = (b_want_rotated_) ? n_ceiled : k_ceiled;
|
||||
b_two_i = (b_want_rotated_) ? k_ceiled : n_ceiled;
|
||||
c_one_i = (c_want_rotated_) ? n_ceiled : m_ceiled;
|
||||
c_two_i = (c_want_rotated_) ? m_ceiled : n_ceiled;
|
||||
a_one_i = (a_want_rotated_(gemm_kernel_id)) ? k_ceiled : m_ceiled;
|
||||
a_two_i = (a_want_rotated_(gemm_kernel_id)) ? m_ceiled : k_ceiled;
|
||||
b_one_i = (b_want_rotated_(gemm_kernel_id)) ? n_ceiled : k_ceiled;
|
||||
b_two_i = (b_want_rotated_(gemm_kernel_id)) ? k_ceiled : n_ceiled;
|
||||
c_one_i = (c_want_rotated_(gemm_kernel_id)) ? n_ceiled : m_ceiled;
|
||||
c_two_i = (c_want_rotated_(gemm_kernel_id)) ? m_ceiled : n_ceiled;
|
||||
}
|
||||
|
||||
// Constructor
|
||||
|
|
|
@ -70,7 +70,8 @@ void XgemmBatched<T>::DoGemmBatched(const Layout layout, const Transpose a_trans
|
|||
size_t a_one, a_two, b_one, b_two, c_one, c_two;
|
||||
Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
|
||||
a_one, a_two, b_one, b_two, c_one, c_two,
|
||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
|
||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
|
||||
db_["GEMMK"]);
|
||||
|
||||
// Tests the matrices for validity
|
||||
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||
|
@ -141,7 +142,8 @@ void XgemmBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n, const
|
|||
// whether the matrices need to be rotated or not for the kernel.
|
||||
size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i;
|
||||
Xgemm<T>::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"],
|
||||
a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
|
||||
a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i,
|
||||
db_["GEMMK"]);
|
||||
|
||||
// Sets the "internal" offsets, i.e. the perfect offsets
|
||||
auto a_offsets_i = std::vector<int>(batch_count);
|
||||
|
|
|
@ -66,7 +66,8 @@ void XgemmStridedBatched<T>::DoGemmStridedBatched(const Layout layout, const Tra
|
|||
size_t a_one, a_two, b_one, b_two, c_one, c_two;
|
||||
Xgemm<T>::ProcessArguments(layout, a_transpose, b_transpose, m, n, k,
|
||||
a_one, a_two, b_one, b_two, c_one, c_two,
|
||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate);
|
||||
a_do_transpose, b_do_transpose, c_do_transpose, a_conjugate, b_conjugate,
|
||||
db_["GEMMK"]);
|
||||
|
||||
// Tests the matrices for validity
|
||||
for (auto batch = size_t{0}; batch < batch_count; ++batch) {
|
||||
|
@ -122,7 +123,8 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n,
|
|||
// whether the matrices need to be rotated or not for the kernel.
|
||||
size_t a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i;
|
||||
Xgemm<T>::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"],
|
||||
a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
|
||||
a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i,
|
||||
db_["GEMMK"]);
|
||||
|
||||
// Determines whether or not temporary matrices are needed
|
||||
auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && !a_do_transpose && !a_conjugate;
|
||||
|
|
Loading…
Reference in New Issue