diff --git a/src/cache.cpp b/src/cache.cpp index 6786eaa2..2b91abf1 100644 --- a/src/cache.cpp +++ b/src/cache.cpp @@ -20,103 +20,73 @@ namespace clblast { // ================================================================================================= -// Stores the compiled binary or IR in the cache -void StoreBinaryToCache(const std::string &binary, const std::string &device_name, - const Precision &precision, const std::string &routine_name) { - #ifdef VERBOSE - printf("[DEBUG] Storing binary in cache\n"); - #endif - binary_cache_mutex_.lock(); - binary_cache_.push_back(BinaryCache{binary, device_name, precision, routine_name}); - binary_cache_mutex_.unlock(); -} +template +template +Value Cache::Get(const U &key, bool *in_cache) const { + std::lock_guard lock(cache_mutex_); -// Stores the compiled program in the cache -void StoreProgramToCache(const Program &program, const Context &context, - const Precision &precision, const std::string &routine_name) { - #ifdef VERBOSE - printf("[DEBUG] Storing program in cache\n"); - #endif - program_cache_mutex_.lock(); - program_cache_.push_back(ProgramCache{program, context(), precision, routine_name}); - program_cache_mutex_.unlock(); -} - -// Queries the cache and retrieves a matching binary. Assumes that the match is available, throws -// otherwise. -const std::string& GetBinaryFromCache(const std::string &device_name, const Precision &precision, - const std::string &routine_name) { - #ifdef VERBOSE - printf("[DEBUG] Retrieving binary from cache\n"); - #endif - binary_cache_mutex_.lock(); - for (auto &cached_binary: binary_cache_) { - if (cached_binary.MatchInCache(device_name, precision, routine_name)) { - binary_cache_mutex_.unlock(); - return cached_binary.binary; +#if __cplusplus >= 201402L + // generalized std::map::find() of C++14 + auto it = cache_.find(key); +#else + // O(n) lookup in a vector + auto it = std::find_if(cache_.begin(), cache_.end(), [&] (const std::pair &pair) { + return pair.first == key; + }); +#endif + if (it == cache_.end()) { + if (in_cache) { + *in_cache = false; } + return Value(); } - binary_cache_mutex_.unlock(); - throw LogicError("GetBinaryFromCache: Expected binary in cache, but found none"); + + if (in_cache) { + *in_cache = true; + } + return it->second; } -// Queries the cache and retrieves a matching program. Assumes that the match is available, throws -// otherwise. -const Program& GetProgramFromCache(const Context &context, const Precision &precision, - const std::string &routine_name) { - #ifdef VERBOSE - printf("[DEBUG] Retrieving program from cache\n"); - #endif - program_cache_mutex_.lock(); - for (auto &cached_program: program_cache_) { - if (cached_program.MatchInCache(context(), precision, routine_name)) { - program_cache_mutex_.unlock(); - return cached_program.program; - } +template +void Cache::Store(Key &&key, Value &&value) { + std::lock_guard lock(cache_mutex_); + +#if __cplusplus >= 201402L + // emplace() into a map + auto r = cache_.emplace(std::move(key), std::move(value)); + if (!r.second) { + throw LogicError("Cache::Store: object already in cache"); } - program_cache_mutex_.unlock(); - throw LogicError("GetProgramFromCache: Expected program in cache, but found none"); +#else + // emplace_back() into a vector + cache_.emplace_back(std::move(key), std::move(value)); +#endif } -// Queries the cache to see whether or not the compiled kernel is already there -bool BinaryIsInCache(const std::string &device_name, const Precision &precision, - const std::string &routine_name) { - binary_cache_mutex_.lock(); - for (auto &cached_binary: binary_cache_) { - if (cached_binary.MatchInCache(device_name, precision, routine_name)) { - binary_cache_mutex_.unlock(); - return true; - } - } - binary_cache_mutex_.unlock(); - return false; +template +void Cache::Invalidate() { + std::lock_guard lock(cache_mutex_); + + cache_.clear(); } -// Queries the cache to see whether or not the compiled kernel is already there -bool ProgramIsInCache(const Context &context, const Precision &precision, - const std::string &routine_name) { - program_cache_mutex_.lock(); - for (auto &cached_program: program_cache_) { - if (cached_program.MatchInCache(context(), precision, routine_name)) { - program_cache_mutex_.unlock(); - return true; - } - } - program_cache_mutex_.unlock(); - return false; +template +Cache &Cache::Instance() { + return instance_; } +template +Cache Cache::instance_; + // ================================================================================================= -// Clears the cache of stored binaries and programs -void CacheClearAll() { - binary_cache_mutex_.lock(); - binary_cache_.clear(); - binary_cache_mutex_.unlock(); - program_cache_mutex_.lock(); - program_cache_.clear(); - program_cache_mutex_.unlock(); -} +template class Cache; +template std::string BinaryCache::Get(const BinaryKeyRef &, bool *) const; + +// ================================================================================================= + +template class Cache; +template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const; // ================================================================================================= } // namespace clblast diff --git a/src/cache.hpp b/src/cache.hpp index 9ecb0f1e..25d6f076 100644 --- a/src/cache.hpp +++ b/src/cache.hpp @@ -15,81 +15,75 @@ #define CLBLAST_CACHE_H_ #include -#include #include +#include #include "utilities/utilities.hpp" namespace clblast { // ================================================================================================= -// The cache of compiled OpenCL binaries, along with some meta-data -struct BinaryCache { - std::string binary; - std::string device_name; - Precision precision; - std::string routine_name_; +// The generic thread-safe cache. We assume that the Key may be a heavyweight struct that is not +// normally used by the caller, while the Value is either lightweight or ref-counted. +// Hence, searching by non-Key is supported (if there is a corresponding operator<()), and +// on Store() the Key instance is moved from the caller (because it will likely be constructed +// as temporary at the time of Store()). +template +class Cache { +public: + // Cached object is returned by-value to avoid racing with Invalidate(). + // Due to lack of std::optional<>, in case of a cache miss we return a default-constructed + // Value and set the flag to false. + template + Value Get(const U &key, bool *in_cache) const; - // Finds out whether the properties match - bool MatchInCache(const std::string &ref_device, const Precision &ref_precision, - const std::string &ref_routine) { - return (device_name == ref_device && - precision == ref_precision && - routine_name_ == ref_routine); - } -}; + // We do not return references to just stored object to avoid racing with Invalidate(). + // Caller is expected to store a temporary. + void Store(Key &&key, Value &&value); + void Invalidate(); -// The actual cache, implemented as a vector of the above data-type, and its mutex -static std::vector binary_cache_; -static std::mutex binary_cache_mutex_; + static Cache &Instance(); + +private: +#if __cplusplus >= 201402L + // The std::less allows to search in cache by an object comparable with Key, without + // constructing a temporary Key + // (see http://en.cppreference.com/w/cpp/utility/functional/less_void, + // http://www.open-std.org/JTC1/SC22/WG21/docs/papers/2013/n3657.htm, + // http://stackoverflow.com/questions/10536788/avoiding-key-construction-for-stdmapfind) + std::map> cache_; +#else + std::vector> cache_; +#endif + mutable std::mutex cache_mutex_; + + static Cache instance_; +}; // class Cache // ================================================================================================= -// The cache of compiled OpenCL programs, along with some meta-data -struct ProgramCache { - Program program; - cl_context context; - Precision precision; - std::string routine_name_; +// The key struct for the cache of compiled OpenCL binaries +// Order of fields: precision, routine_name, device_name (smaller fields first) +typedef std::tuple BinaryKey; +typedef std::tuple BinaryKeyRef; - // Finds out whether the properties match - bool MatchInCache(const cl_context ref_context, const Precision &ref_precision, - const std::string &ref_routine) { - return (context == ref_context && - precision == ref_precision && - routine_name_ == ref_routine); - } -}; +typedef Cache BinaryCache; + +extern template class Cache; +extern template std::string BinaryCache::Get(const BinaryKeyRef &, bool *) const; -// The actual cache, implemented as a vector of the above data-type, and its mutex -static std::vector program_cache_; -static std::mutex program_cache_mutex_; // ================================================================================================= -// Stores the compiled binary or program in the cache -void StoreBinaryToCache(const std::string &binary, const std::string &device_name, - const Precision &precision, const std::string &routine_name); -void StoreProgramToCache(const Program &program, const Context &context, - const Precision &precision, const std::string &routine_name); +// The key struct for the cache of compiled OpenCL programs (context-dependent) +// Order of fields: context, precision, routine_name (smaller fields first) +typedef std::tuple ProgramKey; +typedef std::tuple ProgramKeyRef; -// Queries the cache and retrieves a matching binary or program. Assumes that the match is -// available, throws otherwise. -const std::string& GetBinaryFromCache(const std::string &device_name, const Precision &precision, - const std::string &routine_name); -const Program& GetProgramFromCache(const Context &context, const Precision &precision, - const std::string &routine_name); +typedef Cache ProgramCache; -// Queries the cache to see whether or not the compiled kernel is already there -bool BinaryIsInCache(const std::string &device_name, const Precision &precision, - const std::string &routine_name); -bool ProgramIsInCache(const Context &context, const Precision &precision, - const std::string &routine_name); - -// ================================================================================================= - -// Clears the cache of stored binaries -void CacheClearAll(); +extern template class Cache; +extern template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const; // ================================================================================================= } // namespace clblast diff --git a/src/clblast.cpp b/src/clblast.cpp index e0f8add2..35f3f552 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2165,7 +2165,8 @@ template StatusCode PUBLIC_API Omatcopy(const Layout, const Transpose, // Clears the cache of stored binaries StatusCode ClearCache() { try { - CacheClearAll(); + ProgramCache::Instance().Invalidate(); + BinaryCache::Instance().Invalidate(); } catch (...) { return DispatchException(); } return StatusCode::kSuccess; } diff --git a/src/clpp11.hpp b/src/clpp11.hpp index c984661c..41af28da 100644 --- a/src/clpp11.hpp +++ b/src/clpp11.hpp @@ -361,7 +361,7 @@ enum class BuildStatus { kSuccess, kError, kInvalid }; // C++11 version of 'cl_program'. class Program { public: - // Note that there is no constructor based on the regular OpenCL data-type because of extra state + Program() = default; // Source-based constructor with memory management explicit Program(const Context &context, const std::string &source): diff --git a/src/routine.cpp b/src/routine.cpp index d5a6b589..75e4ea89 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -36,7 +36,10 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, db_(queue_, routines, precision_, userDatabase) { // Queries the cache to see whether or not the program (context-specific) is already there - if (ProgramIsInCache(context_, precision_, routine_name_)) { return; } + bool has_program; + program_ = ProgramCache::Instance().Get(ProgramKeyRef{ context_(), precision_, routine_name_ }, + &has_program); + if (has_program) { return; } // Sets the build options from an environmental variable (if set) auto options = std::vector(); @@ -47,11 +50,14 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, // Queries the cache to see whether or not the binary (device-specific) is already there. If it // is, a program is created and stored in the cache - if (BinaryIsInCache(device_name_, precision_, routine_name_)) { - auto& binary = GetBinaryFromCache(device_name_, precision_, routine_name_); - auto program = Program(device_, context_, binary); - program.Build(device_, options); - StoreProgramToCache(program, context_, precision_, routine_name_); + bool has_binary; + auto binary = BinaryCache::Instance().Get(BinaryKeyRef{ precision_, routine_name_, device_name_ }, + &has_binary); + if (has_binary) { + program_ = Program(device_, context_, binary); + program_.Build(device_, options); + ProgramCache::Instance().Store(ProgramKey{ context_(), precision_, routine_name_ }, + Program{ program_ }); return; } @@ -111,21 +117,23 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, #endif // Compiles the kernel - auto program = Program(context_, source_string); + program_ = Program(context_, source_string); try { - program.Build(device_, options); + program_.Build(device_, options); } catch (const CLError &e) { if (e.status() == CL_BUILD_PROGRAM_FAILURE) { fprintf(stdout, "OpenCL compiler error/warning: %s\n", - program.GetBuildInfo(device_).c_str()); + program_.GetBuildInfo(device_).c_str()); } throw; } // Store the compiled binary and program in the cache - const auto binary = program.GetIR(); - StoreBinaryToCache(binary, device_name_, precision_, routine_name_); - StoreProgramToCache(program, context_, precision_, routine_name_); + BinaryCache::Instance().Store(BinaryKey{ precision_, routine_name_, device_name_ }, + program_.GetIR()); + + ProgramCache::Instance().Store(ProgramKey{ context_(), precision_, routine_name_ }, + Program{ program_ }); // Prints the elapsed compilation time in case of debugging in verbose mode #ifdef VERBOSE diff --git a/src/routine.hpp b/src/routine.hpp index 2d8b2415..8e9fd54d 100644 --- a/src/routine.hpp +++ b/src/routine.hpp @@ -57,6 +57,9 @@ class Routine { // OpenCL device properties const std::string device_name_; + // Compiled program (either retrieved from cache or compiled in slow path) + Program program_; + // Connection to the database for all the device-specific parameters const Database db_; }; diff --git a/src/routines/level1/xamax.cpp b/src/routines/level1/xamax.cpp index e9efa1a7..40a66517 100644 --- a/src/routines/level1/xamax.cpp +++ b/src/routines/level1/xamax.cpp @@ -43,9 +43,8 @@ void Xamax::DoAmax(const size_t n, TestVectorIndex(1, imax_buffer, imax_offset); // Retrieves the Xamax kernels from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel1 = Kernel(program, "Xamax"); - auto kernel2 = Kernel(program, "XamaxEpilogue"); + auto kernel1 = Kernel(program_, "Xamax"); + auto kernel2 = Kernel(program_, "XamaxEpilogue"); // Creates the buffer for intermediate values auto temp_size = 2*db_["WGS2"]; diff --git a/src/routines/level1/xasum.cpp b/src/routines/level1/xasum.cpp index a242a5fa..b93b271c 100644 --- a/src/routines/level1/xasum.cpp +++ b/src/routines/level1/xasum.cpp @@ -43,9 +43,8 @@ void Xasum::DoAsum(const size_t n, TestVectorScalar(1, asum_buffer, asum_offset); // Retrieves the Xasum kernels from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel1 = Kernel(program, "Xasum"); - auto kernel2 = Kernel(program, "XasumEpilogue"); + auto kernel1 = Kernel(program_, "Xasum"); + auto kernel2 = Kernel(program_, "XasumEpilogue"); // Creates the buffer for intermediate values auto temp_size = 2*db_["WGS2"]; diff --git a/src/routines/level1/xaxpy.cpp b/src/routines/level1/xaxpy.cpp index 5436c5b7..39f61ef4 100644 --- a/src/routines/level1/xaxpy.cpp +++ b/src/routines/level1/xaxpy.cpp @@ -52,8 +52,7 @@ void Xaxpy::DoAxpy(const size_t n, const T alpha, auto kernel_name = (use_fast_kernel) ? "XaxpyFast" : "Xaxpy"; // Retrieves the Xaxpy kernel from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments if (use_fast_kernel) { diff --git a/src/routines/level1/xcopy.cpp b/src/routines/level1/xcopy.cpp index d86200c0..62889764 100644 --- a/src/routines/level1/xcopy.cpp +++ b/src/routines/level1/xcopy.cpp @@ -52,8 +52,7 @@ void Xcopy::DoCopy(const size_t n, auto kernel_name = (use_fast_kernel) ? "XcopyFast" : "Xcopy"; // Retrieves the Xcopy kernel from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments if (use_fast_kernel) { diff --git a/src/routines/level1/xdot.cpp b/src/routines/level1/xdot.cpp index 9d718913..9f9c0590 100644 --- a/src/routines/level1/xdot.cpp +++ b/src/routines/level1/xdot.cpp @@ -46,9 +46,8 @@ void Xdot::DoDot(const size_t n, TestVectorScalar(1, dot_buffer, dot_offset); // Retrieves the Xdot kernels from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel1 = Kernel(program, "Xdot"); - auto kernel2 = Kernel(program, "XdotEpilogue"); + auto kernel1 = Kernel(program_, "Xdot"); + auto kernel2 = Kernel(program_, "XdotEpilogue"); // Creates the buffer for intermediate values auto temp_size = 2*db_["WGS2"]; diff --git a/src/routines/level1/xnrm2.cpp b/src/routines/level1/xnrm2.cpp index 373820a4..aa341aff 100644 --- a/src/routines/level1/xnrm2.cpp +++ b/src/routines/level1/xnrm2.cpp @@ -43,9 +43,8 @@ void Xnrm2::DoNrm2(const size_t n, TestVectorScalar(1, nrm2_buffer, nrm2_offset); // Retrieves the Xnrm2 kernels from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel1 = Kernel(program, "Xnrm2"); - auto kernel2 = Kernel(program, "Xnrm2Epilogue"); + auto kernel1 = Kernel(program_, "Xnrm2"); + auto kernel2 = Kernel(program_, "Xnrm2Epilogue"); // Creates the buffer for intermediate values auto temp_size = 2*db_["WGS2"]; diff --git a/src/routines/level1/xscal.cpp b/src/routines/level1/xscal.cpp index 0521b1e5..9bc096e5 100644 --- a/src/routines/level1/xscal.cpp +++ b/src/routines/level1/xscal.cpp @@ -49,8 +49,7 @@ void Xscal::DoScal(const size_t n, const T alpha, auto kernel_name = (use_fast_kernel) ? "XscalFast" : "Xscal"; // Retrieves the Xscal kernel from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments if (use_fast_kernel) { diff --git a/src/routines/level1/xswap.cpp b/src/routines/level1/xswap.cpp index c9b97dc9..f046575f 100644 --- a/src/routines/level1/xswap.cpp +++ b/src/routines/level1/xswap.cpp @@ -52,8 +52,7 @@ void Xswap::DoSwap(const size_t n, auto kernel_name = (use_fast_kernel) ? "XswapFast" : "Xswap"; // Retrieves the Xswap kernel from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments if (use_fast_kernel) { diff --git a/src/routines/level2/xgemv.cpp b/src/routines/level2/xgemv.cpp index 7b4c2e8f..9e9c2db4 100644 --- a/src/routines/level2/xgemv.cpp +++ b/src/routines/level2/xgemv.cpp @@ -122,8 +122,7 @@ void Xgemv::MatVec(const Layout layout, const Transpose a_transpose, } // Retrieves the Xgemv kernel from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast(m_real)); diff --git a/src/routines/level2/xger.cpp b/src/routines/level2/xger.cpp index d16ebd11..9ec156a1 100644 --- a/src/routines/level2/xger.cpp +++ b/src/routines/level2/xger.cpp @@ -53,8 +53,7 @@ void Xger::DoGer(const Layout layout, TestVectorY(n, y_buffer, y_offset, y_inc); // Retrieves the kernel from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel = Kernel(program, "Xger"); + auto kernel = Kernel(program_, "Xger"); // Sets the kernel arguments kernel.SetArgument(0, static_cast(a_one)); diff --git a/src/routines/level2/xher.cpp b/src/routines/level2/xher.cpp index 6c334e63..ba12a3ef 100644 --- a/src/routines/level2/xher.cpp +++ b/src/routines/level2/xher.cpp @@ -67,8 +67,7 @@ void Xher::DoHer(const Layout layout, const Triangle triangle, const auto matching_alpha = GetAlpha(alpha); // Retrieves the kernel from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel = Kernel(program, "Xher"); + auto kernel = Kernel(program_, "Xher"); // Sets the kernel arguments kernel.SetArgument(0, static_cast(n)); diff --git a/src/routines/level2/xher2.cpp b/src/routines/level2/xher2.cpp index 11e2c871..a420e693 100644 --- a/src/routines/level2/xher2.cpp +++ b/src/routines/level2/xher2.cpp @@ -54,8 +54,7 @@ void Xher2::DoHer2(const Layout layout, const Triangle triangle, TestVectorY(n, y_buffer, y_offset, y_inc); // Retrieves the kernel from the compiled binary - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel = Kernel(program, "Xher2"); + auto kernel = Kernel(program_, "Xher2"); // Sets the kernel arguments kernel.SetArgument(0, static_cast(n)); diff --git a/src/routines/level3/xgemm.cpp b/src/routines/level3/xgemm.cpp index 0015b629..7bd388c1 100644 --- a/src/routines/level3/xgemm.cpp +++ b/src/routines/level3/xgemm.cpp @@ -150,9 +150,6 @@ void Xgemm::GemmIndirect(const size_t m, const size_t n, const size_t k, const auto c_one_i = (c_want_rotated) ? n_ceiled : m_ceiled; const auto c_two_i = (c_want_rotated) ? m_ceiled : n_ceiled; - // Loads the program from the database - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - // Determines whether or not temporary matrices are needed auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offset == 0 && a_do_transpose == false && a_conjugate == false; @@ -178,7 +175,7 @@ void Xgemm::GemmIndirect(const size_t m, const size_t n, const size_t k, PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA.pointer(), emptyEventList, a_one, a_two, a_ld, a_offset, a_buffer, a_one_i, a_two_i, a_one_i, 0, a_temp, - ConstantOne(), program, + ConstantOne(), program_, true, a_do_transpose, a_conjugate); eventWaitList.push_back(eventProcessA); } @@ -189,7 +186,7 @@ void Xgemm::GemmIndirect(const size_t m, const size_t n, const size_t k, PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB.pointer(), emptyEventList, b_one, b_two, b_ld, b_offset, b_buffer, b_one_i, b_two_i, b_one_i, 0, b_temp, - ConstantOne(), program, + ConstantOne(), program_, true, b_do_transpose, b_conjugate); eventWaitList.push_back(eventProcessB); } @@ -200,13 +197,13 @@ void Xgemm::GemmIndirect(const size_t m, const size_t n, const size_t k, PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList, c_one, c_two, c_ld, c_offset, c_buffer, c_one_i, c_two_i, c_one_i, 0, c_temp, - ConstantOne(), program, + ConstantOne(), program_, true, c_do_transpose, false); eventWaitList.push_back(eventProcessC); } // Retrieves the Xgemm kernel from the compiled binary - auto kernel = Kernel(program, "Xgemm"); + auto kernel = Kernel(program_, "Xgemm"); // Sets the kernel arguments kernel.SetArgument(0, static_cast(m_ceiled)); @@ -236,7 +233,7 @@ void Xgemm::GemmIndirect(const size_t m, const size_t n, const size_t k, PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList, c_one_i, c_two_i, c_one_i, 0, c_temp, c_one, c_two, c_ld, c_offset, c_buffer, - ConstantOne(), program, + ConstantOne(), program_, false, c_do_transpose, false); } } @@ -255,13 +252,10 @@ void Xgemm::GemmDirect(const size_t m, const size_t n, const size_t k, const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose, const bool a_conjugate, const bool b_conjugate) { - // Loads the program from the database - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - // Retrieves the proper XgemmDirect kernel from the compiled binary const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectTT" : "XgemmDirectTN") : (b_do_transpose ? "XgemmDirectNT" : "XgemmDirectNN"); - auto kernel = Kernel(program, name); + auto kernel = Kernel(program_, name); // Sets the kernel arguments kernel.SetArgument(0, static_cast(m)); diff --git a/src/routines/level3/xhemm.cpp b/src/routines/level3/xhemm.cpp index e5b1502a..8629f3de 100644 --- a/src/routines/level3/xhemm.cpp +++ b/src/routines/level3/xhemm.cpp @@ -58,8 +58,7 @@ void Xhemm::DoHemm(const Layout layout, const Side side, const Triangle trian // Creates a general matrix from the hermitian matrix to be able to run the regular Xgemm // routine afterwards - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the arguments for the hermitian-to-squared kernel kernel.SetArgument(0, static_cast(k)); diff --git a/src/routines/level3/xhemm.hpp b/src/routines/level3/xhemm.hpp index 2385706e..7c011915 100644 --- a/src/routines/level3/xhemm.hpp +++ b/src/routines/level3/xhemm.hpp @@ -30,6 +30,7 @@ class Xhemm: public Xgemm { using Xgemm::queue_; using Xgemm::context_; using Xgemm::device_; + using Xgemm::program_; using Xgemm::db_; using Xgemm::DoGemm; diff --git a/src/routines/level3/xher2k.cpp b/src/routines/level3/xher2k.cpp index ee3bb8b8..2aed2781 100644 --- a/src/routines/level3/xher2k.cpp +++ b/src/routines/level3/xher2k.cpp @@ -81,9 +81,6 @@ void Xher2k::DoHer2k(const Layout layout, const Triangle triangle, const Tr // Decides which kernel to run: the upper-triangular or lower-triangular version auto kernel_name = (triangle == Triangle::kUpper) ? "XgemmUpper" : "XgemmLower"; - // Loads the program from the database - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - // Determines whether or not temporary matrices are needed auto a1_no_temp = ab_one == n_ceiled && ab_two == k_ceiled && a_ld == n_ceiled && a_offset == 0 && ab_rotated == false && ab_conjugate == false; @@ -116,7 +113,7 @@ void Xher2k::DoHer2k(const Layout layout, const Triangle triangle, const Tr PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA1.pointer(), emptyEventList, ab_one, ab_two, a_ld, a_offset, a_buffer, n_ceiled, k_ceiled, n_ceiled, 0, a1_temp, - ConstantOne(), program, + ConstantOne(), program_, true, ab_rotated, ab_conjugate); eventWaitList.push_back(eventProcessA1); } @@ -125,7 +122,7 @@ void Xher2k::DoHer2k(const Layout layout, const Triangle triangle, const Tr PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA2.pointer(), emptyEventList, ab_one, ab_two, a_ld, a_offset, a_buffer, n_ceiled, k_ceiled, n_ceiled, 0, a2_temp, - ConstantOne(), program, + ConstantOne(), program_, true, ab_rotated, !ab_conjugate); eventWaitList.push_back(eventProcessA2); } @@ -134,7 +131,7 @@ void Xher2k::DoHer2k(const Layout layout, const Triangle triangle, const Tr PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB1.pointer(), emptyEventList, ab_one, ab_two, b_ld, b_offset, b_buffer, n_ceiled, k_ceiled, n_ceiled, 0, b1_temp, - ConstantOne(), program, + ConstantOne(), program_, true, ab_rotated, ab_conjugate); eventWaitList.push_back(eventProcessB1); } @@ -143,7 +140,7 @@ void Xher2k::DoHer2k(const Layout layout, const Triangle triangle, const Tr PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB2.pointer(), emptyEventList, ab_one, ab_two, b_ld, b_offset, b_buffer, n_ceiled, k_ceiled, n_ceiled, 0, b2_temp, - ConstantOne(), program, + ConstantOne(), program_, true, ab_rotated, !ab_conjugate); eventWaitList.push_back(eventProcessB2); } @@ -154,12 +151,12 @@ void Xher2k::DoHer2k(const Layout layout, const Triangle triangle, const Tr PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList, n, n, c_ld, c_offset, c_buffer, n_ceiled, n_ceiled, n_ceiled, 0, c_temp, - ConstantOne(), program, + ConstantOne(), program_, true, c_rotated, false); eventWaitList.push_back(eventProcessC); // Retrieves the XgemmUpper or XgemmLower kernel from the compiled binary - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast(n_ceiled)); @@ -201,7 +198,7 @@ void Xher2k::DoHer2k(const Layout layout, const Triangle triangle, const Tr PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList, n_ceiled, n_ceiled, n_ceiled, 0, c_temp, n, n, c_ld, c_offset, c_buffer, - ConstantOne(), program, + ConstantOne(), program_, false, c_rotated, false, upper, lower, true); } diff --git a/src/routines/level3/xherk.cpp b/src/routines/level3/xherk.cpp index ae8e9324..d982859e 100644 --- a/src/routines/level3/xherk.cpp +++ b/src/routines/level3/xherk.cpp @@ -79,9 +79,6 @@ void Xherk::DoHerk(const Layout layout, const Triangle triangle, const Tran // Decides which kernel to run: the upper-triangular or lower-triangular version auto kernel_name = (triangle == Triangle::kUpper) ? "XgemmUpper" : "XgemmLower"; - // Loads the program from the database - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - // Determines whether or not temporary matrices are needed auto a_no_temp = a_one == n_ceiled && a_two == k_ceiled && a_ld == n_ceiled && a_offset == 0 && a_rotated == false && a_conjugate == false; @@ -109,7 +106,7 @@ void Xherk::DoHerk(const Layout layout, const Triangle triangle, const Tran PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA.pointer(), emptyEventList, a_one, a_two, a_ld, a_offset, a_buffer, n_ceiled, k_ceiled, n_ceiled, 0, a_temp, - ConstantOne(), program, + ConstantOne(), program_, true, a_rotated, a_conjugate); eventWaitList.push_back(eventProcessA); } @@ -118,7 +115,7 @@ void Xherk::DoHerk(const Layout layout, const Triangle triangle, const Tran PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB.pointer(), emptyEventList, a_one, a_two, a_ld, a_offset, a_buffer, n_ceiled, k_ceiled, n_ceiled, 0, b_temp, - ConstantOne(), program, + ConstantOne(), program_, true, a_rotated, b_conjugate); eventWaitList.push_back(eventProcessB); } @@ -129,12 +126,12 @@ void Xherk::DoHerk(const Layout layout, const Triangle triangle, const Tran PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList, n, n, c_ld, c_offset, c_buffer, n_ceiled, n_ceiled, n_ceiled, 0, c_temp, - ConstantOne(), program, + ConstantOne(), program_, true, c_rotated, false); eventWaitList.push_back(eventProcessC); // Retrieves the XgemmUpper or XgemmLower kernel from the compiled binary - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast(n_ceiled)); @@ -163,7 +160,7 @@ void Xherk::DoHerk(const Layout layout, const Triangle triangle, const Tran PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList, n_ceiled, n_ceiled, n_ceiled, 0, c_temp, n, n, c_ld, c_offset, c_buffer, - ConstantOne(), program, + ConstantOne(), program_, false, c_rotated, false, upper, lower, true); } diff --git a/src/routines/level3/xsymm.cpp b/src/routines/level3/xsymm.cpp index d7f771d1..969edfc8 100644 --- a/src/routines/level3/xsymm.cpp +++ b/src/routines/level3/xsymm.cpp @@ -30,12 +30,12 @@ Xsymm::Xsymm(Queue &queue, EventPointer event, const std::string &name): // The main routine template void Xsymm::DoSymm(const Layout layout, const Side side, const Triangle triangle, - const size_t m, const size_t n, - const T alpha, - const Buffer &a_buffer, const size_t a_offset, const size_t a_ld, - const Buffer &b_buffer, const size_t b_offset, const size_t b_ld, - const T beta, - const Buffer &c_buffer, const size_t c_offset, const size_t c_ld) { + const size_t m, const size_t n, + const T alpha, + const Buffer &a_buffer, const size_t a_offset, const size_t a_ld, + const Buffer &b_buffer, const size_t b_offset, const size_t b_ld, + const T beta, + const Buffer &c_buffer, const size_t c_offset, const size_t c_ld) { // Makes sure all dimensions are larger than zero if ((m == 0) || (n == 0) ) { throw BLASError(StatusCode::kInvalidDimension); } @@ -58,8 +58,7 @@ void Xsymm::DoSymm(const Layout layout, const Side side, const Triangle trian // Creates a general matrix from the symmetric matrix to be able to run the regular Xgemm // routine afterwards - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the arguments for the symmetric-to-squared kernel kernel.SetArgument(0, static_cast(k)); diff --git a/src/routines/level3/xsymm.hpp b/src/routines/level3/xsymm.hpp index ee965364..7a584560 100644 --- a/src/routines/level3/xsymm.hpp +++ b/src/routines/level3/xsymm.hpp @@ -32,6 +32,7 @@ class Xsymm: public Xgemm { using Xgemm::queue_; using Xgemm::context_; using Xgemm::device_; + using Xgemm::program_; using Xgemm::db_; using Xgemm::DoGemm; diff --git a/src/routines/level3/xsyr2k.cpp b/src/routines/level3/xsyr2k.cpp index cb0e0461..fdef43dc 100644 --- a/src/routines/level3/xsyr2k.cpp +++ b/src/routines/level3/xsyr2k.cpp @@ -77,9 +77,6 @@ void Xsyr2k::DoSyr2k(const Layout layout, const Triangle triangle, const Tran // Decides which kernel to run: the upper-triangular or lower-triangular version auto kernel_name = (triangle == Triangle::kUpper) ? "XgemmUpper" : "XgemmLower"; - // Loads the program from the database - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - // Determines whether or not temporary matrices are needed auto a_no_temp = ab_one == n_ceiled && ab_two == k_ceiled && a_ld == n_ceiled && a_offset == 0 && ab_rotated == false; @@ -103,7 +100,7 @@ void Xsyr2k::DoSyr2k(const Layout layout, const Triangle triangle, const Tran PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA.pointer(), emptyEventList, ab_one, ab_two, a_ld, a_offset, a_buffer, n_ceiled, k_ceiled, n_ceiled, 0, a_temp, - ConstantOne(), program, + ConstantOne(), program_, true, ab_rotated, false); eventWaitList.push_back(eventProcessA); } @@ -112,7 +109,7 @@ void Xsyr2k::DoSyr2k(const Layout layout, const Triangle triangle, const Tran PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB.pointer(), emptyEventList, ab_one, ab_two, b_ld, b_offset, b_buffer, n_ceiled, k_ceiled, n_ceiled, 0, b_temp, - ConstantOne(), program, + ConstantOne(), program_, true, ab_rotated, false); eventWaitList.push_back(eventProcessB); } @@ -123,12 +120,12 @@ void Xsyr2k::DoSyr2k(const Layout layout, const Triangle triangle, const Tran PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList, n, n, c_ld, c_offset, c_buffer, n_ceiled, n_ceiled, n_ceiled, 0, c_temp, - ConstantOne(), program, + ConstantOne(), program_, true, c_rotated, false); eventWaitList.push_back(eventProcessC); // Retrieves the XgemmUpper or XgemmLower kernel from the compiled binary - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast(n_ceiled)); @@ -168,7 +165,7 @@ void Xsyr2k::DoSyr2k(const Layout layout, const Triangle triangle, const Tran PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList, n_ceiled, n_ceiled, n_ceiled, 0, c_temp, n, n, c_ld, c_offset, c_buffer, - ConstantOne(), program, + ConstantOne(), program_, false, c_rotated, false, upper, lower, false); } diff --git a/src/routines/level3/xsyrk.cpp b/src/routines/level3/xsyrk.cpp index bd6c4b25..9588c28c 100644 --- a/src/routines/level3/xsyrk.cpp +++ b/src/routines/level3/xsyrk.cpp @@ -74,9 +74,6 @@ void Xsyrk::DoSyrk(const Layout layout, const Triangle triangle, const Transp // Decides which kernel to run: the upper-triangular or lower-triangular version auto kernel_name = (triangle == Triangle::kUpper) ? "XgemmUpper" : "XgemmLower"; - // Loads the program from the database - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - // Determines whether or not temporary matrices are needed auto a_no_temp = a_one == n_ceiled && a_two == k_ceiled && a_ld == n_ceiled && a_offset == 0 && a_rotated == false; @@ -97,7 +94,7 @@ void Xsyrk::DoSyrk(const Layout layout, const Triangle triangle, const Transp PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA.pointer(), emptyEventList, a_one, a_two, a_ld, a_offset, a_buffer, n_ceiled, k_ceiled, n_ceiled, 0, a_temp, - ConstantOne(), program, + ConstantOne(), program_, true, a_rotated, false); eventWaitList.push_back(eventProcessA); } @@ -108,12 +105,12 @@ void Xsyrk::DoSyrk(const Layout layout, const Triangle triangle, const Transp PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList, n, n, c_ld, c_offset, c_buffer, n_ceiled, n_ceiled, n_ceiled, 0, c_temp, - ConstantOne(), program, + ConstantOne(), program_, true, c_rotated, false); eventWaitList.push_back(eventProcessC); // Retrieves the XgemmUpper or XgemmLower kernel from the compiled binary - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the kernel arguments kernel.SetArgument(0, static_cast(n_ceiled)); @@ -142,7 +139,7 @@ void Xsyrk::DoSyrk(const Layout layout, const Triangle triangle, const Transp PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList, n_ceiled, n_ceiled, n_ceiled, 0, c_temp, n, n, c_ld, c_offset, c_buffer, - ConstantOne(), program, + ConstantOne(), program_, false, c_rotated, false, upper, lower, false); } diff --git a/src/routines/level3/xtrmm.cpp b/src/routines/level3/xtrmm.cpp index ed810e72..02c295ac 100644 --- a/src/routines/level3/xtrmm.cpp +++ b/src/routines/level3/xtrmm.cpp @@ -70,8 +70,7 @@ void Xtrmm::DoTrmm(const Layout layout, const Side side, const Triangle trian // Creates a general matrix from the triangular matrix to be able to run the regular Xgemm // routine afterwards - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto kernel = Kernel(program, kernel_name); + auto kernel = Kernel(program_, kernel_name); // Sets the arguments for the triangular-to-squared kernel kernel.SetArgument(0, static_cast(k)); diff --git a/src/routines/level3/xtrmm.hpp b/src/routines/level3/xtrmm.hpp index 967bf132..e77b7214 100644 --- a/src/routines/level3/xtrmm.hpp +++ b/src/routines/level3/xtrmm.hpp @@ -31,6 +31,7 @@ class Xtrmm: public Xgemm { using Xgemm::queue_; using Xgemm::context_; using Xgemm::device_; + using Xgemm::program_; using Xgemm::db_; using Xgemm::DoGemm; diff --git a/src/routines/levelx/xomatcopy.cpp b/src/routines/levelx/xomatcopy.cpp index 875ca7d2..4ae8c056 100644 --- a/src/routines/levelx/xomatcopy.cpp +++ b/src/routines/levelx/xomatcopy.cpp @@ -65,14 +65,11 @@ void Xomatcopy::DoOmatcopy(const Layout layout, const Transpose a_transpose, TestMatrixA(a_one, a_two, a_buffer, a_offset, a_ld); TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld); - // Loads the program from the database - const auto program = GetProgramFromCache(context_, PrecisionValue(), routine_name_); - auto emptyEventList = std::vector(); PadCopyTransposeMatrix(queue_, device_, db_, event_, emptyEventList, a_one, a_two, a_ld, a_offset, a_buffer, b_one, b_two, b_ld, b_offset, b_buffer, - alpha, program, false, transpose, conjugate); + alpha, program_, false, transpose, conjugate); } // =================================================================================================