Added symmetric matrix support for the ABC performance tester

This commit is contained in:
CNugteren 2015-06-26 08:10:23 +02:00
parent ff9f9fac57
commit 75f263ce3a
4 changed files with 16 additions and 15 deletions

View file

@ -239,14 +239,15 @@ template void ClientAC<double2>(int, char **, Routine2<double2>, const std::vect
// This is the matrix-matrix-matrix variant of the set-up/tear-down client routine. // This is the matrix-matrix-matrix variant of the set-up/tear-down client routine.
template <typename T> template <typename T>
void ClientABC(int argc, char *argv[], Routine3<T> client_routine, void ClientABC(int argc, char *argv[], Routine3<T> client_routine,
const std::vector<std::string> &options) { const std::vector<std::string> &options, const bool symmetric) {
// Function to determine how to find the default value of the leading dimension of matrix A // Function to determine how to find the default value of the leading dimension of matrix A
auto default_ld_a = [](const Arguments<T> args) { return args.m; }; auto default_ld_a = [&symmetric](const Arguments<T> args) { return (symmetric) ? args.n : args.m; };
// Simple command line argument parser with defaults // Simple command line argument parser with defaults
auto args = ParseArguments<T>(argc, argv, options, default_ld_a); auto args = ParseArguments<T>(argc, argv, options, default_ld_a);
if (args.print_help) { return; } if (args.print_help) { return; }
if (symmetric) { args.m = args.n; }
// Prints the header of the output table // Prints the header of the output table
PrintTableHeader(args.silent, options); PrintTableHeader(args.silent, options);
@ -314,10 +315,10 @@ void ClientABC(int argc, char *argv[], Routine3<T> client_routine,
} }
// Compiles the above function // Compiles the above function
template void ClientABC<float>(int, char **, Routine3<float>, const std::vector<std::string>&); template void ClientABC<float>(int, char **, Routine3<float>, const std::vector<std::string>&, const bool);
template void ClientABC<double>(int, char **, Routine3<double>, const std::vector<std::string>&); template void ClientABC<double>(int, char **, Routine3<double>, const std::vector<std::string>&, const bool);
template void ClientABC<float2>(int, char **, Routine3<float2>, const std::vector<std::string>&); template void ClientABC<float2>(int, char **, Routine3<float2>, const std::vector<std::string>&, const bool);
template void ClientABC<double2>(int, char **, Routine3<double2>, const std::vector<std::string>&); template void ClientABC<double2>(int, char **, Routine3<double2>, const std::vector<std::string>&, const bool);
// ================================================================================================= // =================================================================================================

View file

@ -56,7 +56,7 @@ void ClientAC(int argc, char *argv[], Routine2<T> client_routine,
const std::vector<std::string> &options); const std::vector<std::string> &options);
template <typename T> template <typename T>
void ClientABC(int argc, char *argv[], Routine3<T> client_routine, void ClientABC(int argc, char *argv[], Routine3<T> client_routine,
const std::vector<std::string> &options); const std::vector<std::string> &options, const bool symmetric);
// ================================================================================================= // =================================================================================================

View file

@ -96,10 +96,10 @@ void ClientXgemm(int argc, char *argv[]) {
kArgAlpha, kArgBeta}; kArgAlpha, kArgBeta};
switch(GetPrecision(argc, argv)) { switch(GetPrecision(argc, argv)) {
case Precision::kHalf: throw std::runtime_error("Unsupported precision mode"); case Precision::kHalf: throw std::runtime_error("Unsupported precision mode");
case Precision::kSingle: ClientABC<float>(argc, argv, PerformanceXgemm<float>, o); break; case Precision::kSingle: ClientABC<float>(argc, argv, PerformanceXgemm<float>, o, false); break;
case Precision::kDouble: ClientABC<double>(argc, argv, PerformanceXgemm<double>, o); break; case Precision::kDouble: ClientABC<double>(argc, argv, PerformanceXgemm<double>, o, false); break;
case Precision::kComplexSingle: ClientABC<float2>(argc, argv, PerformanceXgemm<float2>, o); break; case Precision::kComplexSingle: ClientABC<float2>(argc, argv, PerformanceXgemm<float2>, o, false); break;
case Precision::kComplexDouble: ClientABC<double2>(argc, argv, PerformanceXgemm<double2>, o); break; case Precision::kComplexDouble: ClientABC<double2>(argc, argv, PerformanceXgemm<double2>, o, false); break;
} }
} }

View file

@ -96,10 +96,10 @@ void ClientXsymm(int argc, char *argv[]) {
kArgAlpha, kArgBeta}; kArgAlpha, kArgBeta};
switch(GetPrecision(argc, argv)) { switch(GetPrecision(argc, argv)) {
case Precision::kHalf: throw std::runtime_error("Unsupported precision mode"); case Precision::kHalf: throw std::runtime_error("Unsupported precision mode");
case Precision::kSingle: ClientABC<float>(argc, argv, PerformanceXsymm<float>, o); break; case Precision::kSingle: ClientABC<float>(argc, argv, PerformanceXsymm<float>, o, false); break;
case Precision::kDouble: ClientABC<double>(argc, argv, PerformanceXsymm<double>, o); break; case Precision::kDouble: ClientABC<double>(argc, argv, PerformanceXsymm<double>, o, false); break;
case Precision::kComplexSingle: ClientABC<float2>(argc, argv, PerformanceXsymm<float2>, o); break; case Precision::kComplexSingle: ClientABC<float2>(argc, argv, PerformanceXsymm<float2>, o, false); break;
case Precision::kComplexDouble: ClientABC<double2>(argc, argv, PerformanceXsymm<double2>, o); break; case Precision::kComplexDouble: ClientABC<double2>(argc, argv, PerformanceXsymm<double2>, o, false); break;
} }
} }