Added support for conjugate transpose in GEMV
parent
d7a0d970e0
commit
7e176ccac9
|
@ -58,7 +58,8 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta,
|
|||
const int a_rotated,
|
||||
const __global real* restrict agm, const int a_offset, const int a_ld,
|
||||
const __global real* restrict xgm, const int x_offset, const int x_inc,
|
||||
__global real* ygm, const int y_offset, const int y_inc) {
|
||||
__global real* ygm, const int y_offset, const int y_inc,
|
||||
const int do_conjugate) {
|
||||
|
||||
// Local memory for the vector X
|
||||
__local real xlm[WGS1];
|
||||
|
@ -95,14 +96,18 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta,
|
|||
#pragma unroll
|
||||
for (int kl=0; kl<WGS1; ++kl) {
|
||||
const int k = kwg + kl;
|
||||
MultiplyAdd(acc[w], xlm[kl], agm[gid + a_ld*k + a_offset]);
|
||||
real value = agm[gid + a_ld*k + a_offset];
|
||||
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
|
||||
MultiplyAdd(acc[w], xlm[kl], value);
|
||||
}
|
||||
}
|
||||
else { // Transposed
|
||||
#pragma unroll
|
||||
for (int kl=0; kl<WGS1; ++kl) {
|
||||
const int k = kwg + kl;
|
||||
MultiplyAdd(acc[w], xlm[kl], agm[k + a_ld*gid + a_offset]);
|
||||
real value = agm[k + a_ld*gid + a_offset];
|
||||
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
|
||||
MultiplyAdd(acc[w], xlm[kl], value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -122,13 +127,17 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta,
|
|||
if (a_rotated == 0) { // Not rotated
|
||||
#pragma unroll
|
||||
for (int k=n_floor; k<n; ++k) {
|
||||
MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], agm[gid + a_ld*k + a_offset]);
|
||||
real value = agm[gid + a_ld*k + a_offset];
|
||||
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
|
||||
MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], value);
|
||||
}
|
||||
}
|
||||
else { // Transposed
|
||||
#pragma unroll
|
||||
for (int k=n_floor; k<n; ++k) {
|
||||
MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], agm[k + a_ld*gid + a_offset]);
|
||||
real value = agm[k + a_ld*gid + a_offset];
|
||||
if (do_conjugate == 1) { COMPLEX_CONJUGATE(value); }
|
||||
MultiplyAdd(acc[w], xgm[k*x_inc + x_offset], value);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -159,12 +168,14 @@ __kernel void Xgemv(const int m, const int n, const real alpha, const real beta,
|
|||
// --> 'a_offset' is 0
|
||||
// --> 'a_ld' is a multiple of VW2
|
||||
// --> 'a_rotated' is 0
|
||||
// --> 'do_conjugate' is 0
|
||||
__attribute__((reqd_work_group_size(WGS2, 1, 1)))
|
||||
__kernel void XgemvFast(const int m, const int n, const real alpha, const real beta,
|
||||
const int a_rotated,
|
||||
const __global realVF* restrict agm, const int a_offset, const int a_ld,
|
||||
const __global real* restrict xgm, const int x_offset, const int x_inc,
|
||||
__global real* ygm, const int y_offset, const int y_inc) {
|
||||
__global real* ygm, const int y_offset, const int y_inc,
|
||||
const int do_conjugate) {
|
||||
// Local memory for the vector X
|
||||
__local real xlm[WGS2];
|
||||
|
||||
|
@ -265,12 +276,14 @@ __kernel void XgemvFast(const int m, const int n, const real alpha, const real b
|
|||
// --> 'a_offset' is 0
|
||||
// --> 'a_ld' is a multiple of VW3
|
||||
// --> 'a_rotated' is 1
|
||||
// --> 'do_conjugate' is 0
|
||||
__attribute__((reqd_work_group_size(WGS3, 1, 1)))
|
||||
__kernel void XgemvFastRot(const int m, const int n, const real alpha, const real beta,
|
||||
const int a_rotated,
|
||||
const __global realVFR* restrict agm, const int a_offset, const int a_ld,
|
||||
const __global real* restrict xgm, const int x_offset, const int x_inc,
|
||||
__global real* ygm, const int y_offset, const int y_inc) {
|
||||
__global real* ygm, const int y_offset, const int y_inc,
|
||||
const int do_conjugate) {
|
||||
// Local memory for the vector X
|
||||
__local real xlm[WGS3];
|
||||
|
||||
|
|
|
@ -54,13 +54,16 @@ StatusCode Xgemv<T>::DoGemv(const Layout layout, const Transpose a_transpose,
|
|||
auto a_two = (a_altlayout) ? m : n;
|
||||
|
||||
// Swap m and n if the matrix is transposed
|
||||
auto a_transposed = (a_transpose == Transpose::kYes);
|
||||
auto a_transposed = (a_transpose != Transpose::kNo);
|
||||
auto m_real = (a_transposed) ? n : m;
|
||||
auto n_real = (a_transposed) ? m : n;
|
||||
|
||||
// Determines whether the kernel needs to perform rotated access ('^' is the XOR operator)
|
||||
auto a_rotated = a_transposed ^ a_altlayout;
|
||||
|
||||
// In case of complex data-types, the transpose can also become a conjugate transpose
|
||||
auto a_conjugate = (a_transpose == Transpose::kConjugate);
|
||||
|
||||
// Tests the matrix and the vectors for validity
|
||||
auto status = TestMatrixA(a_one, a_two, a_buffer, a_offset, a_ld, sizeof(T));
|
||||
if (ErrorIn(status)) { return status; }
|
||||
|
@ -70,11 +73,11 @@ StatusCode Xgemv<T>::DoGemv(const Layout layout, const Transpose a_transpose,
|
|||
if (ErrorIn(status)) { return status; }
|
||||
|
||||
// Determines whether or not the fast-version can be used
|
||||
bool use_fast_kernel = (a_offset == 0) && (a_rotated == 0) &&
|
||||
bool use_fast_kernel = (a_offset == 0) && (a_rotated == 0) && (a_conjugate == 0) &&
|
||||
IsMultiple(m, db_["WGS2"]*db_["WPT2"]) &&
|
||||
IsMultiple(n, db_["WGS2"]) &&
|
||||
IsMultiple(a_ld, db_["VW2"]);
|
||||
bool use_fast_kernel_rot = (a_offset == 0) && (a_rotated == 1) &&
|
||||
bool use_fast_kernel_rot = (a_offset == 0) && (a_rotated == 1) && (a_conjugate == 0) &&
|
||||
IsMultiple(m, db_["WGS3"]*db_["WPT3"]) &&
|
||||
IsMultiple(n, db_["WGS3"]) &&
|
||||
IsMultiple(a_ld, db_["VW3"]);
|
||||
|
@ -115,6 +118,7 @@ StatusCode Xgemv<T>::DoGemv(const Layout layout, const Transpose a_transpose,
|
|||
kernel.SetArgument(11, y_buffer());
|
||||
kernel.SetArgument(12, static_cast<int>(y_offset));
|
||||
kernel.SetArgument(13, static_cast<int>(y_inc));
|
||||
kernel.SetArgument(14, static_cast<int>(a_conjugate));
|
||||
|
||||
// Launches the kernel
|
||||
auto global = std::vector<size_t>{global_size};
|
||||
|
|
|
@ -90,6 +90,7 @@ void XgemvTune(const Arguments<T> &args, const size_t variation,
|
|||
tuner.AddArgumentOutput(y_vec);
|
||||
tuner.AddArgumentScalar(0);
|
||||
tuner.AddArgumentScalar(1);
|
||||
tuner.AddArgumentScalar(0); // Conjugate transpose
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
|
Loading…
Reference in New Issue