Merge branch 'development' into benchmarking
commit
957aaae6ca
|
@ -29,6 +29,40 @@ VENDOR_TRANSLATION_TABLE = {
|
|||
}
|
||||
|
||||
|
||||
def remove_mismatched_arguments(database):
|
||||
"""Checks for tuning results with mis-matched entries and removes them according to user preferences"""
|
||||
kernel_attributes = clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"]
|
||||
|
||||
# For Python 2 and 3 compatibility
|
||||
try:
|
||||
user_input = raw_input
|
||||
except NameError:
|
||||
user_input = input
|
||||
pass
|
||||
|
||||
# Check for mis-matched entries
|
||||
for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
|
||||
group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
|
||||
if len(group_by_arguments) != 1:
|
||||
print("[database] WARNING: entries for a single kernel with multiple argument values " + str(kernel_group_name))
|
||||
print("[database] Either quit now, or remove all but one of the argument combinations below:")
|
||||
for index, (attribute_group_name, mismatching_entries) in enumerate(group_by_arguments):
|
||||
print("[database] %d: %s" % (index, attribute_group_name))
|
||||
for attribute_group_name, mismatching_entries in group_by_arguments:
|
||||
response = user_input("[database] Remove entries corresponding to %s, [y/n]? " % str(attribute_group_name))
|
||||
if response == "y":
|
||||
for entry in mismatching_entries:
|
||||
database["sections"].remove(entry)
|
||||
print("[database] Removed %d entry/entries" % len(mismatching_entries))
|
||||
|
||||
# Sanity-check: all mis-matched entries should be removed
|
||||
for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
|
||||
group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
|
||||
if len(group_by_arguments) != 1:
|
||||
print("[database] ERROR: entries for a single kernel with multiple argument values " + str(kernel_group_name))
|
||||
assert len(group_by_arguments) == 1
|
||||
|
||||
|
||||
def main(argv):
|
||||
|
||||
# Parses the command-line arguments
|
||||
|
@ -76,6 +110,9 @@ def main(argv):
|
|||
new_size = db.length(database)
|
||||
print("with " + str(new_size - old_size) + " new items") # Newline printed here
|
||||
|
||||
# Checks for tuning results with mis-matched entries
|
||||
remove_mismatched_arguments(database)
|
||||
|
||||
# Stores the modified database back to disk
|
||||
if len(glob.glob(json_files)) >= 1:
|
||||
io.save_database(database, database_filename)
|
||||
|
|
|
@ -5,6 +5,9 @@
|
|||
# Author(s):
|
||||
# Cedric Nugteren <www.cedricnugteren.nl>
|
||||
|
||||
import itertools
|
||||
from operator import itemgetter
|
||||
|
||||
import clblast
|
||||
|
||||
|
||||
|
@ -62,3 +65,14 @@ def combine_result(old_results, new_result):
|
|||
# No match found: append a new result
|
||||
old_results.append(new_result)
|
||||
return old_results
|
||||
|
||||
|
||||
def group_by(database, attributes):
|
||||
"""Returns an list with the name of the group and the corresponding entries in the database"""
|
||||
assert len(database) > 0
|
||||
attributes = [a for a in attributes if a in database[0]]
|
||||
database.sort(key=itemgetter(*attributes))
|
||||
result = []
|
||||
for key, data in itertools.groupby(database, key=itemgetter(*attributes)):
|
||||
result.append((key, list(data)))
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue