Merge branch 'development' into benchmarking

pull/148/head
Cedric Nugteren 2017-04-21 21:59:48 +02:00
commit 957aaae6ca
2 changed files with 51 additions and 0 deletions

View File

@ -29,6 +29,40 @@ VENDOR_TRANSLATION_TABLE = {
}
def remove_mismatched_arguments(database):
"""Checks for tuning results with mis-matched entries and removes them according to user preferences"""
kernel_attributes = clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"]
# For Python 2 and 3 compatibility
try:
user_input = raw_input
except NameError:
user_input = input
pass
# Check for mis-matched entries
for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
if len(group_by_arguments) != 1:
print("[database] WARNING: entries for a single kernel with multiple argument values " + str(kernel_group_name))
print("[database] Either quit now, or remove all but one of the argument combinations below:")
for index, (attribute_group_name, mismatching_entries) in enumerate(group_by_arguments):
print("[database] %d: %s" % (index, attribute_group_name))
for attribute_group_name, mismatching_entries in group_by_arguments:
response = user_input("[database] Remove entries corresponding to %s, [y/n]? " % str(attribute_group_name))
if response == "y":
for entry in mismatching_entries:
database["sections"].remove(entry)
print("[database] Removed %d entry/entries" % len(mismatching_entries))
# Sanity-check: all mis-matched entries should be removed
for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
if len(group_by_arguments) != 1:
print("[database] ERROR: entries for a single kernel with multiple argument values " + str(kernel_group_name))
assert len(group_by_arguments) == 1
def main(argv):
# Parses the command-line arguments
@ -76,6 +110,9 @@ def main(argv):
new_size = db.length(database)
print("with " + str(new_size - old_size) + " new items") # Newline printed here
# Checks for tuning results with mis-matched entries
remove_mismatched_arguments(database)
# Stores the modified database back to disk
if len(glob.glob(json_files)) >= 1:
io.save_database(database, database_filename)

View File

@ -5,6 +5,9 @@
# Author(s):
# Cedric Nugteren <www.cedricnugteren.nl>
import itertools
from operator import itemgetter
import clblast
@ -62,3 +65,14 @@ def combine_result(old_results, new_result):
# No match found: append a new result
old_results.append(new_result)
return old_results
def group_by(database, attributes):
"""Returns an list with the name of the group and the corresponding entries in the database"""
assert len(database) > 0
attributes = [a for a in attributes if a in database[0]]
database.sort(key=itemgetter(*attributes))
result = []
for key, data in itertools.groupby(database, key=itemgetter(*attributes)):
result.append((key, list(data)))
return result