Added argument checking for the GEMM tuner: expects m/n to be multiples of MWG/NWG

pull/282/head
Cedric Nugteren 2018-03-30 10:23:33 +02:00
parent 7e69c422af
commit d86ff75fa5
1 changed files with 10 additions and 1 deletions

View File

@ -118,7 +118,16 @@ TunerSettings XgemmGetTunerSettings(const int V, const Arguments<T> &args) {
// Tests for valid arguments
template <typename T>
void XgemmTestValidArguments(const int, const Arguments<T> &) { }
void XgemmTestValidArguments(const int V, const Arguments<T> &args) {
const auto mwg_max = (V == 1) ? 64 : 128;
const auto nwg_max = (V == 1) ? 64 : 128;
if (!IsMultiple(args.m, mwg_max)) {
throw std::runtime_error("'Xgemm' kernel requires 'm' to be a multiple of MWG (max " + ToString(mwg_max) + ")");
}
if (!IsMultiple(args.n, nwg_max)) {
throw std::runtime_error("'Xgemm' kernel requires 'n' to be a multiple of NWG (max " + ToString(nwg_max) + ")");
}
}
std::vector<Constraint> XgemmSetConstraints(const int V) {
auto constraints = std::vector<Constraint>();
auto MultipleOfX = [] (std::vector<size_t> v) { return IsMultiple(v[0], v[1]); };