Added argument checking for the GEMM tuner: expects m/n to be multiples of MWG/NWG
parent
7e69c422af
commit
d86ff75fa5
|
@ -118,7 +118,16 @@ TunerSettings XgemmGetTunerSettings(const int V, const Arguments<T> &args) {
|
|||
|
||||
// Tests for valid arguments
|
||||
template <typename T>
|
||||
void XgemmTestValidArguments(const int, const Arguments<T> &) { }
|
||||
void XgemmTestValidArguments(const int V, const Arguments<T> &args) {
|
||||
const auto mwg_max = (V == 1) ? 64 : 128;
|
||||
const auto nwg_max = (V == 1) ? 64 : 128;
|
||||
if (!IsMultiple(args.m, mwg_max)) {
|
||||
throw std::runtime_error("'Xgemm' kernel requires 'm' to be a multiple of MWG (max " + ToString(mwg_max) + ")");
|
||||
}
|
||||
if (!IsMultiple(args.n, nwg_max)) {
|
||||
throw std::runtime_error("'Xgemm' kernel requires 'n' to be a multiple of NWG (max " + ToString(nwg_max) + ")");
|
||||
}
|
||||
}
|
||||
std::vector<Constraint> XgemmSetConstraints(const int V) {
|
||||
auto constraints = std::vector<Constraint>();
|
||||
auto MultipleOfX = [] (std::vector<size_t> v) { return IsMultiple(v[0], v[1]); };
|
||||
|
|
Loading…
Reference in New Issue