From 890281f3e8b9c0523e69500d0860aa7085e7fbe1 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Sat, 23 Sep 2017 17:50:44 +0200 Subject: [PATCH] Made database-caching no longer dependent on device name but on device/platform IDs --- src/cache.hpp | 6 +++--- src/clblast.cpp | 9 +++++---- src/clpp11.hpp | 1 + src/routine.cpp | 13 +++++++------ src/routine.hpp | 4 +--- test/correctness/tester.cpp | 2 +- 6 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/cache.hpp b/src/cache.hpp index ed693ea3..f6a948b6 100644 --- a/src/cache.hpp +++ b/src/cache.hpp @@ -93,9 +93,9 @@ extern template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const; class Database; // The key struct for the cache of database maps. -// Order of fields: precision, device_name, kernel_name (smaller fields first) -typedef std::tuple DatabaseKey; -typedef std::tuple DatabaseKeyRef; +// Order of fields: platform_id, device_id, precision, kernel_name (smaller fields first) +typedef std::tuple DatabaseKey; +typedef std::tuple DatabaseKeyRef; typedef Cache DatabaseCache; diff --git a/src/clblast.cpp b/src/clblast.cpp index d44649bb..3983e5fc 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2492,11 +2492,12 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern // Retrieves the device name const auto device_cpp = Device(device); - const auto device_name = device_cpp.Name(); + const auto platform_id = device_cpp.Platform(); + const auto device_name = GetDeviceName(device_cpp); // Retrieves the current database values to verify whether the new ones are complete auto in_cache = false; - const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision, device_name, kernel_name }, &in_cache); + const auto current_database = DatabaseCache::Instance().Get(DatabaseKeyRef{platform_id, device, precision, kernel_name}, &in_cache); if (!in_cache) { return StatusCode::kInvalidOverrideKernel; } for (const auto ¤t_param : current_database.GetParameterNames()) { if (parameters.find(current_param) == parameters.end()) { @@ -2530,8 +2531,8 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern const auto database = Database(device_cpp, kernel_name, precision, database_entries); // Removes the old database entry and stores the new one in the cache - DatabaseCache::Instance().Remove(DatabaseKey{ precision, device_name, kernel_name }); - DatabaseCache::Instance().Store(DatabaseKey{ precision, device_name, kernel_name }, Database(database)); + DatabaseCache::Instance().Remove(DatabaseKey{platform_id, device, precision, kernel_name}); + DatabaseCache::Instance().Store(DatabaseKey{platform_id, device, precision, kernel_name}, Database(database)); } catch (...) { return DispatchException(); } return StatusCode::kSuccess; diff --git a/src/clpp11.hpp b/src/clpp11.hpp index 7c1457b0..7d348e18 100644 --- a/src/clpp11.hpp +++ b/src/clpp11.hpp @@ -230,6 +230,7 @@ class Device { } // Methods to retrieve device information + cl_platform_id Platform() const { return GetInfo(CL_DEVICE_PLATFORM); } std::string Version() const { return GetInfoString(CL_DEVICE_VERSION); } size_t VersionNumber() const { diff --git a/src/routine.cpp b/src/routine.cpp index 758ffa0c..c305feb8 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -60,7 +60,7 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, event_(event), context_(queue_.GetContext()), device_(queue_.GetDevice()), - device_name_(device_.Name()), + platform_(device_.Platform()), db_(kernel_names) { InitDatabase(userDatabase); @@ -72,13 +72,13 @@ void Routine::InitDatabase(const std::vector &userDatab // Queries the cache to see whether or not the kernel parameter database is already there bool has_db; - db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ precision_, device_name_, kernel_name }, + db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_, device_(), precision_, kernel_name }, &has_db); if (has_db) { continue; } // Builds the parameter database for this device and routine set and stores it in the cache db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase); - DatabaseCache::Instance().Store(DatabaseKey{ precision_, device_name_, kernel_name }, + DatabaseCache::Instance().Store(DatabaseKey{ platform_, device_(), precision_, kernel_name }, Database{ db_(kernel_name) }); } } @@ -100,8 +100,9 @@ void Routine::InitProgram(std::initializer_list source) { // Queries the cache to see whether or not the binary (device-specific) is already there. If it // is, a program is created and stored in the cache + const auto device_name = GetDeviceName(device_); bool has_binary; - auto binary = BinaryCache::Instance().Get(BinaryKeyRef{ precision_, routine_name_, device_name_ }, + auto binary = BinaryCache::Instance().Get(BinaryKeyRef{ precision_, routine_name_, device_name }, &has_binary); if (has_binary) { program_ = Program(device_, context_, binary); @@ -171,7 +172,7 @@ void Routine::InitProgram(std::initializer_list source) { // Prints details of the routine to compile in case of debugging in verbose mode #ifdef VERBOSE printf("[DEBUG] Compiling routine '%s-%s' for device '%s'\n", - routine_name_.c_str(), ToString(precision_).c_str(), device_name_.c_str()); + routine_name_.c_str(), ToString(precision_).c_str(), device_name.c_str()); const auto start_time = std::chrono::steady_clock::now(); #endif @@ -188,7 +189,7 @@ void Routine::InitProgram(std::initializer_list source) { } // Store the compiled binary and program in the cache - BinaryCache::Instance().Store(BinaryKey{ precision_, routine_name_, device_name_ }, + BinaryCache::Instance().Store(BinaryKey{ precision_, routine_name_, device_name }, program_.GetIR()); ProgramCache::Instance().Store(ProgramKey{ context_(), device_(), precision_, routine_name_ }, diff --git a/src/routine.hpp b/src/routine.hpp index 5e2b4065..e77e35ad 100644 --- a/src/routine.hpp +++ b/src/routine.hpp @@ -75,9 +75,7 @@ class Routine { EventPointer event_; const Context context_; const Device device_; - - // OpenCL device properties - const std::string device_name_; + const cl_platform_id platform_; // Compiled program (either retrieved from cache or compiled in slow path) Program program_; diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp index 9dbd8934..165aca35 100644 --- a/test/correctness/tester.cpp +++ b/test/correctness/tester.cpp @@ -141,7 +141,7 @@ Tester::Tester(const std::vector &arguments, const bool silent } // Prints the header - fprintf(stdout, "* Running on OpenCL device '%s'.\n", device_.Name().c_str()); + fprintf(stdout, "* Running on OpenCL device '%s'.\n", GetDeviceName(device_).c_str()); fprintf(stdout, "* Starting tests for the %s'%s'%s routine.", kPrintMessage.c_str(), name.c_str(), kPrintEnd.c_str());