mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-07-04 21:36:57 +02:00
set the correct flop count for xgemm
This commit is contained in:
parent
ce44c3adb5
commit
aec45ea637
|
@ -159,7 +159,13 @@ TunerSettings XgemmGetTunerSettings(const int V, const Arguments<T> &args) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Describes how to compute the performance metrics
|
// Describes how to compute the performance metrics
|
||||||
settings.metric_amount = 2 * args.m * args.n * args.k;
|
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||||
|
// complex flops
|
||||||
|
settings.metric_amount = args.m * args.n * (8 * args.k - 2);
|
||||||
|
} else {
|
||||||
|
// scalar flops
|
||||||
|
settings.metric_amount = args.m * args.n * (2 * args.k - 1);
|
||||||
|
}
|
||||||
settings.performance_unit = "GFLOPS";
|
settings.performance_unit = "GFLOPS";
|
||||||
|
|
||||||
return settings;
|
return settings;
|
||||||
|
|
|
@ -193,7 +193,13 @@ class TestXgemm {
|
||||||
|
|
||||||
// Describes how to compute performance metrics
|
// Describes how to compute performance metrics
|
||||||
static size_t GetFlops(const Arguments<T> &args) {
|
static size_t GetFlops(const Arguments<T> &args) {
|
||||||
return 2 * args.m * args.n * args.k;
|
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||||
|
// complex flops
|
||||||
|
return args.m * args.n * (8 * args.k - 2);
|
||||||
|
} else {
|
||||||
|
// scalar flops
|
||||||
|
return args.m * args.n * (2 * args.k - 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
static size_t GetBytes(const Arguments<T> &args) {
|
static size_t GetBytes(const Arguments<T> &args) {
|
||||||
return (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
|
return (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
|
||||||
|
|
|
@ -169,7 +169,14 @@ class TestXsymm {
|
||||||
|
|
||||||
// Describes how to compute performance metrics
|
// Describes how to compute performance metrics
|
||||||
static size_t GetFlops(const Arguments<T> &args) {
|
static size_t GetFlops(const Arguments<T> &args) {
|
||||||
return 2 * args.m * args.n * args.m;
|
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||||
|
// complex flops
|
||||||
|
return 8 * args.m * args.n * args.m;
|
||||||
|
} else {
|
||||||
|
// scalar flops
|
||||||
|
return 2 * args.m * args.n * args.m;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
static size_t GetBytes(const Arguments<T> &args) {
|
static size_t GetBytes(const Arguments<T> &args) {
|
||||||
return (args.m*args.m + args.m*args.n + 2*args.m*args.n) * sizeof(T);
|
return (args.m*args.m + args.m*args.n + 2*args.m*args.n) * sizeof(T);
|
||||||
|
|
|
@ -153,7 +153,13 @@ class TestXsyrk {
|
||||||
|
|
||||||
// Describes how to compute performance metrics
|
// Describes how to compute performance metrics
|
||||||
static size_t GetFlops(const Arguments<T> &args) {
|
static size_t GetFlops(const Arguments<T> &args) {
|
||||||
return args.n * args.n * args.k;
|
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||||
|
// complex flops
|
||||||
|
return 4 * args.n * args.n * args.k;
|
||||||
|
} else {
|
||||||
|
// scalar flops
|
||||||
|
return args.n * args.n * args.k;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
static size_t GetBytes(const Arguments<T> &args) {
|
static size_t GetBytes(const Arguments<T> &args) {
|
||||||
return (args.n*args.k + args.n*args.n) * sizeof(T);
|
return (args.n*args.k + args.n*args.n) * sizeof(T);
|
||||||
|
|
|
@ -162,7 +162,13 @@ class TestXtrmm {
|
||||||
// Describes how to compute performance metrics
|
// Describes how to compute performance metrics
|
||||||
static size_t GetFlops(const Arguments<T> &args) {
|
static size_t GetFlops(const Arguments<T> &args) {
|
||||||
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
||||||
return args.m * args.n * k;
|
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||||
|
// complex flops
|
||||||
|
return 4 * args.m * args.n * k;
|
||||||
|
} else {
|
||||||
|
// scalar flops
|
||||||
|
return args.m * args.n * k;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
static size_t GetBytes(const Arguments<T> &args) {
|
static size_t GetBytes(const Arguments<T> &args) {
|
||||||
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
||||||
|
|
|
@ -173,7 +173,13 @@ class TestXtrsm {
|
||||||
// Describes how to compute performance metrics
|
// Describes how to compute performance metrics
|
||||||
static size_t GetFlops(const Arguments<T> &args) {
|
static size_t GetFlops(const Arguments<T> &args) {
|
||||||
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
||||||
return args.m * args.n * k;
|
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||||
|
// complex flops
|
||||||
|
return 4 * args.m * args.n * k;
|
||||||
|
} else {
|
||||||
|
// scalar flops
|
||||||
|
return args.m * args.n * k;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
static size_t GetBytes(const Arguments<T> &args) {
|
static size_t GetBytes(const Arguments<T> &args) {
|
||||||
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
auto k = (args.side == Side::kLeft) ? args.m : args.n;
|
||||||
|
|
|
@ -217,7 +217,13 @@ class TestXgemmBatched {
|
||||||
|
|
||||||
// Describes how to compute performance metrics
|
// Describes how to compute performance metrics
|
||||||
static size_t GetFlops(const Arguments<T> &args) {
|
static size_t GetFlops(const Arguments<T> &args) {
|
||||||
return args.batch_count * (2 * args.m * args.n * args.k);
|
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||||
|
// complex flops
|
||||||
|
return args.batch_count * args.m * args.n * (8 * args.k - 2);
|
||||||
|
} else {
|
||||||
|
// scalar flops
|
||||||
|
return args.batch_count * args.m * args.n * (2 * args.k - 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
static size_t GetBytes(const Arguments<T> &args) {
|
static size_t GetBytes(const Arguments<T> &args) {
|
||||||
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
|
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
|
||||||
|
|
|
@ -204,7 +204,13 @@ public:
|
||||||
|
|
||||||
// Describes how to compute performance metrics
|
// Describes how to compute performance metrics
|
||||||
static size_t GetFlops(const Arguments<T> &args) {
|
static size_t GetFlops(const Arguments<T> &args) {
|
||||||
return args.batch_count * (2 * args.m * args.n * args.k);
|
if((args.precision == Precision::kComplexSingle) || (args.precision == Precision::kComplexDouble)) {
|
||||||
|
// complex flops
|
||||||
|
return args.batch_count * args.m * args.n * (8 * args.k - 2);
|
||||||
|
} else {
|
||||||
|
// scalar flops
|
||||||
|
return args.batch_count * args.m * args.n * (2 * args.k - 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
static size_t GetBytes(const Arguments<T> &args) {
|
static size_t GetBytes(const Arguments<T> &args) {
|
||||||
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
|
return args.batch_count * (args.m*args.k + args.k*args.n + 2*args.m*args.n) * sizeof(T);
|
||||||
|
|
Loading…
Reference in a new issue