set the correct flop count for xgemm

pull/416/head
JishinMaster 2021-03-07 21:44:20 +01:00
parent ce44c3adb5
commit aec45ea637
8 changed files with 57 additions and 8 deletions

View File

@ -159,7 +159,13 @@ TunerSettings XgemmGetTunerSettings(const int V, const Arguments<T> &args) {
}
// Describes how to compute the performance metrics
settings.metric_amount = 2 * args.m * args.n * args.k;
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
settings.metric_amount = args.m * args.n * (8 * args.k - 2);
} else {
// scalar flops
settings.metric_amount = args.m * args.n * (2 * args.k - 1);
}
settings.performance_unit = "GFLOPS";
return settings;

View File

@ -193,7 +193,13 @@ class TestXgemm {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
return 2 * args.m * args.n * args.k;
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return args.m * args.n * (8 * args.k - 2);
} else {
// scalar flops
return args.m * args.n * (2 * args.k - 1);
}
}
static size_t GetBytes(const Arguments<T> &args) {
return (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);

View File

@ -169,7 +169,14 @@ class TestXsymm {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
return 2 * args.m * args.n * args.m;
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return 8 * args.m * args.n * args.m;
} else {
// scalar flops
return 2 * args.m * args.n * args.m;
}
}
static size_t GetBytes(const Arguments<T> &args) {
return (args.m*args.m + args.m*args.n + 2*args.m*args.n) * sizeof(T);

View File

@ -153,7 +153,13 @@ class TestXsyrk {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
return args.n * args.n * args.k;
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return 4 * args.n * args.n * args.k;
} else {
// scalar flops
return args.n * args.n * args.k;
}
}
static size_t GetBytes(const Arguments<T> &args) {
return (args.n*args.k + args.n*args.n) * sizeof(T);

View File

@ -162,7 +162,13 @@ class TestXtrmm {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n;
return args.m * args.n * k;
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return 4 * args.m * args.n * k;
} else {
// scalar flops
return args.m * args.n * k;
}
}
static size_t GetBytes(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n;

View File

@ -173,7 +173,13 @@ class TestXtrsm {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n;
return args.m * args.n * k;
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return 4 * args.m * args.n * k;
} else {
// scalar flops
return args.m * args.n * k;
}
}
static size_t GetBytes(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n;

View File

@ -217,7 +217,13 @@ class TestXgemmBatched {
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
return args.batch_count * (2 * args.m * args.n * args.k);
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return args.batch_count * args.m * args.n * (8 * args.k - 2);
} else {
// scalar flops
return args.batch_count * args.m * args.n * (2 * args.k - 1);
}
}
static size_t GetBytes(const Arguments<T> &args) {
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);

View File

@ -204,7 +204,13 @@ public:
// Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) {
return args.batch_count * (2 * args.m * args.n * args.k);
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return args.batch_count * args.m * args.n * (8 * args.k - 2);
} else {
// scalar flops
return args.batch_count * args.m * args.n * (2 * args.k - 1);
}
}
static size_t GetBytes(const Arguments<T> &args) {
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);