Routine, Cache: generalize, reduce amount of copying in fast path
Implement a generalized Cache<K, V>. Two variants are provided: the first one is based on std::map, using C++14-specific transparent std::less<> and generalized std::map::find() to allow searching by tuple of references. The second one is based on std::vector and O(n) lookup, but remains C++11-compliant.pull/132/head
parent
e943fe77d6
commit
5bcd92f297
136
src/cache.cpp
136
src/cache.cpp
|
@ -20,103 +20,73 @@
|
|||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
// Stores the compiled binary or IR in the cache
|
||||
void StoreBinaryToCache(const std::string &binary, const std::string &device_name,
|
||||
const Precision &precision, const std::string &routine_name) {
|
||||
#ifdef VERBOSE
|
||||
printf("[DEBUG] Storing binary in cache\n");
|
||||
#endif
|
||||
binary_cache_mutex_.lock();
|
||||
binary_cache_.push_back(BinaryCache{binary, device_name, precision, routine_name});
|
||||
binary_cache_mutex_.unlock();
|
||||
}
|
||||
template <typename Key, typename Value>
|
||||
template <typename U>
|
||||
Value Cache<Key, Value>::Get(const U &key, bool *in_cache) const {
|
||||
std::lock_guard<std::mutex> lock(cache_mutex_);
|
||||
|
||||
// Stores the compiled program in the cache
|
||||
void StoreProgramToCache(const Program &program, const Context &context,
|
||||
const Precision &precision, const std::string &routine_name) {
|
||||
#ifdef VERBOSE
|
||||
printf("[DEBUG] Storing program in cache\n");
|
||||
#endif
|
||||
program_cache_mutex_.lock();
|
||||
program_cache_.push_back(ProgramCache{program, context(), precision, routine_name});
|
||||
program_cache_mutex_.unlock();
|
||||
}
|
||||
|
||||
// Queries the cache and retrieves a matching binary. Assumes that the match is available, throws
|
||||
// otherwise.
|
||||
const std::string& GetBinaryFromCache(const std::string &device_name, const Precision &precision,
|
||||
const std::string &routine_name) {
|
||||
#ifdef VERBOSE
|
||||
printf("[DEBUG] Retrieving binary from cache\n");
|
||||
#endif
|
||||
binary_cache_mutex_.lock();
|
||||
for (auto &cached_binary: binary_cache_) {
|
||||
if (cached_binary.MatchInCache(device_name, precision, routine_name)) {
|
||||
binary_cache_mutex_.unlock();
|
||||
return cached_binary.binary;
|
||||
#if __cplusplus >= 201402L
|
||||
// generalized std::map::find() of C++14
|
||||
auto it = cache_.find(key);
|
||||
#else
|
||||
// O(n) lookup in a vector
|
||||
auto it = std::find_if(cache_.begin(), cache_.end(), [&] (const std::pair<Key, Value> &pair) {
|
||||
return pair.first == key;
|
||||
});
|
||||
#endif
|
||||
if (it == cache_.end()) {
|
||||
if (in_cache) {
|
||||
*in_cache = false;
|
||||
}
|
||||
return Value();
|
||||
}
|
||||
binary_cache_mutex_.unlock();
|
||||
throw LogicError("GetBinaryFromCache: Expected binary in cache, but found none");
|
||||
|
||||
if (in_cache) {
|
||||
*in_cache = true;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Queries the cache and retrieves a matching program. Assumes that the match is available, throws
|
||||
// otherwise.
|
||||
const Program& GetProgramFromCache(const Context &context, const Precision &precision,
|
||||
const std::string &routine_name) {
|
||||
#ifdef VERBOSE
|
||||
printf("[DEBUG] Retrieving program from cache\n");
|
||||
#endif
|
||||
program_cache_mutex_.lock();
|
||||
for (auto &cached_program: program_cache_) {
|
||||
if (cached_program.MatchInCache(context(), precision, routine_name)) {
|
||||
program_cache_mutex_.unlock();
|
||||
return cached_program.program;
|
||||
}
|
||||
template <typename Key, typename Value>
|
||||
void Cache<Key, Value>::Store(Key &&key, Value &&value) {
|
||||
std::lock_guard<std::mutex> lock(cache_mutex_);
|
||||
|
||||
#if __cplusplus >= 201402L
|
||||
// emplace() into a map
|
||||
auto r = cache_.emplace(std::move(key), std::move(value));
|
||||
if (!r.second) {
|
||||
throw LogicError("Cache::Store: object already in cache");
|
||||
}
|
||||
program_cache_mutex_.unlock();
|
||||
throw LogicError("GetProgramFromCache: Expected program in cache, but found none");
|
||||
#else
|
||||
// emplace_back() into a vector
|
||||
cache_.emplace_back(std::move(key), std::move(value));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Queries the cache to see whether or not the compiled kernel is already there
|
||||
bool BinaryIsInCache(const std::string &device_name, const Precision &precision,
|
||||
const std::string &routine_name) {
|
||||
binary_cache_mutex_.lock();
|
||||
for (auto &cached_binary: binary_cache_) {
|
||||
if (cached_binary.MatchInCache(device_name, precision, routine_name)) {
|
||||
binary_cache_mutex_.unlock();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
binary_cache_mutex_.unlock();
|
||||
return false;
|
||||
template <typename Key, typename Value>
|
||||
void Cache<Key, Value>::Invalidate() {
|
||||
std::lock_guard<std::mutex> lock(cache_mutex_);
|
||||
|
||||
cache_.clear();
|
||||
}
|
||||
|
||||
// Queries the cache to see whether or not the compiled kernel is already there
|
||||
bool ProgramIsInCache(const Context &context, const Precision &precision,
|
||||
const std::string &routine_name) {
|
||||
program_cache_mutex_.lock();
|
||||
for (auto &cached_program: program_cache_) {
|
||||
if (cached_program.MatchInCache(context(), precision, routine_name)) {
|
||||
program_cache_mutex_.unlock();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
program_cache_mutex_.unlock();
|
||||
return false;
|
||||
template <typename Key, typename Value>
|
||||
Cache<Key, Value> &Cache<Key, Value>::Instance() {
|
||||
return instance_;
|
||||
}
|
||||
|
||||
template <typename Key, typename Value>
|
||||
Cache<Key, Value> Cache<Key, Value>::instance_;
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Clears the cache of stored binaries and programs
|
||||
void CacheClearAll() {
|
||||
binary_cache_mutex_.lock();
|
||||
binary_cache_.clear();
|
||||
binary_cache_mutex_.unlock();
|
||||
program_cache_mutex_.lock();
|
||||
program_cache_.clear();
|
||||
program_cache_mutex_.unlock();
|
||||
}
|
||||
template class Cache<BinaryKey, std::string>;
|
||||
template std::string BinaryCache::Get(const BinaryKeyRef &, bool *) const;
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
template class Cache<ProgramKey, Program>;
|
||||
template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const;
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
|
106
src/cache.hpp
106
src/cache.hpp
|
@ -15,81 +15,75 @@
|
|||
#define CLBLAST_CACHE_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <map>
|
||||
|
||||
#include "utilities/utilities.hpp"
|
||||
|
||||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
// The cache of compiled OpenCL binaries, along with some meta-data
|
||||
struct BinaryCache {
|
||||
std::string binary;
|
||||
std::string device_name;
|
||||
Precision precision;
|
||||
std::string routine_name_;
|
||||
// The generic thread-safe cache. We assume that the Key may be a heavyweight struct that is not
|
||||
// normally used by the caller, while the Value is either lightweight or ref-counted.
|
||||
// Hence, searching by non-Key is supported (if there is a corresponding operator<()), and
|
||||
// on Store() the Key instance is moved from the caller (because it will likely be constructed
|
||||
// as temporary at the time of Store()).
|
||||
template <typename Key, typename Value>
|
||||
class Cache {
|
||||
public:
|
||||
// Cached object is returned by-value to avoid racing with Invalidate().
|
||||
// Due to lack of std::optional<>, in case of a cache miss we return a default-constructed
|
||||
// Value and set the flag to false.
|
||||
template <typename U>
|
||||
Value Get(const U &key, bool *in_cache) const;
|
||||
|
||||
// Finds out whether the properties match
|
||||
bool MatchInCache(const std::string &ref_device, const Precision &ref_precision,
|
||||
const std::string &ref_routine) {
|
||||
return (device_name == ref_device &&
|
||||
precision == ref_precision &&
|
||||
routine_name_ == ref_routine);
|
||||
}
|
||||
};
|
||||
// We do not return references to just stored object to avoid racing with Invalidate().
|
||||
// Caller is expected to store a temporary.
|
||||
void Store(Key &&key, Value &&value);
|
||||
void Invalidate();
|
||||
|
||||
// The actual cache, implemented as a vector of the above data-type, and its mutex
|
||||
static std::vector<BinaryCache> binary_cache_;
|
||||
static std::mutex binary_cache_mutex_;
|
||||
static Cache<Key, Value> &Instance();
|
||||
|
||||
private:
|
||||
#if __cplusplus >= 201402L
|
||||
// The std::less<void> allows to search in cache by an object comparable with Key, without
|
||||
// constructing a temporary Key
|
||||
// (see http://en.cppreference.com/w/cpp/utility/functional/less_void,
|
||||
// http://www.open-std.org/JTC1/SC22/WG21/docs/papers/2013/n3657.htm,
|
||||
// http://stackoverflow.com/questions/10536788/avoiding-key-construction-for-stdmapfind)
|
||||
std::map<Key, Value, std::less<void>> cache_;
|
||||
#else
|
||||
std::vector<std::pair<Key, Value>> cache_;
|
||||
#endif
|
||||
mutable std::mutex cache_mutex_;
|
||||
|
||||
static Cache<Key, Value> instance_;
|
||||
}; // class Cache
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// The cache of compiled OpenCL programs, along with some meta-data
|
||||
struct ProgramCache {
|
||||
Program program;
|
||||
cl_context context;
|
||||
Precision precision;
|
||||
std::string routine_name_;
|
||||
// The key struct for the cache of compiled OpenCL binaries
|
||||
// Order of fields: precision, routine_name, device_name (smaller fields first)
|
||||
typedef std::tuple<Precision, std::string, std::string> BinaryKey;
|
||||
typedef std::tuple<const Precision &, const std::string &, const std::string &> BinaryKeyRef;
|
||||
|
||||
// Finds out whether the properties match
|
||||
bool MatchInCache(const cl_context ref_context, const Precision &ref_precision,
|
||||
const std::string &ref_routine) {
|
||||
return (context == ref_context &&
|
||||
precision == ref_precision &&
|
||||
routine_name_ == ref_routine);
|
||||
}
|
||||
};
|
||||
typedef Cache<BinaryKey, std::string> BinaryCache;
|
||||
|
||||
extern template class Cache<BinaryKey, std::string>;
|
||||
extern template std::string BinaryCache::Get(const BinaryKeyRef &, bool *) const;
|
||||
|
||||
// The actual cache, implemented as a vector of the above data-type, and its mutex
|
||||
static std::vector<ProgramCache> program_cache_;
|
||||
static std::mutex program_cache_mutex_;
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Stores the compiled binary or program in the cache
|
||||
void StoreBinaryToCache(const std::string &binary, const std::string &device_name,
|
||||
const Precision &precision, const std::string &routine_name);
|
||||
void StoreProgramToCache(const Program &program, const Context &context,
|
||||
const Precision &precision, const std::string &routine_name);
|
||||
// The key struct for the cache of compiled OpenCL programs (context-dependent)
|
||||
// Order of fields: context, precision, routine_name (smaller fields first)
|
||||
typedef std::tuple<cl_context, Precision, std::string> ProgramKey;
|
||||
typedef std::tuple<const cl_context &, const Precision &, const std::string &> ProgramKeyRef;
|
||||
|
||||
// Queries the cache and retrieves a matching binary or program. Assumes that the match is
|
||||
// available, throws otherwise.
|
||||
const std::string& GetBinaryFromCache(const std::string &device_name, const Precision &precision,
|
||||
const std::string &routine_name);
|
||||
const Program& GetProgramFromCache(const Context &context, const Precision &precision,
|
||||
const std::string &routine_name);
|
||||
typedef Cache<ProgramKey, Program> ProgramCache;
|
||||
|
||||
// Queries the cache to see whether or not the compiled kernel is already there
|
||||
bool BinaryIsInCache(const std::string &device_name, const Precision &precision,
|
||||
const std::string &routine_name);
|
||||
bool ProgramIsInCache(const Context &context, const Precision &precision,
|
||||
const std::string &routine_name);
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Clears the cache of stored binaries
|
||||
void CacheClearAll();
|
||||
extern template class Cache<ProgramKey, Program>;
|
||||
extern template Program ProgramCache::Get(const ProgramKeyRef &, bool *) const;
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
|
|
@ -2165,7 +2165,8 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
|
|||
// Clears the cache of stored binaries
|
||||
StatusCode ClearCache() {
|
||||
try {
|
||||
CacheClearAll();
|
||||
ProgramCache::Instance().Invalidate();
|
||||
BinaryCache::Instance().Invalidate();
|
||||
} catch (...) { return DispatchException(); }
|
||||
return StatusCode::kSuccess;
|
||||
}
|
||||
|
|
|
@ -361,7 +361,7 @@ enum class BuildStatus { kSuccess, kError, kInvalid };
|
|||
// C++11 version of 'cl_program'.
|
||||
class Program {
|
||||
public:
|
||||
// Note that there is no constructor based on the regular OpenCL data-type because of extra state
|
||||
Program() = default;
|
||||
|
||||
// Source-based constructor with memory management
|
||||
explicit Program(const Context &context, const std::string &source):
|
||||
|
|
|
@ -36,7 +36,10 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
|
|||
db_(queue_, routines, precision_, userDatabase) {
|
||||
|
||||
// Queries the cache to see whether or not the program (context-specific) is already there
|
||||
if (ProgramIsInCache(context_, precision_, routine_name_)) { return; }
|
||||
bool has_program;
|
||||
program_ = ProgramCache::Instance().Get(ProgramKeyRef{ context_(), precision_, routine_name_ },
|
||||
&has_program);
|
||||
if (has_program) { return; }
|
||||
|
||||
// Sets the build options from an environmental variable (if set)
|
||||
auto options = std::vector<std::string>();
|
||||
|
@ -47,11 +50,14 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
|
|||
|
||||
// Queries the cache to see whether or not the binary (device-specific) is already there. If it
|
||||
// is, a program is created and stored in the cache
|
||||
if (BinaryIsInCache(device_name_, precision_, routine_name_)) {
|
||||
auto& binary = GetBinaryFromCache(device_name_, precision_, routine_name_);
|
||||
auto program = Program(device_, context_, binary);
|
||||
program.Build(device_, options);
|
||||
StoreProgramToCache(program, context_, precision_, routine_name_);
|
||||
bool has_binary;
|
||||
auto binary = BinaryCache::Instance().Get(BinaryKeyRef{ precision_, routine_name_, device_name_ },
|
||||
&has_binary);
|
||||
if (has_binary) {
|
||||
program_ = Program(device_, context_, binary);
|
||||
program_.Build(device_, options);
|
||||
ProgramCache::Instance().Store(ProgramKey{ context_(), precision_, routine_name_ },
|
||||
Program{ program_ });
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -111,21 +117,23 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
|
|||
#endif
|
||||
|
||||
// Compiles the kernel
|
||||
auto program = Program(context_, source_string);
|
||||
program_ = Program(context_, source_string);
|
||||
try {
|
||||
program.Build(device_, options);
|
||||
program_.Build(device_, options);
|
||||
} catch (const CLError &e) {
|
||||
if (e.status() == CL_BUILD_PROGRAM_FAILURE) {
|
||||
fprintf(stdout, "OpenCL compiler error/warning: %s\n",
|
||||
program.GetBuildInfo(device_).c_str());
|
||||
program_.GetBuildInfo(device_).c_str());
|
||||
}
|
||||
throw;
|
||||
}
|
||||
|
||||
// Store the compiled binary and program in the cache
|
||||
const auto binary = program.GetIR();
|
||||
StoreBinaryToCache(binary, device_name_, precision_, routine_name_);
|
||||
StoreProgramToCache(program, context_, precision_, routine_name_);
|
||||
BinaryCache::Instance().Store(BinaryKey{ precision_, routine_name_, device_name_ },
|
||||
program_.GetIR());
|
||||
|
||||
ProgramCache::Instance().Store(ProgramKey{ context_(), precision_, routine_name_ },
|
||||
Program{ program_ });
|
||||
|
||||
// Prints the elapsed compilation time in case of debugging in verbose mode
|
||||
#ifdef VERBOSE
|
||||
|
|
|
@ -57,6 +57,9 @@ class Routine {
|
|||
// OpenCL device properties
|
||||
const std::string device_name_;
|
||||
|
||||
// Compiled program (either retrieved from cache or compiled in slow path)
|
||||
Program program_;
|
||||
|
||||
// Connection to the database for all the device-specific parameters
|
||||
const Database db_;
|
||||
};
|
||||
|
|
|
@ -43,9 +43,8 @@ void Xamax<T>::DoAmax(const size_t n,
|
|||
TestVectorIndex(1, imax_buffer, imax_offset);
|
||||
|
||||
// Retrieves the Xamax kernels from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel1 = Kernel(program, "Xamax");
|
||||
auto kernel2 = Kernel(program, "XamaxEpilogue");
|
||||
auto kernel1 = Kernel(program_, "Xamax");
|
||||
auto kernel2 = Kernel(program_, "XamaxEpilogue");
|
||||
|
||||
// Creates the buffer for intermediate values
|
||||
auto temp_size = 2*db_["WGS2"];
|
||||
|
|
|
@ -43,9 +43,8 @@ void Xasum<T>::DoAsum(const size_t n,
|
|||
TestVectorScalar(1, asum_buffer, asum_offset);
|
||||
|
||||
// Retrieves the Xasum kernels from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel1 = Kernel(program, "Xasum");
|
||||
auto kernel2 = Kernel(program, "XasumEpilogue");
|
||||
auto kernel1 = Kernel(program_, "Xasum");
|
||||
auto kernel2 = Kernel(program_, "XasumEpilogue");
|
||||
|
||||
// Creates the buffer for intermediate values
|
||||
auto temp_size = 2*db_["WGS2"];
|
||||
|
|
|
@ -52,8 +52,7 @@ void Xaxpy<T>::DoAxpy(const size_t n, const T alpha,
|
|||
auto kernel_name = (use_fast_kernel) ? "XaxpyFast" : "Xaxpy";
|
||||
|
||||
// Retrieves the Xaxpy kernel from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
if (use_fast_kernel) {
|
||||
|
|
|
@ -52,8 +52,7 @@ void Xcopy<T>::DoCopy(const size_t n,
|
|||
auto kernel_name = (use_fast_kernel) ? "XcopyFast" : "Xcopy";
|
||||
|
||||
// Retrieves the Xcopy kernel from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
if (use_fast_kernel) {
|
||||
|
|
|
@ -46,9 +46,8 @@ void Xdot<T>::DoDot(const size_t n,
|
|||
TestVectorScalar(1, dot_buffer, dot_offset);
|
||||
|
||||
// Retrieves the Xdot kernels from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel1 = Kernel(program, "Xdot");
|
||||
auto kernel2 = Kernel(program, "XdotEpilogue");
|
||||
auto kernel1 = Kernel(program_, "Xdot");
|
||||
auto kernel2 = Kernel(program_, "XdotEpilogue");
|
||||
|
||||
// Creates the buffer for intermediate values
|
||||
auto temp_size = 2*db_["WGS2"];
|
||||
|
|
|
@ -43,9 +43,8 @@ void Xnrm2<T>::DoNrm2(const size_t n,
|
|||
TestVectorScalar(1, nrm2_buffer, nrm2_offset);
|
||||
|
||||
// Retrieves the Xnrm2 kernels from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel1 = Kernel(program, "Xnrm2");
|
||||
auto kernel2 = Kernel(program, "Xnrm2Epilogue");
|
||||
auto kernel1 = Kernel(program_, "Xnrm2");
|
||||
auto kernel2 = Kernel(program_, "Xnrm2Epilogue");
|
||||
|
||||
// Creates the buffer for intermediate values
|
||||
auto temp_size = 2*db_["WGS2"];
|
||||
|
|
|
@ -49,8 +49,7 @@ void Xscal<T>::DoScal(const size_t n, const T alpha,
|
|||
auto kernel_name = (use_fast_kernel) ? "XscalFast" : "Xscal";
|
||||
|
||||
// Retrieves the Xscal kernel from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
if (use_fast_kernel) {
|
||||
|
|
|
@ -52,8 +52,7 @@ void Xswap<T>::DoSwap(const size_t n,
|
|||
auto kernel_name = (use_fast_kernel) ? "XswapFast" : "Xswap";
|
||||
|
||||
// Retrieves the Xswap kernel from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
if (use_fast_kernel) {
|
||||
|
|
|
@ -122,8 +122,7 @@ void Xgemv<T>::MatVec(const Layout layout, const Transpose a_transpose,
|
|||
}
|
||||
|
||||
// Retrieves the Xgemv kernel from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(m_real));
|
||||
|
|
|
@ -53,8 +53,7 @@ void Xger<T>::DoGer(const Layout layout,
|
|||
TestVectorY(n, y_buffer, y_offset, y_inc);
|
||||
|
||||
// Retrieves the kernel from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel = Kernel(program, "Xger");
|
||||
auto kernel = Kernel(program_, "Xger");
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(a_one));
|
||||
|
|
|
@ -67,8 +67,7 @@ void Xher<T,U>::DoHer(const Layout layout, const Triangle triangle,
|
|||
const auto matching_alpha = GetAlpha(alpha);
|
||||
|
||||
// Retrieves the kernel from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel = Kernel(program, "Xher");
|
||||
auto kernel = Kernel(program_, "Xher");
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(n));
|
||||
|
|
|
@ -54,8 +54,7 @@ void Xher2<T>::DoHer2(const Layout layout, const Triangle triangle,
|
|||
TestVectorY(n, y_buffer, y_offset, y_inc);
|
||||
|
||||
// Retrieves the kernel from the compiled binary
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel = Kernel(program, "Xher2");
|
||||
auto kernel = Kernel(program_, "Xher2");
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(n));
|
||||
|
|
|
@ -150,9 +150,6 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
|
|||
const auto c_one_i = (c_want_rotated) ? n_ceiled : m_ceiled;
|
||||
const auto c_two_i = (c_want_rotated) ? m_ceiled : n_ceiled;
|
||||
|
||||
// Loads the program from the database
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
|
||||
// Determines whether or not temporary matrices are needed
|
||||
auto a_no_temp = a_one == a_one_i && a_two == a_two_i && a_ld == a_one && a_offset == 0 &&
|
||||
a_do_transpose == false && a_conjugate == false;
|
||||
|
@ -178,7 +175,7 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
|
||||
a_one, a_two, a_ld, a_offset, a_buffer,
|
||||
a_one_i, a_two_i, a_one_i, 0, a_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, a_do_transpose, a_conjugate);
|
||||
eventWaitList.push_back(eventProcessA);
|
||||
}
|
||||
|
@ -189,7 +186,7 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
|
||||
b_one, b_two, b_ld, b_offset, b_buffer,
|
||||
b_one_i, b_two_i, b_one_i, 0, b_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, b_do_transpose, b_conjugate);
|
||||
eventWaitList.push_back(eventProcessB);
|
||||
}
|
||||
|
@ -200,13 +197,13 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
|
||||
c_one, c_two, c_ld, c_offset, c_buffer,
|
||||
c_one_i, c_two_i, c_one_i, 0, c_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, c_do_transpose, false);
|
||||
eventWaitList.push_back(eventProcessC);
|
||||
}
|
||||
|
||||
// Retrieves the Xgemm kernel from the compiled binary
|
||||
auto kernel = Kernel(program, "Xgemm");
|
||||
auto kernel = Kernel(program_, "Xgemm");
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(m_ceiled));
|
||||
|
@ -236,7 +233,7 @@ void Xgemm<T>::GemmIndirect(const size_t m, const size_t n, const size_t k,
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList,
|
||||
c_one_i, c_two_i, c_one_i, 0, c_temp,
|
||||
c_one, c_two, c_ld, c_offset, c_buffer,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
false, c_do_transpose, false);
|
||||
}
|
||||
}
|
||||
|
@ -255,13 +252,10 @@ void Xgemm<T>::GemmDirect(const size_t m, const size_t n, const size_t k,
|
|||
const bool a_do_transpose, const bool b_do_transpose, const bool c_do_transpose,
|
||||
const bool a_conjugate, const bool b_conjugate) {
|
||||
|
||||
// Loads the program from the database
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
|
||||
// Retrieves the proper XgemmDirect kernel from the compiled binary
|
||||
const auto name = (a_do_transpose) ? (b_do_transpose ? "XgemmDirectTT" : "XgemmDirectTN") :
|
||||
(b_do_transpose ? "XgemmDirectNT" : "XgemmDirectNN");
|
||||
auto kernel = Kernel(program, name);
|
||||
auto kernel = Kernel(program_, name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(m));
|
||||
|
|
|
@ -58,8 +58,7 @@ void Xhemm<T>::DoHemm(const Layout layout, const Side side, const Triangle trian
|
|||
|
||||
// Creates a general matrix from the hermitian matrix to be able to run the regular Xgemm
|
||||
// routine afterwards
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the arguments for the hermitian-to-squared kernel
|
||||
kernel.SetArgument(0, static_cast<int>(k));
|
||||
|
|
|
@ -30,6 +30,7 @@ class Xhemm: public Xgemm<T> {
|
|||
using Xgemm<T>::queue_;
|
||||
using Xgemm<T>::context_;
|
||||
using Xgemm<T>::device_;
|
||||
using Xgemm<T>::program_;
|
||||
using Xgemm<T>::db_;
|
||||
using Xgemm<T>::DoGemm;
|
||||
|
||||
|
|
|
@ -81,9 +81,6 @@ void Xher2k<T,U>::DoHer2k(const Layout layout, const Triangle triangle, const Tr
|
|||
// Decides which kernel to run: the upper-triangular or lower-triangular version
|
||||
auto kernel_name = (triangle == Triangle::kUpper) ? "XgemmUpper" : "XgemmLower";
|
||||
|
||||
// Loads the program from the database
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
|
||||
// Determines whether or not temporary matrices are needed
|
||||
auto a1_no_temp = ab_one == n_ceiled && ab_two == k_ceiled && a_ld == n_ceiled && a_offset == 0 &&
|
||||
ab_rotated == false && ab_conjugate == false;
|
||||
|
@ -116,7 +113,7 @@ void Xher2k<T,U>::DoHer2k(const Layout layout, const Triangle triangle, const Tr
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA1.pointer(), emptyEventList,
|
||||
ab_one, ab_two, a_ld, a_offset, a_buffer,
|
||||
n_ceiled, k_ceiled, n_ceiled, 0, a1_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, ab_rotated, ab_conjugate);
|
||||
eventWaitList.push_back(eventProcessA1);
|
||||
}
|
||||
|
@ -125,7 +122,7 @@ void Xher2k<T,U>::DoHer2k(const Layout layout, const Triangle triangle, const Tr
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA2.pointer(), emptyEventList,
|
||||
ab_one, ab_two, a_ld, a_offset, a_buffer,
|
||||
n_ceiled, k_ceiled, n_ceiled, 0, a2_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, ab_rotated, !ab_conjugate);
|
||||
eventWaitList.push_back(eventProcessA2);
|
||||
}
|
||||
|
@ -134,7 +131,7 @@ void Xher2k<T,U>::DoHer2k(const Layout layout, const Triangle triangle, const Tr
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB1.pointer(), emptyEventList,
|
||||
ab_one, ab_two, b_ld, b_offset, b_buffer,
|
||||
n_ceiled, k_ceiled, n_ceiled, 0, b1_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, ab_rotated, ab_conjugate);
|
||||
eventWaitList.push_back(eventProcessB1);
|
||||
}
|
||||
|
@ -143,7 +140,7 @@ void Xher2k<T,U>::DoHer2k(const Layout layout, const Triangle triangle, const Tr
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB2.pointer(), emptyEventList,
|
||||
ab_one, ab_two, b_ld, b_offset, b_buffer,
|
||||
n_ceiled, k_ceiled, n_ceiled, 0, b2_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, ab_rotated, !ab_conjugate);
|
||||
eventWaitList.push_back(eventProcessB2);
|
||||
}
|
||||
|
@ -154,12 +151,12 @@ void Xher2k<T,U>::DoHer2k(const Layout layout, const Triangle triangle, const Tr
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
|
||||
n, n, c_ld, c_offset, c_buffer,
|
||||
n_ceiled, n_ceiled, n_ceiled, 0, c_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, c_rotated, false);
|
||||
eventWaitList.push_back(eventProcessC);
|
||||
|
||||
// Retrieves the XgemmUpper or XgemmLower kernel from the compiled binary
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(n_ceiled));
|
||||
|
@ -201,7 +198,7 @@ void Xher2k<T,U>::DoHer2k(const Layout layout, const Triangle triangle, const Tr
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList,
|
||||
n_ceiled, n_ceiled, n_ceiled, 0, c_temp,
|
||||
n, n, c_ld, c_offset, c_buffer,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
false, c_rotated, false, upper, lower, true);
|
||||
}
|
||||
|
||||
|
|
|
@ -79,9 +79,6 @@ void Xherk<T,U>::DoHerk(const Layout layout, const Triangle triangle, const Tran
|
|||
// Decides which kernel to run: the upper-triangular or lower-triangular version
|
||||
auto kernel_name = (triangle == Triangle::kUpper) ? "XgemmUpper" : "XgemmLower";
|
||||
|
||||
// Loads the program from the database
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
|
||||
// Determines whether or not temporary matrices are needed
|
||||
auto a_no_temp = a_one == n_ceiled && a_two == k_ceiled && a_ld == n_ceiled && a_offset == 0 &&
|
||||
a_rotated == false && a_conjugate == false;
|
||||
|
@ -109,7 +106,7 @@ void Xherk<T,U>::DoHerk(const Layout layout, const Triangle triangle, const Tran
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
|
||||
a_one, a_two, a_ld, a_offset, a_buffer,
|
||||
n_ceiled, k_ceiled, n_ceiled, 0, a_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, a_rotated, a_conjugate);
|
||||
eventWaitList.push_back(eventProcessA);
|
||||
}
|
||||
|
@ -118,7 +115,7 @@ void Xherk<T,U>::DoHerk(const Layout layout, const Triangle triangle, const Tran
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
|
||||
a_one, a_two, a_ld, a_offset, a_buffer,
|
||||
n_ceiled, k_ceiled, n_ceiled, 0, b_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, a_rotated, b_conjugate);
|
||||
eventWaitList.push_back(eventProcessB);
|
||||
}
|
||||
|
@ -129,12 +126,12 @@ void Xherk<T,U>::DoHerk(const Layout layout, const Triangle triangle, const Tran
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
|
||||
n, n, c_ld, c_offset, c_buffer,
|
||||
n_ceiled, n_ceiled, n_ceiled, 0, c_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, c_rotated, false);
|
||||
eventWaitList.push_back(eventProcessC);
|
||||
|
||||
// Retrieves the XgemmUpper or XgemmLower kernel from the compiled binary
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(n_ceiled));
|
||||
|
@ -163,7 +160,7 @@ void Xherk<T,U>::DoHerk(const Layout layout, const Triangle triangle, const Tran
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList,
|
||||
n_ceiled, n_ceiled, n_ceiled, 0, c_temp,
|
||||
n, n, c_ld, c_offset, c_buffer,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
false, c_rotated, false, upper, lower, true);
|
||||
}
|
||||
|
||||
|
|
|
@ -30,12 +30,12 @@ Xsymm<T>::Xsymm(Queue &queue, EventPointer event, const std::string &name):
|
|||
// The main routine
|
||||
template <typename T>
|
||||
void Xsymm<T>::DoSymm(const Layout layout, const Side side, const Triangle triangle,
|
||||
const size_t m, const size_t n,
|
||||
const T alpha,
|
||||
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
const T beta,
|
||||
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld) {
|
||||
const size_t m, const size_t n,
|
||||
const T alpha,
|
||||
const Buffer<T> &a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
const Buffer<T> &b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
const T beta,
|
||||
const Buffer<T> &c_buffer, const size_t c_offset, const size_t c_ld) {
|
||||
|
||||
// Makes sure all dimensions are larger than zero
|
||||
if ((m == 0) || (n == 0) ) { throw BLASError(StatusCode::kInvalidDimension); }
|
||||
|
@ -58,8 +58,7 @@ void Xsymm<T>::DoSymm(const Layout layout, const Side side, const Triangle trian
|
|||
|
||||
// Creates a general matrix from the symmetric matrix to be able to run the regular Xgemm
|
||||
// routine afterwards
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the arguments for the symmetric-to-squared kernel
|
||||
kernel.SetArgument(0, static_cast<int>(k));
|
||||
|
|
|
@ -32,6 +32,7 @@ class Xsymm: public Xgemm<T> {
|
|||
using Xgemm<T>::queue_;
|
||||
using Xgemm<T>::context_;
|
||||
using Xgemm<T>::device_;
|
||||
using Xgemm<T>::program_;
|
||||
using Xgemm<T>::db_;
|
||||
using Xgemm<T>::DoGemm;
|
||||
|
||||
|
|
|
@ -77,9 +77,6 @@ void Xsyr2k<T>::DoSyr2k(const Layout layout, const Triangle triangle, const Tran
|
|||
// Decides which kernel to run: the upper-triangular or lower-triangular version
|
||||
auto kernel_name = (triangle == Triangle::kUpper) ? "XgemmUpper" : "XgemmLower";
|
||||
|
||||
// Loads the program from the database
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
|
||||
// Determines whether or not temporary matrices are needed
|
||||
auto a_no_temp = ab_one == n_ceiled && ab_two == k_ceiled && a_ld == n_ceiled && a_offset == 0 &&
|
||||
ab_rotated == false;
|
||||
|
@ -103,7 +100,7 @@ void Xsyr2k<T>::DoSyr2k(const Layout layout, const Triangle triangle, const Tran
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
|
||||
ab_one, ab_two, a_ld, a_offset, a_buffer,
|
||||
n_ceiled, k_ceiled, n_ceiled, 0, a_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, ab_rotated, false);
|
||||
eventWaitList.push_back(eventProcessA);
|
||||
}
|
||||
|
@ -112,7 +109,7 @@ void Xsyr2k<T>::DoSyr2k(const Layout layout, const Triangle triangle, const Tran
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessB.pointer(), emptyEventList,
|
||||
ab_one, ab_two, b_ld, b_offset, b_buffer,
|
||||
n_ceiled, k_ceiled, n_ceiled, 0, b_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, ab_rotated, false);
|
||||
eventWaitList.push_back(eventProcessB);
|
||||
}
|
||||
|
@ -123,12 +120,12 @@ void Xsyr2k<T>::DoSyr2k(const Layout layout, const Triangle triangle, const Tran
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
|
||||
n, n, c_ld, c_offset, c_buffer,
|
||||
n_ceiled, n_ceiled, n_ceiled, 0, c_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, c_rotated, false);
|
||||
eventWaitList.push_back(eventProcessC);
|
||||
|
||||
// Retrieves the XgemmUpper or XgemmLower kernel from the compiled binary
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(n_ceiled));
|
||||
|
@ -168,7 +165,7 @@ void Xsyr2k<T>::DoSyr2k(const Layout layout, const Triangle triangle, const Tran
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList,
|
||||
n_ceiled, n_ceiled, n_ceiled, 0, c_temp,
|
||||
n, n, c_ld, c_offset, c_buffer,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
false, c_rotated, false, upper, lower, false);
|
||||
}
|
||||
|
||||
|
|
|
@ -74,9 +74,6 @@ void Xsyrk<T>::DoSyrk(const Layout layout, const Triangle triangle, const Transp
|
|||
// Decides which kernel to run: the upper-triangular or lower-triangular version
|
||||
auto kernel_name = (triangle == Triangle::kUpper) ? "XgemmUpper" : "XgemmLower";
|
||||
|
||||
// Loads the program from the database
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
|
||||
// Determines whether or not temporary matrices are needed
|
||||
auto a_no_temp = a_one == n_ceiled && a_two == k_ceiled && a_ld == n_ceiled && a_offset == 0 &&
|
||||
a_rotated == false;
|
||||
|
@ -97,7 +94,7 @@ void Xsyrk<T>::DoSyrk(const Layout layout, const Triangle triangle, const Transp
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessA.pointer(), emptyEventList,
|
||||
a_one, a_two, a_ld, a_offset, a_buffer,
|
||||
n_ceiled, k_ceiled, n_ceiled, 0, a_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, a_rotated, false);
|
||||
eventWaitList.push_back(eventProcessA);
|
||||
}
|
||||
|
@ -108,12 +105,12 @@ void Xsyrk<T>::DoSyrk(const Layout layout, const Triangle triangle, const Transp
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, eventProcessC.pointer(), emptyEventList,
|
||||
n, n, c_ld, c_offset, c_buffer,
|
||||
n_ceiled, n_ceiled, n_ceiled, 0, c_temp,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
true, c_rotated, false);
|
||||
eventWaitList.push_back(eventProcessC);
|
||||
|
||||
// Retrieves the XgemmUpper or XgemmLower kernel from the compiled binary
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the kernel arguments
|
||||
kernel.SetArgument(0, static_cast<int>(n_ceiled));
|
||||
|
@ -142,7 +139,7 @@ void Xsyrk<T>::DoSyrk(const Layout layout, const Triangle triangle, const Transp
|
|||
PadCopyTransposeMatrix(queue_, device_, db_, event_, eventWaitList,
|
||||
n_ceiled, n_ceiled, n_ceiled, 0, c_temp,
|
||||
n, n, c_ld, c_offset, c_buffer,
|
||||
ConstantOne<T>(), program,
|
||||
ConstantOne<T>(), program_,
|
||||
false, c_rotated, false, upper, lower, false);
|
||||
}
|
||||
|
||||
|
|
|
@ -70,8 +70,7 @@ void Xtrmm<T>::DoTrmm(const Layout layout, const Side side, const Triangle trian
|
|||
|
||||
// Creates a general matrix from the triangular matrix to be able to run the regular Xgemm
|
||||
// routine afterwards
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
auto kernel = Kernel(program, kernel_name);
|
||||
auto kernel = Kernel(program_, kernel_name);
|
||||
|
||||
// Sets the arguments for the triangular-to-squared kernel
|
||||
kernel.SetArgument(0, static_cast<int>(k));
|
||||
|
|
|
@ -31,6 +31,7 @@ class Xtrmm: public Xgemm<T> {
|
|||
using Xgemm<T>::queue_;
|
||||
using Xgemm<T>::context_;
|
||||
using Xgemm<T>::device_;
|
||||
using Xgemm<T>::program_;
|
||||
using Xgemm<T>::db_;
|
||||
using Xgemm<T>::DoGemm;
|
||||
|
||||
|
|
|
@ -65,14 +65,11 @@ void Xomatcopy<T>::DoOmatcopy(const Layout layout, const Transpose a_transpose,
|
|||
TestMatrixA(a_one, a_two, a_buffer, a_offset, a_ld);
|
||||
TestMatrixB(b_one, b_two, b_buffer, b_offset, b_ld);
|
||||
|
||||
// Loads the program from the database
|
||||
const auto program = GetProgramFromCache(context_, PrecisionValue<T>(), routine_name_);
|
||||
|
||||
auto emptyEventList = std::vector<Event>();
|
||||
PadCopyTransposeMatrix(queue_, device_, db_, event_, emptyEventList,
|
||||
a_one, a_two, a_ld, a_offset, a_buffer,
|
||||
b_one, b_two, b_ld, b_offset, b_buffer,
|
||||
alpha, program, false, transpose, conjugate);
|
||||
alpha, program_, false, transpose, conjugate);
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
|
Loading…
Reference in New Issue