set the correct flop count for xgemm

This commit is contained in:
JishinMaster 2021-03-07 21:44:20 +01:00
parent ce44c3adb5
commit aec45ea637
8 changed files with 57 additions and 8 deletions

View file

@ -159,7 +159,13 @@ TunerSettings XgemmGetTunerSettings(const int V, const Arguments<T> &args) {
} }
// Describes how to compute the performance metrics // Describes how to compute the performance metrics
settings.metric_amount = 2 * args.m * args.n * args.k; if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
settings.metric_amount = args.m * args.n * (8 * args.k - 2);
} else {
// scalar flops
settings.metric_amount = args.m * args.n * (2 * args.k - 1);
}
settings.performance_unit = "GFLOPS"; settings.performance_unit = "GFLOPS";
return settings; return settings;

View file

@ -193,7 +193,13 @@ class TestXgemm {
// Describes how to compute performance metrics // Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) { static size_t GetFlops(const Arguments<T> &args) {
return 2 * args.m * args.n * args.k; if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return args.m * args.n * (8 * args.k - 2);
} else {
// scalar flops
return args.m * args.n * (2 * args.k - 1);
}
} }
static size_t GetBytes(const Arguments<T> &args) { static size_t GetBytes(const Arguments<T> &args) {
return (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T); return (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);

View file

@ -169,8 +169,15 @@ class TestXsymm {
// Describes how to compute performance metrics // Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) { static size_t GetFlops(const Arguments<T> &args) {
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return 8 * args.m * args.n * args.m;
} else {
// scalar flops
return 2 * args.m * args.n * args.m; return 2 * args.m * args.n * args.m;
} }
}
static size_t GetBytes(const Arguments<T> &args) { static size_t GetBytes(const Arguments<T> &args) {
return (args.m*args.m + args.m*args.n + 2*args.m*args.n) * sizeof(T); return (args.m*args.m + args.m*args.n + 2*args.m*args.n) * sizeof(T);
} }

View file

@ -153,8 +153,14 @@ class TestXsyrk {
// Describes how to compute performance metrics // Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) { static size_t GetFlops(const Arguments<T> &args) {
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return 4 * args.n * args.n * args.k;
} else {
// scalar flops
return args.n * args.n * args.k; return args.n * args.n * args.k;
} }
}
static size_t GetBytes(const Arguments<T> &args) { static size_t GetBytes(const Arguments<T> &args) {
return (args.n*args.k + args.n*args.n) * sizeof(T); return (args.n*args.k + args.n*args.n) * sizeof(T);
} }

View file

@ -162,8 +162,14 @@ class TestXtrmm {
// Describes how to compute performance metrics // Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) { static size_t GetFlops(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n; auto k = (args.side == Side::kLeft) ? args.m : args.n;
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return 4 * args.m * args.n * k;
} else {
// scalar flops
return args.m * args.n * k; return args.m * args.n * k;
} }
}
static size_t GetBytes(const Arguments<T> &args) { static size_t GetBytes(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n; auto k = (args.side == Side::kLeft) ? args.m : args.n;
return (k*k + 2*args.m*args.n) * sizeof(T); return (k*k + 2*args.m*args.n) * sizeof(T);

View file

@ -173,8 +173,14 @@ class TestXtrsm {
// Describes how to compute performance metrics // Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) { static size_t GetFlops(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n; auto k = (args.side == Side::kLeft) ? args.m : args.n;
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return 4 * args.m * args.n * k;
} else {
// scalar flops
return args.m * args.n * k; return args.m * args.n * k;
} }
}
static size_t GetBytes(const Arguments<T> &args) { static size_t GetBytes(const Arguments<T> &args) {
auto k = (args.side == Side::kLeft) ? args.m : args.n; auto k = (args.side == Side::kLeft) ? args.m : args.n;
return (k*k + 2*args.m*args.n) * sizeof(T); return (k*k + 2*args.m*args.n) * sizeof(T);

View file

@ -217,7 +217,13 @@ class TestXgemmBatched {
// Describes how to compute performance metrics // Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) { static size_t GetFlops(const Arguments<T> &args) {
return args.batch_count * (2 * args.m * args.n * args.k); if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return args.batch_count * args.m * args.n * (8 * args.k - 2);
} else {
// scalar flops
return args.batch_count * args.m * args.n * (2 * args.k - 1);
}
} }
static size_t GetBytes(const Arguments<T> &args) { static size_t GetBytes(const Arguments<T> &args) {
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T); return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);

View file

@ -204,7 +204,13 @@ public:
// Describes how to compute performance metrics // Describes how to compute performance metrics
static size_t GetFlops(const Arguments<T> &args) { static size_t GetFlops(const Arguments<T> &args) {
return args.batch_count * (2 * args.m * args.n * args.k); if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
// complex flops
return args.batch_count * args.m * args.n * (8 * args.k - 2);
} else {
// scalar flops
return args.batch_count * args.m * args.n * (2 * args.k - 1);
}
} }
static size_t GetBytes(const Arguments<T> &args) { static size_t GetBytes(const Arguments<T> &args) {
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T); return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);