Implemented the in-direct version of the strided-batched GEMM kernel
parent
13f0f6fc6e
commit
99a4df88a6
|
@ -10,6 +10,9 @@ Development (next version)
|
|||
- Improved compilation time by splitting the tuning database into multiple compilation units
|
||||
- Various minor fixes and enhancements
|
||||
- Added tuned parameters for various devices (see README)
|
||||
- Added a strided-batched (not part of the BLAS standard) routine, faster but less generic compared
|
||||
to the existing xGEMMBATCHED routines:
|
||||
* SGEMMSTRIDEDBATCHED/DGEMMSTRIDEDBATCHED/CGEMMSTRIDEDBATCHED/ZGEMMSTRIDEDBATCHED/HGEMMSTRIDEDBATCHED
|
||||
|
||||
Version 1.2.0
|
||||
- Fixed a bug in the TRSM/TRSV routines due to missing synchronisations after GEMM/GEMV calls
|
||||
|
|
|
@ -172,6 +172,45 @@ void CopyMatrixBatched(const int src_one, const int src_two,
|
|||
alpha, 0, 0, 0);
|
||||
}
|
||||
|
||||
#endif
|
||||
// =================================================================================================
|
||||
#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
|
||||
|
||||
// Strided-batched version of the above
|
||||
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
|
||||
void CopyPadMatrixStridedBatched(const int src_one, const int src_two,
|
||||
const int src_ld, const int src_offset,
|
||||
const int src_stride, __global const real* restrict src,
|
||||
const int dest_one, const int dest_two,
|
||||
const int dest_ld, const int dest_offset,
|
||||
const int dest_stride, __global real* dest,
|
||||
const int do_conjugate) {
|
||||
const int batch = get_group_id(2);
|
||||
const int src_offset_batch = src_offset + src_stride * batch;
|
||||
const int dest_offset_batch = dest_offset + dest_stride * batch;
|
||||
real alpha; SetToOne(alpha);
|
||||
_CopyPadMatrix(src_one, src_two, src_ld, src_offset_batch, src,
|
||||
dest_one, dest_two, dest_ld, dest_offset_batch, dest,
|
||||
alpha, do_conjugate);
|
||||
}
|
||||
|
||||
// Strided-batched version of the above
|
||||
__kernel __attribute__((reqd_work_group_size(PAD_DIMX, PAD_DIMY, 1)))
|
||||
void CopyMatrixStridedBatched(const int src_one, const int src_two,
|
||||
const int src_ld, const int src_offset,
|
||||
const int src_stride, __global const real* restrict src,
|
||||
const int dest_one, const int dest_two,
|
||||
const int dest_ld, const int dest_offset,
|
||||
const int dest_stride, __global real* dest) {
|
||||
const int batch = get_group_id(2);
|
||||
const int src_offset_batch = src_offset + src_stride * batch;
|
||||
const int dest_offset_batch = dest_offset + dest_stride * batch;
|
||||
real alpha; SetToOne(alpha);
|
||||
_CopyMatrix(src_one, src_two, src_ld, src_offset_batch, src,
|
||||
dest_one, dest_two, dest_ld, dest_offset_batch, dest,
|
||||
alpha, 0, 0, 0);
|
||||
}
|
||||
|
||||
#endif
|
||||
// =================================================================================================
|
||||
|
||||
|
|
|
@ -229,6 +229,47 @@ void TransposeMatrixBatched(const int src_one, const int src_two,
|
|||
alpha, 0, 0, 0);
|
||||
}
|
||||
|
||||
#endif
|
||||
// =================================================================================================
|
||||
#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
|
||||
|
||||
// Strided-batched version of the above
|
||||
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
|
||||
void TransposePadMatrixStridedBatched(const int src_one, const int src_two,
|
||||
const int src_ld, const int src_offset,
|
||||
const int src_stride, __global const real* restrict src,
|
||||
const int dest_one, const int dest_two,
|
||||
const int dest_ld, const int dest_offset,
|
||||
const int dest_stride, __global real* dest,
|
||||
const int do_conjugate) {
|
||||
const int batch = get_group_id(2);
|
||||
const int src_offset_batch = src_offset + src_stride * batch;
|
||||
const int dest_offset_batch = dest_offset + dest_stride * batch;
|
||||
real alpha; SetToOne(alpha);
|
||||
__local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
|
||||
_TransposePadMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src,
|
||||
dest_one, dest_two, dest_ld, dest_offset_batch, dest,
|
||||
alpha, do_conjugate);
|
||||
}
|
||||
|
||||
// Strided-batched version of the above
|
||||
__kernel __attribute__((reqd_work_group_size(PADTRA_TILE, PADTRA_TILE, 1)))
|
||||
void TransposeMatrixStridedBatched(const int src_one, const int src_two,
|
||||
const int src_ld, const int src_offset,
|
||||
const int src_stride, __global const real* restrict src,
|
||||
const int dest_one, const int dest_two,
|
||||
const int dest_ld, const int dest_offset,
|
||||
const int dest_stride, __global real* dest) {
|
||||
const int batch = get_group_id(2);
|
||||
const int src_offset_batch = src_offset + src_stride * batch;
|
||||
const int dest_offset_batch = dest_offset + dest_stride * batch;
|
||||
real alpha; SetToOne(alpha);
|
||||
__local real tile[(PADTRA_WPT*PADTRA_TILE) * (PADTRA_WPT*PADTRA_TILE + PADTRA_PAD)];
|
||||
_TransposeMatrix(tile, src_one, src_two, src_ld, src_offset_batch, src,
|
||||
dest_one, dest_two, dest_ld, dest_offset_batch, dest,
|
||||
alpha, 0, 0, 0);
|
||||
}
|
||||
|
||||
#endif
|
||||
// =================================================================================================
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
R"(
|
||||
|
||||
// =================================================================================================
|
||||
#if defined(ROUTINE_GEMMBATCHED)
|
||||
|
||||
// Main entry point of the kernel. This is the regular full version.
|
||||
__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
|
||||
void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK,
|
||||
const __constant real_arg* arg_alphas,
|
||||
|
@ -58,6 +58,49 @@ void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK,
|
|||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
// =================================================================================================
|
||||
#if defined(ROUTINE_GEMMSTRIDEDBATCHED)
|
||||
|
||||
__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
|
||||
void XgemmStridedBatched(const int kSizeM, const int kSizeN, const int kSizeK,
|
||||
const real_arg arg_alpha, const real_arg arg_beta,
|
||||
const __global realM* restrict agm, const int a_one, const int a_two,
|
||||
const __global realN* restrict bgm, const int b_one, const int b_two,
|
||||
__global realM* cgm, const int c_one, const int c_two) {
|
||||
const int batch = get_group_id(2);
|
||||
const real alpha = GetRealArg(arg_alpha);
|
||||
const real beta = GetRealArg(arg_beta);
|
||||
|
||||
// Sets the offsets
|
||||
const int a_offset = batch * a_one * a_two;
|
||||
const int b_offset = batch * b_one * b_two;
|
||||
const int c_offset = batch * c_one * c_two;
|
||||
const __global realM* restrict agm_ = &agm[a_offset / VWM];
|
||||
const __global realN* restrict bgm_ = &bgm[b_offset / VWN];
|
||||
__global realM* restrict cgm_ = &cgm[c_offset / VWM];
|
||||
|
||||
// Allocates workgroup-private memory (local memory)
|
||||
#if SA == 1
|
||||
__local realM alm[KWG * MWG/VWM];
|
||||
#endif
|
||||
#if SB == 1
|
||||
__local realN blm[KWG * NWG/VWN];
|
||||
#endif
|
||||
|
||||
// Computes the matrix-multiplication and stores the result in global memory
|
||||
#if SA == 1 && SB == 1
|
||||
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm, blm);
|
||||
#elif SA == 1
|
||||
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm);
|
||||
#elif SB == 1
|
||||
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, blm);
|
||||
#else
|
||||
XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
// =================================================================================================
|
||||
|
||||
// End of the C++11 raw string literal
|
||||
|
|
|
@ -239,6 +239,72 @@ void PadCopyTransposeMatrixBatched(Queue &queue, const Device &device,
|
|||
}
|
||||
}
|
||||
|
||||
// Batched version of the above
|
||||
template <typename T>
|
||||
void PadCopyTransposeMatrixStridedBatched(Queue &queue, const Device &device,
|
||||
const Databases &db,
|
||||
EventPointer event, const std::vector<Event> &waitForEvents,
|
||||
const size_t src_one, const size_t src_two,
|
||||
const size_t src_ld, const size_t src_offset,
|
||||
const size_t src_stride, const Buffer<T> &src,
|
||||
const size_t dest_one, const size_t dest_two,
|
||||
const size_t dest_ld, const size_t dest_offset,
|
||||
const size_t dest_stride, const Buffer<T> &dest,
|
||||
const Program &program, const bool do_pad,
|
||||
const bool do_transpose, const bool do_conjugate,
|
||||
const size_t batch_count) {
|
||||
|
||||
// Determines the right kernel
|
||||
auto kernel_name = std::string{};
|
||||
if (do_transpose) {
|
||||
kernel_name = (do_pad) ? "TransposePadMatrixStridedBatched" : "TransposeMatrixStridedBatched";
|
||||
}
|
||||
else {
|
||||
kernel_name = (do_pad) ? "CopyPadMatrixStridedBatched" : "CopyMatrixStridedBatched";
|
||||
}
|
||||
|
||||
// Retrieves the kernel from the compiled binary
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(src_one));
|
||||
kernel.SetArgument(1, static_cast<int>(src_two));
|
||||
kernel.SetArgument(2, static_cast<int>(src_ld));
|
||||
kernel.SetArgument(3, static_cast<int>(src_offset));
|
||||
kernel.SetArgument(4, static_cast<int>(src_stride));
|
||||
kernel.SetArgument(5, src());
|
||||
kernel.SetArgument(6, static_cast<int>(dest_one));
|
||||
kernel.SetArgument(7, static_cast<int>(dest_two));
|
||||
kernel.SetArgument(8, static_cast<int>(dest_ld));
|
||||
kernel.SetArgument(9, static_cast<int>(dest_offset));
|
||||
kernel.SetArgument(10, static_cast<int>(dest_stride));
|
||||
kernel.SetArgument(11, dest());
|
||||
if (do_pad) {
|
||||
kernel.SetArgument(12, static_cast<int>(do_conjugate));
|
||||
}
|
||||
|
||||
// Launches the kernel and returns the error code. Uses global and local thread sizes based on
|
||||
// parameters in the database.
|
||||
if (do_transpose) {
|
||||
const auto global = std::vector<size_t>{
|
||||
Ceil(CeilDiv(dest_one, db["PADTRA_WPT"]), db["PADTRA_TILE"]),
|
||||
Ceil(CeilDiv(dest_two, db["PADTRA_WPT"]), db["PADTRA_TILE"]),
|
||||
batch_count
|
||||
};
|
||||
const auto local = std::vector<size_t>{db["PADTRA_TILE"], db["PADTRA_TILE"], 1};
|
||||
RunKernel(kernel, queue, device, global, local, event, waitForEvents);
|
||||
}
|
||||
else {
|
||||
const auto global = std::vector<size_t>{
|
||||
Ceil(CeilDiv(dest_one, db["PAD_WPTX"]), db["PAD_DIMX"]),
|
||||
Ceil(CeilDiv(dest_two, db["PAD_WPTY"]), db["PAD_DIMY"]),
|
||||
batch_count
|
||||
};
|
||||
const auto local = std::vector<size_t>{db["PAD_DIMX"], db["PAD_DIMY"], 1};
|
||||
RunKernel(kernel, queue, device, global, local, event, waitForEvents);
|
||||
}
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
||||
|
|
|
@ -112,7 +112,7 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n,
|
|||
const size_t b_one, const size_t b_two,
|
||||
const size_t c_one, const size_t c_two,
|
||||
const size_t batch_count) {
|
||||
/* TODO
|
||||
|
||||
// Calculates the ceiled versions of m, n, and k
|
||||
const auto m_ceiled = Ceil(Ceil(m, db_["MWG"]), db_["VWM"]);
|
||||
const auto n_ceiled = Ceil(Ceil(n, db_["NWG"]), db_["VWN"]);
|
||||
|
@ -124,18 +124,10 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n,
|
|||
Xgemm<T>::CalculateInternalDimensions(m, n, k, db_["MWG"], db_["NWG"], db_["KWG"],
|
||||
a_one_i, a_two_i, b_one_i, b_two_i, c_one_i, c_two_i);
|
||||
|
||||
// Sets the "internal" offsets, i.e. the perfect offsets
|
||||
auto a_offsets_i = 0;//std::vector<int>(batch_count);
|
||||
auto b_offsets_i = 0;//std::vector<int>(batch_count);
|
||||
auto c_offsets_i = 0;//std::vector<int>(batch_count);
|
||||
|
||||
// Determines whether or not temporary matrices are needed
|
||||
auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offsets == a_offsets_i &&
|
||||
!a_do_transpose && !a_conjugate;
|
||||
auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && b_offsets == b_offsets_i &&
|
||||
!b_do_transpose && !b_conjugate;
|
||||
auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && c_offsets == c_offsets_i &&
|
||||
!c_do_transpose;
|
||||
auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && !a_do_transpose && !a_conjugate;
|
||||
auto b_no_temp = b_one == b_one_i && b_two == b_two_i && b_ld == b_one && !b_do_transpose && !b_conjugate;
|
||||
auto c_no_temp = c_one == c_one_i && c_two == c_two_i && c_ld == c_one && !c_do_transpose;
|
||||
|
||||
// Creates the temporary matrices
|
||||
const auto a_temp = (a_no_temp) ? a_buffer : Buffer<T>(context_, batch_count * a_one_i * a_two_i);
|
||||
|
@ -150,43 +142,31 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n,
|
|||
// to fill it up until it reaches a certain multiple of size (kernel parameter dependent). In
|
||||
// case nothing has to be done, these kernels can be skipped.
|
||||
if (!a_no_temp) {
|
||||
auto a_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
|
||||
auto a_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
|
||||
a_offsets_device.Write(queue_, batch_count, a_offsets);
|
||||
a_offsets_i_device.Write(queue_, batch_count, a_offsets_i);
|
||||
auto eventProcessA = Event();
|
||||
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
|
||||
a_one, a_two, a_ld, a_offsets_device, a_buffer,
|
||||
a_one_i, a_two_i, a_one_i, a_offsets_i_device, a_temp,
|
||||
program_, true, a_do_transpose, a_conjugate, batch_count);
|
||||
PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
|
||||
a_one, a_two, a_ld, a_offset, a_stride, a_buffer,
|
||||
a_one_i, a_two_i, a_one_i, 0, a_one_i * a_two_i, a_temp,
|
||||
program_, true, a_do_transpose, a_conjugate, batch_count);
|
||||
eventWaitList.push_back(eventProcessA);
|
||||
}
|
||||
|
||||
// As above, but now for matrix B
|
||||
if (!b_no_temp) {
|
||||
auto b_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
|
||||
auto b_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
|
||||
b_offsets_device.Write(queue_, batch_count, b_offsets);
|
||||
b_offsets_i_device.Write(queue_, batch_count, b_offsets_i);
|
||||
auto eventProcessB = Event();
|
||||
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
|
||||
b_one, b_two, b_ld, b_offsets_device, b_buffer,
|
||||
b_one_i, b_two_i, b_one_i, b_offsets_i_device, b_temp,
|
||||
program_, true, b_do_transpose, b_conjugate, batch_count);
|
||||
PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
|
||||
b_one, b_two, b_ld, b_offset, b_stride, b_buffer,
|
||||
b_one_i, b_two_i, b_one_i, 0, b_one_i * b_two_i, b_temp,
|
||||
program_, true, b_do_transpose, b_conjugate, batch_count);
|
||||
eventWaitList.push_back(eventProcessB);
|
||||
}
|
||||
|
||||
// As above, but now for matrix C
|
||||
auto c_offsets_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
|
||||
auto c_offsets_i_device = Buffer<int>(context_, BufferAccess::kReadWrite, batch_count);
|
||||
if (!c_no_temp) {
|
||||
c_offsets_device.Write(queue_, batch_count, c_offsets);
|
||||
c_offsets_i_device.Write(queue_, batch_count, c_offsets_i);
|
||||
auto eventProcessC = Event();
|
||||
PadCopyTransposeMatrixBatched(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
|
||||
c_one, c_two, c_ld, c_offsets_device, c_buffer,
|
||||
c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp,
|
||||
program_, true, c_do_transpose, false, batch_count);
|
||||
PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
|
||||
c_one, c_two, c_ld, c_offset, c_stride, c_buffer,
|
||||
c_one_i, c_two_i, c_one_i, 0, c_one_i * c_two_i, c_temp,
|
||||
program_, true, c_do_transpose, false, batch_count);
|
||||
eventWaitList.push_back(eventProcessC);
|
||||
}
|
||||
|
||||
|
@ -197,8 +177,8 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n,
|
|||
kernel.SetArgument(0, static_cast<int>(m_ceiled));
|
||||
kernel.SetArgument(1, static_cast<int>(n_ceiled));
|
||||
kernel.SetArgument(2, static_cast<int>(k_ceiled));
|
||||
kernel.SetArgument(3, alpha);
|
||||
kernel.SetArgument(4, beta);
|
||||
kernel.SetArgument(3, GetRealArg(alpha));
|
||||
kernel.SetArgument(4, GetRealArg(beta));
|
||||
kernel.SetArgument(5, a_temp());
|
||||
kernel.SetArgument(6, static_cast<int>(a_one_i));
|
||||
kernel.SetArgument(7, static_cast<int>(a_two_i));
|
||||
|
@ -225,12 +205,11 @@ void XgemmStridedBatched<T>::BatchedGemmIndirect(const size_t m, const size_t n,
|
|||
// Runs the post-processing kernel if needed
|
||||
if (!c_no_temp) {
|
||||
eventWaitList.push_back(eventKernel);
|
||||
PadCopyTransposeMatrixBatched(queue_, device_, db_, event_, eventWaitList,
|
||||
c_one_i, c_two_i, c_one_i, c_offsets_i_device, c_temp,
|
||||
c_one, c_two, c_ld, c_offsets_device, c_buffer,
|
||||
program_, false, c_do_transpose, false, batch_count);
|
||||
PadCopyTransposeMatrixStridedBatched(queue_, device_, db_, event_, eventWaitList,
|
||||
c_one_i, c_two_i, c_one_i, 0, c_one_i * c_two_i, c_temp,
|
||||
c_one, c_two, c_ld, c_offset, c_stride, c_buffer,
|
||||
program_, false, c_do_transpose, false, batch_count);
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
|
Loading…
Reference in New Issue