Fixed issue with not performing complex conjugation under certain cases when transposing

pull/309/head
Cedric Nugteren 2018-07-31 21:49:37 +02:00
parent 391e5757bd
commit 503ab74f02
2 changed files with 7 additions and 3 deletions

View File

@ -2,6 +2,7 @@
Development (next version) Development (next version)
- Added support for shuffle instructions for NVIDIA GPUs (thanks to 'tyler-utah') - Added support for shuffle instructions for NVIDIA GPUs (thanks to 'tyler-utah')
- The tuners now check beforehand on invalid local thread sizes and skip those completely - The tuners now check beforehand on invalid local thread sizes and skip those completely
- Fixed an issue with conjugate transpose not being executed in certain cases for a.o. XOMATCOPY
- Fixed an issue with AMD GPUs and the new GEMMK == 1 kernel - Fixed an issue with AMD GPUs and the new GEMMK == 1 kernel
- Various minor fixes and enhancements - Various minor fixes and enhancements

View File

@ -76,6 +76,7 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device,
// Determines the right kernel // Determines the right kernel
auto kernel_name = std::string{}; auto kernel_name = std::string{};
auto pad_kernel = false;
if (do_transpose) { if (do_transpose) {
if (use_fast_kernel && if (use_fast_kernel &&
IsMultiple(src_ld, db["TRA_WPT"]) && IsMultiple(src_ld, db["TRA_WPT"]) &&
@ -85,7 +86,8 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device,
} }
else { else {
use_fast_kernel = false; use_fast_kernel = false;
kernel_name = (do_pad) ? "TransposePadMatrix" : "TransposeMatrix"; pad_kernel = (do_pad || do_conjugate);
kernel_name = (pad_kernel) ? "TransposePadMatrix" : "TransposeMatrix";
} }
} }
else { else {
@ -97,7 +99,8 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device,
} }
else { else {
use_fast_kernel = false; use_fast_kernel = false;
kernel_name = (do_pad) ? "CopyPadMatrix" : "CopyMatrix"; pad_kernel = do_pad;
kernel_name = (pad_kernel) ? "CopyPadMatrix" : "CopyMatrix";
} }
} }
@ -123,7 +126,7 @@ void PadCopyTransposeMatrix(Queue &queue, const Device &device,
kernel.SetArgument(8, static_cast<int>(dest_offset)); kernel.SetArgument(8, static_cast<int>(dest_offset));
kernel.SetArgument(9, dest()); kernel.SetArgument(9, dest());
kernel.SetArgument(10, GetRealArg(alpha)); kernel.SetArgument(10, GetRealArg(alpha));
if (do_pad) { if (pad_kernel) {
kernel.SetArgument(11, static_cast<int>(do_conjugate)); kernel.SetArgument(11, static_cast<int>(do_conjugate));
} }
else { else {