set the correct flop count for xgemm
parent
ce44c3adb5
commit
aec45ea637
|
@ -159,7 +159,13 @@ TunerSettings XgemmGetTunerSettings(const int V, const Arguments<T> &args) {
|
|||
}
|
||||
|
||||
// Describes how to compute the performance metrics
|
||||
settings.metric_amount = 2 * args.m * args.n * args.k;
|
||||
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||
// complex flops
|
||||
settings.metric_amount = args.m * args.n * (8 * args.k - 2);
|
||||
} else {
|
||||
// scalar flops
|
||||
settings.metric_amount = args.m * args.n * (2 * args.k - 1);
|
||||
}
|
||||
settings.performance_unit = "GFLOPS";
|
||||
|
||||
return settings;
|
||||
|
|
|
@ -193,7 +193,13 @@ class TestXgemm {
|
|||
|
||||
// Describes how to compute performance metrics
|
||||
static size_t GetFlops(const Arguments<T> &args) {
|
||||
return 2 * args.m * args.n * args.k;
|
||||
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||
// complex flops
|
||||
return args.m * args.n * (8 * args.k - 2);
|
||||
} else {
|
||||
// scalar flops
|
||||
return args.m * args.n * (2 * args.k - 1);
|
||||
}
|
||||
}
|
||||
static size_t GetBytes(const Arguments<T> &args) {
|
||||
return (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
|
||||
|
|
|
@ -169,7 +169,14 @@ class TestXsymm {
|
|||
|
||||
// Describes how to compute performance metrics
|
||||
static size_t GetFlops(const Arguments<T> &args) {
|
||||
return 2 * args.m * args.n * args.m;
|
||||
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||
// complex flops
|
||||
return 8 * args.m * args.n * args.m;
|
||||
} else {
|
||||
// scalar flops
|
||||
return 2 * args.m * args.n * args.m;
|
||||
}
|
||||
|
||||
}
|
||||
static size_t GetBytes(const Arguments<T> &args) {
|
||||
return (args.m*args.m + args.m*args.n + 2*args.m*args.n) * sizeof(T);
|
||||
|
|
|
@ -153,7 +153,13 @@ class TestXsyrk {
|
|||
|
||||
// Describes how to compute performance metrics
|
||||
static size_t GetFlops(const Arguments<T> &args) {
|
||||
return args.n * args.n * args.k;
|
||||
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||
// complex flops
|
||||
return 4 * args.n * args.n * args.k;
|
||||
} else {
|
||||
// scalar flops
|
||||
return args.n * args.n * args.k;
|
||||
}
|
||||
}
|
||||
static size_t GetBytes(const Arguments<T> &args) {
|
||||
return (args.n*args.k + args.n*args.n) * sizeof(T);
|
||||
|
|
|
@ -162,7 +162,13 @@ class TestXtrmm {
|
|||
// Describes how to compute performance metrics
|
||||
static size_t GetFlops(const Arguments<T> &args) {
|
||||
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
||||
return args.m * args.n * k;
|
||||
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||
// complex flops
|
||||
return 4 * args.m * args.n * k;
|
||||
} else {
|
||||
// scalar flops
|
||||
return args.m * args.n * k;
|
||||
}
|
||||
}
|
||||
static size_t GetBytes(const Arguments<T> &args) {
|
||||
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
||||
|
|
|
@ -173,7 +173,13 @@ class TestXtrsm {
|
|||
// Describes how to compute performance metrics
|
||||
static size_t GetFlops(const Arguments<T> &args) {
|
||||
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
||||
return args.m * args.n * k;
|
||||
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||
// complex flops
|
||||
return 4 * args.m * args.n * k;
|
||||
} else {
|
||||
// scalar flops
|
||||
return args.m * args.n * k;
|
||||
}
|
||||
}
|
||||
static size_t GetBytes(const Arguments<T> &args) {
|
||||
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
||||
|
|
|
@ -217,7 +217,13 @@ class TestXgemmBatched {
|
|||
|
||||
// Describes how to compute performance metrics
|
||||
static size_t GetFlops(const Arguments<T> &args) {
|
||||
return args.batch_count * (2 * args.m * args.n * args.k);
|
||||
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||
// complex flops
|
||||
return args.batch_count * args.m * args.n * (8 * args.k - 2);
|
||||
} else {
|
||||
// scalar flops
|
||||
return args.batch_count * args.m * args.n * (2 * args.k - 1);
|
||||
}
|
||||
}
|
||||
static size_t GetBytes(const Arguments<T> &args) {
|
||||
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
|
||||
|
|
|
@ -204,7 +204,13 @@ public:
|
|||
|
||||
// Describes how to compute performance metrics
|
||||
static size_t GetFlops(const Arguments<T> &args) {
|
||||
return args.batch_count * (2 * args.m * args.n * args.k);
|
||||
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||
// complex flops
|
||||
return args.batch_count * args.m * args.n * (8 * args.k - 2);
|
||||
} else {
|
||||
// scalar flops
|
||||
return args.batch_count * args.m * args.n * (2 * args.k - 1);
|
||||
}
|
||||
}
|
||||
static size_t GetBytes(const Arguments<T> &args) {
|
||||
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
|
||||
|
|
Loading…
Reference in New Issue