mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-07-07 12:23:46 +02:00
Merge branch 'development' into benchmarking
This commit is contained in:
commit
957aaae6ca
|
@ -29,6 +29,40 @@ VENDOR_TRANSLATION_TABLE = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_mismatched_arguments(database):
|
||||||
|
"""Checks for tuning results with mis-matched entries and removes them according to user preferences"""
|
||||||
|
kernel_attributes = clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"]
|
||||||
|
|
||||||
|
# For Python 2 and 3 compatibility
|
||||||
|
try:
|
||||||
|
user_input = raw_input
|
||||||
|
except NameError:
|
||||||
|
user_input = input
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check for mis-matched entries
|
||||||
|
for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
|
||||||
|
group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
|
||||||
|
if len(group_by_arguments) != 1:
|
||||||
|
print("[database] WARNING: entries for a single kernel with multiple argument values " + str(kernel_group_name))
|
||||||
|
print("[database] Either quit now, or remove all but one of the argument combinations below:")
|
||||||
|
for index, (attribute_group_name, mismatching_entries) in enumerate(group_by_arguments):
|
||||||
|
print("[database] %d: %s" % (index, attribute_group_name))
|
||||||
|
for attribute_group_name, mismatching_entries in group_by_arguments:
|
||||||
|
response = user_input("[database] Remove entries corresponding to %s, [y/n]? " % str(attribute_group_name))
|
||||||
|
if response == "y":
|
||||||
|
for entry in mismatching_entries:
|
||||||
|
database["sections"].remove(entry)
|
||||||
|
print("[database] Removed %d entry/entries" % len(mismatching_entries))
|
||||||
|
|
||||||
|
# Sanity-check: all mis-matched entries should be removed
|
||||||
|
for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
|
||||||
|
group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
|
||||||
|
if len(group_by_arguments) != 1:
|
||||||
|
print("[database] ERROR: entries for a single kernel with multiple argument values " + str(kernel_group_name))
|
||||||
|
assert len(group_by_arguments) == 1
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
|
|
||||||
# Parses the command-line arguments
|
# Parses the command-line arguments
|
||||||
|
@ -76,6 +110,9 @@ def main(argv):
|
||||||
new_size = db.length(database)
|
new_size = db.length(database)
|
||||||
print("with " + str(new_size - old_size) + " new items") # Newline printed here
|
print("with " + str(new_size - old_size) + " new items") # Newline printed here
|
||||||
|
|
||||||
|
# Checks for tuning results with mis-matched entries
|
||||||
|
remove_mismatched_arguments(database)
|
||||||
|
|
||||||
# Stores the modified database back to disk
|
# Stores the modified database back to disk
|
||||||
if len(glob.glob(json_files)) >= 1:
|
if len(glob.glob(json_files)) >= 1:
|
||||||
io.save_database(database, database_filename)
|
io.save_database(database, database_filename)
|
||||||
|
|
|
@ -5,6 +5,9 @@
|
||||||
# Author(s):
|
# Author(s):
|
||||||
# Cedric Nugteren <www.cedricnugteren.nl>
|
# Cedric Nugteren <www.cedricnugteren.nl>
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
import clblast
|
import clblast
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,3 +65,14 @@ def combine_result(old_results, new_result):
|
||||||
# No match found: append a new result
|
# No match found: append a new result
|
||||||
old_results.append(new_result)
|
old_results.append(new_result)
|
||||||
return old_results
|
return old_results
|
||||||
|
|
||||||
|
|
||||||
|
def group_by(database, attributes):
|
||||||
|
"""Returns an list with the name of the group and the corresponding entries in the database"""
|
||||||
|
assert len(database) > 0
|
||||||
|
attributes = [a for a in attributes if a in database[0]]
|
||||||
|
database.sort(key=itemgetter(*attributes))
|
||||||
|
result = []
|
||||||
|
for key, data in itertools.groupby(database, key=itemgetter(*attributes)):
|
||||||
|
result.append((key, list(data)))
|
||||||
|
return result
|
||||||
|
|
Loading…
Reference in a new issue