mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-08-20 20:12:26 +02:00
114 lines
4.5 KiB
Python
Executable file
114 lines
4.5 KiB
Python
Executable file
#!/usr/bin/env python
|
|
|
|
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
|
|
# PEP8 Python style guide and uses a max-width of 120 characters per line.
|
|
#
|
|
# Author(s):
|
|
# Cedric Nugteren <www.cedricnugteren.nl>
|
|
|
|
import sys
|
|
import os.path
|
|
import glob
|
|
import argparse
|
|
|
|
import pandas as pd
|
|
|
|
import database.io as io
|
|
import database.db as db
|
|
import database.clblast as clblast
|
|
import database.bests as bests
|
|
import database.defaults as defaults
|
|
|
|
# Server storing a copy of the database
|
|
DATABASE_SERVER_URL = "http://www.cedricnugteren.nl/tuning/clblast.db"
|
|
|
|
# OpenCL vendor names and their short name
|
|
VENDOR_TRANSLATION_TABLE = {"device_vendor": {
|
|
"GenuineIntel": "Intel",
|
|
"Intel(R) Corporation": "Intel",
|
|
"Advanced Micro Devices, Inc.": "AMD",
|
|
"NVIDIA Corporation": "NVIDIA",
|
|
}}
|
|
|
|
|
|
def main(argv):
|
|
|
|
# Parses the command-line arguments
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("source_folder", help="The folder with JSON files to parse to add to the database")
|
|
parser.add_argument("clblast_root", help="Root of the CLBlast sources")
|
|
parser.add_argument("-v", "--verbose", action="store_true", help="Increase verbosity of the script")
|
|
cl_args = parser.parse_args(argv)
|
|
|
|
# Parses the path arguments
|
|
database_filename = os.path.join(cl_args.clblast_root, "scripts", "database", "database.db")
|
|
json_files = os.path.join(cl_args.source_folder, "*.json")
|
|
cpp_database_path = os.path.join(cl_args.clblast_root, "src", "database", "kernels")
|
|
|
|
# Checks whether the command-line arguments are valid
|
|
clblast_header = os.path.join(cl_args.clblast_root, "include", "clblast.h") # Not used but just for validation
|
|
if not os.path.isfile(clblast_header):
|
|
raise RuntimeError("The path '" + cl_args.clblast_root + "' does not point to the root of the CLBlast library")
|
|
if len(glob.glob(json_files)) < 1:
|
|
print("[database] The path '" + cl_args.source_folder + "' does not contain any JSON files")
|
|
|
|
# Pandas options
|
|
pd.set_option('display.width', 1000)
|
|
if cl_args.verbose:
|
|
print("[database] Using pandas version " + pd.__version__)
|
|
|
|
# Downloads the database if a local copy is not present
|
|
if not os.path.isfile(database_filename):
|
|
io.download_database(database_filename, DATABASE_SERVER_URL)
|
|
|
|
# Loads the database from disk
|
|
database = io.load_database(database_filename)
|
|
|
|
# Loops over all JSON files in the supplied folder
|
|
for file_json in glob.glob(json_files):
|
|
|
|
# Loads the newly imported data
|
|
sys.stdout.write("[database] Processing '"+file_json+"' ") # No newline printed
|
|
imported_data = io.load_json_to_pandas(file_json)
|
|
|
|
# Fixes the problem that some vendors use multiple different names
|
|
imported_data = db.find_and_replace(imported_data, VENDOR_TRANSLATION_TABLE)
|
|
|
|
# Adds the new data to the database
|
|
old_size = len(database.index)
|
|
database = db.concatenate_database(database, imported_data)
|
|
database = db.remove_duplicates(database)
|
|
new_size = len(database.index)
|
|
print("with " + str(new_size - old_size) + " new items") # Newline printed here
|
|
|
|
# Stores the modified database back to disk
|
|
if len(glob.glob(json_files)) >= 1:
|
|
io.save_database(database, database_filename)
|
|
|
|
# Optional: update the database here. Default is disabled, code below is just an example
|
|
if False: # TODO: Use command-line arguments to enable updates in a flexible way
|
|
database = db.update_database(database,
|
|
((database["kernel"] == "CopyMatrixFast") &
|
|
(database["precision"] == "3232")),
|
|
"arg_alpha", "2+0.5i")
|
|
io.save_database(database, database_filename)
|
|
|
|
# Retrieves the best performing results
|
|
print("[database] Calculating the best results per device/kernel...")
|
|
database_best_results = bests.get_best_results(database)
|
|
|
|
# Determines the defaults for other vendors and per vendor
|
|
print("[database] Calculating the default values...")
|
|
database_defaults = defaults.calculate_defaults(database)
|
|
database_best_results = db.concatenate_database(database_best_results, database_defaults)
|
|
|
|
# Outputs the database as a C++ database
|
|
print("[database] Producing a C++ database in '" + cpp_database_path + "'...")
|
|
clblast.print_cpp_database(database_best_results, cpp_database_path)
|
|
|
|
print("[database] All done")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main(sys.argv[1:])
|