diff --git a/src/cache.hpp b/src/cache.hpp index f6a948b6..1c8c9d4c 100644 --- a/src/cache.hpp +++ b/src/cache.hpp @@ -80,8 +80,8 @@ extern template std::string BinaryCache::Get(const BinaryKeyRef &, bool *) const // The key struct for the cache of compiled OpenCL programs (context-dependent) // Order of fields: context, device_id, precision, routine_name (smaller fields first) -typedef std::tuple ProgramKey; -typedef std::tuple ProgramKeyRef; +typedef std::tuple ProgramKey; +typedef std::tuple ProgramKeyRef; typedef Cache ProgramCache; @@ -94,8 +94,8 @@ class Database; // The key struct for the cache of database maps. // Order of fields: platform_id, device_id, precision, kernel_name (smaller fields first) -typedef std::tuple DatabaseKey; -typedef std::tuple DatabaseKeyRef; +typedef std::tuple DatabaseKey; +typedef std::tuple DatabaseKeyRef; typedef Cache DatabaseCache; diff --git a/src/clblast.cpp b/src/clblast.cpp index 19d7ef0a..9f865a23 100644 --- a/src/clblast.cpp +++ b/src/clblast.cpp @@ -2492,7 +2492,7 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern // Retrieves the device name const auto device_cpp = Device(device); - const auto platform_id = device_cpp.Platform(); + const auto platform_id = device_cpp.PlatformID(); const auto device_name = GetDeviceName(device_cpp); // Retrieves the current database values to verify whether the new ones are complete diff --git a/src/clpp11.hpp b/src/clpp11.hpp index 7d348e18..97045644 100644 --- a/src/clpp11.hpp +++ b/src/clpp11.hpp @@ -59,34 +59,36 @@ namespace clblast { // ================================================================================================= // Represents a runtime error returned by an OpenCL API function -class CLError : public ErrorCode { +class CLCudaAPIError : public ErrorCode { public: - explicit CLError(cl_int status, const std::string &where): - ErrorCode(status, - where, - "OpenCL error: " + where + ": " + std::to_string(static_cast(status))) { + explicit CLCudaAPIError(cl_int status, const std::string &where): + ErrorCode(status, where, "OpenCL error: " + where + ": " + + std::to_string(static_cast(status))) { } static void Check(const cl_int status, const std::string &where) { if (status != CL_SUCCESS) { - throw CLError(status, where); + throw CLCudaAPIError(status, where); } } static void CheckDtor(const cl_int status, const std::string &where) { if (status != CL_SUCCESS) { - fprintf(stderr, "CLBlast: %s (ignoring)\n", CLError(status, where).what()); + fprintf(stderr, "CLBlast: %s (ignoring)\n", CLCudaAPIError(status, where).what()); } } }; +// Exception returned when building a program +using CLCudaAPIBuildError = CLCudaAPIError; + // ================================================================================================= // Error occurred in OpenCL -#define CheckError(call) CLError::Check(call, CLError::TrimCallString(#call)) +#define CheckError(call) CLCudaAPIError::Check(call, CLCudaAPIError::TrimCallString(#call)) -// Error occured in OpenCL (no-exception version for destructors) -#define CheckErrorDtor(call) CLError::CheckDtor(call, CLError::TrimCallString(#call)) +// Error occurred in OpenCL (no-exception version for destructors) +#define CheckErrorDtor(call) CLCudaAPIError::CheckDtor(call, CLCudaAPIError::TrimCallString(#call)) // ================================================================================================= @@ -142,6 +144,9 @@ using EventPointer = cl_event*; // ================================================================================================= +// Raw platform ID type +using RawPlatformID = cl_platform_id; + // C++11 version of 'cl_platform_id' class Platform { public: @@ -177,7 +182,7 @@ class Platform { } // Accessor to the private data-member - const cl_platform_id& operator()() const { return platform_; } + const RawPlatformID& operator()() const { return platform_; } private: cl_platform_id platform_; @@ -206,6 +211,9 @@ inline std::vector GetAllPlatforms() { // ================================================================================================= +// Raw device ID type +using RawDeviceID = cl_device_id; + // C++11 version of 'cl_device_id' class Device { public: @@ -230,7 +238,7 @@ class Device { } // Methods to retrieve device information - cl_platform_id Platform() const { return GetInfo(CL_DEVICE_PLATFORM); } + RawPlatformID PlatformID() const { return GetInfo(CL_DEVICE_PLATFORM); } std::string Version() const { return GetInfoString(CL_DEVICE_VERSION); } size_t VersionNumber() const { @@ -262,11 +270,19 @@ class Device { unsigned long LocalMemSize() const { return static_cast(GetInfo(CL_DEVICE_LOCAL_MEM_SIZE)); } + std::string Capabilities() const { return GetInfoString(CL_DEVICE_EXTENSIONS); } bool HasExtension(const std::string &extension) const { const auto extensions = Capabilities(); return extensions.find(extension) != std::string::npos; } + bool SupportsFP64() const { + return HasExtension("cl_khr_fp64"); + } + bool SupportsFP16() const { + if (Name() == "Mali-T628") { return true; } // supports fp16 but not cl_khr_fp16 officially + return HasExtension("cl_khr_fp16"); + } size_t CoreClock() const { return static_cast(GetInfo(CL_DEVICE_MAX_CLOCK_FREQUENCY)); @@ -330,9 +346,8 @@ class Device { std::string{"."} + std::to_string(GetInfo(CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV)); } - // Accessor to the private data-member - const cl_device_id& operator()() const { return device_; } + const RawDeviceID& operator()() const { return device_; } private: cl_device_id device_; @@ -366,6 +381,9 @@ class Device { // ================================================================================================= +// Raw context type +using RawContext = cl_context; + // C++11 version of 'cl_context' class Context { public: @@ -385,12 +403,12 @@ class Context { auto status = CL_SUCCESS; const cl_device_id dev = device(); *context_ = clCreateContext(nullptr, 1, &dev, nullptr, nullptr, &status); - CLError::Check(status, "clCreateContext"); + CLCudaAPIError::Check(status, "clCreateContext"); } // Accessor to the private data-member - const cl_context& operator()() const { return *context_; } - cl_context* pointer() const { return &(*context_); } + const RawContext& operator()() const { return *context_; } + RawContext* pointer() const { return &(*context_); } private: std::shared_ptr context_; }; @@ -400,9 +418,6 @@ using ContextPointer = cl_context*; // ================================================================================================= -// Enumeration of build statuses of the run-time compilation process -enum class BuildStatus { kSuccess, kError, kInvalid }; - // C++11 version of 'cl_program'. class Program { public: @@ -415,10 +430,10 @@ class Program { delete p; }) { const char *source_ptr = &source[0]; - size_t length = source.length(); + const auto length = source.length(); auto status = CL_SUCCESS; *program_ = clCreateProgramWithSource(context(), 1, &source_ptr, &length, &status); - CLError::Check(status, "clCreateProgramWithSource"); + CLCudaAPIError::Check(status, "clCreateProgramWithSource"); } // Binary-based constructor with memory management @@ -428,18 +443,18 @@ class Program { delete p; }) { const char *binary_ptr = &binary[0]; - size_t length = binary.length(); + const auto length = binary.length(); auto status1 = CL_SUCCESS; auto status2 = CL_SUCCESS; - const cl_device_id dev = device(); + const auto dev = device(); *program_ = clCreateProgramWithBinary(context(), 1, &dev, &length, reinterpret_cast(&binary_ptr), &status1, &status2); - CLError::Check(status1, "clCreateProgramWithBinary (binary status)"); - CLError::Check(status2, "clCreateProgramWithBinary"); + CLCudaAPIError::Check(status1, "clCreateProgramWithBinary (binary status)"); + CLCudaAPIError::Check(status2, "clCreateProgramWithBinary"); } - // Compiles the device program and returns whether or not there where any warnings/errors + // Compiles the device program and checks whether or not there are any warnings/errors void Build(const Device &device, std::vector &options) { options.push_back("-cl-std=CL1.1"); auto options_string = std::accumulate(options.begin(), options.end(), std::string{" "}); @@ -447,6 +462,11 @@ class Program { CheckError(clBuildProgram(*program_, 1, &dev, options_string.c_str(), nullptr, nullptr)); } + // Confirms whether a certain status code is an actual compilation error or warning + bool StatusIsCompilationWarningOrError(const cl_int status) const { + return (status == CL_BUILD_PROGRAM_FAILURE); + } + // Retrieves the warning/error message from the compiler (if any) std::string GetBuildInfo(const Device &device) const { auto bytes = size_t{0}; @@ -477,6 +497,9 @@ class Program { // ================================================================================================= +// Raw command-queue type +using RawCommandQueue = cl_command_queue; + // C++11 version of 'cl_command_queue' class Queue { public: @@ -495,7 +518,7 @@ class Queue { }) { auto status = CL_SUCCESS; *queue_ = clCreateCommandQueue(context(), device(), CL_QUEUE_PROFILING_ENABLE, &status); - CLError::Check(status, "clCreateCommandQueue"); + CLCudaAPIError::Check(status, "clCreateCommandQueue"); } // Synchronizes the queue @@ -523,7 +546,7 @@ class Queue { } // Accessor to the private data-member - const cl_command_queue& operator()() const { return *queue_; } + const RawCommandQueue& operator()() const { return *queue_; } private: std::shared_ptr queue_; }; @@ -587,7 +610,7 @@ class Buffer { if (access_ == BufferAccess::kWriteOnly) { flags = CL_MEM_WRITE_ONLY; } auto status = CL_SUCCESS; *buffer_ = clCreateBuffer(context(), flags, size*sizeof(T), nullptr, &status); - CLError::Check(status, "clCreateBuffer"); + CLCudaAPIError::Check(status, "clCreateBuffer"); } // As above, but now with read/write access as a default @@ -719,7 +742,7 @@ class Kernel { }) { auto status = CL_SUCCESS; *kernel_ = clCreateKernel(program(), name.c_str(), &status); - CLError::Check(status, "clCreateKernel"); + CLCudaAPIError::Check(status, "clCreateKernel"); } // Sets a kernel argument at the indicated position diff --git a/src/routine.cpp b/src/routine.cpp index b25eec56..aaa85fde 100644 --- a/src/routine.cpp +++ b/src/routine.cpp @@ -60,7 +60,6 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, event_(event), context_(queue_.GetContext()), device_(queue_.GetDevice()), - platform_(device_.Platform()), db_(kernel_names) { InitDatabase(userDatabase); @@ -68,18 +67,19 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name, } void Routine::InitDatabase(const std::vector &userDatabase) { + const auto platform_id = device_.PlatformID(); for (const auto &kernel_name : kernel_names_) { // Queries the cache to see whether or not the kernel parameter database is already there bool has_db; - db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_, device_(), precision_, kernel_name }, + db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_id, device_(), precision_, kernel_name }, &has_db); if (has_db) { continue; } // Builds the parameter database for this device and routine set and stores it in the cache log_debug("Searching database for kernel '" + kernel_name + "'"); db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase); - DatabaseCache::Instance().Store(DatabaseKey{ platform_, device_(), precision_, kernel_name }, + DatabaseCache::Instance().Store(DatabaseKey{ platform_id, device_(), precision_, kernel_name }, Database{ db_(kernel_name) }); } } @@ -123,13 +123,13 @@ void Routine::InitProgram(std::initializer_list source) { // Otherwise, the kernel will be compiled and program will be built. Both the binary and the // program will be added to the cache. - // Inspects whether or not cl_khr_fp64 is supported in case of double precision + // Inspects whether or not FP64 is supported in case of double precision if ((precision_ == Precision::kDouble && !PrecisionSupported(device_)) || (precision_ == Precision::kComplexDouble && !PrecisionSupported(device_))) { throw RuntimeErrorCode(StatusCode::kNoDoublePrecision); } - // As above, but for cl_khr_fp16 (half precision) + // As above, but for FP16 (half precision) if (precision_ == Precision::kHalf && !PrecisionSupported(device_)) { throw RuntimeErrorCode(StatusCode::kNoHalfPrecision); } @@ -188,8 +188,8 @@ void Routine::InitProgram(std::initializer_list source) { program_ = Program(context_, source_string); try { program_.Build(device_, options); - } catch (const CLError &e) { - if (e.status() == CL_BUILD_PROGRAM_FAILURE) { + } catch (const CLCudaAPIBuildError &e) { + if (program_.StatusIsCompilationWarningOrError(e.status())) { fprintf(stdout, "OpenCL compiler error/warning: %s\n", program_.GetBuildInfo(device_).c_str()); } diff --git a/src/routine.hpp b/src/routine.hpp index e77e35ad..a8f1cb6a 100644 --- a/src/routine.hpp +++ b/src/routine.hpp @@ -75,7 +75,6 @@ class Routine { EventPointer event_; const Context context_; const Device device_; - const cl_platform_id platform_; // Compiled program (either retrieved from cache or compiled in slow path) Program program_; diff --git a/src/utilities/clblast_exceptions.cpp b/src/utilities/clblast_exceptions.cpp index 96f10860..32526215 100644 --- a/src/utilities/clblast_exceptions.cpp +++ b/src/utilities/clblast_exceptions.cpp @@ -55,7 +55,7 @@ StatusCode DispatchException() } catch (BLASError &e) { // no message is printed for invalid argument errors status = e.status(); - } catch (CLError &e) { + } catch (CLCudaAPIError &e) { message = e.what(); status = static_cast(e.status()); } catch (RuntimeErrorCode &e) { diff --git a/src/utilities/utilities.cpp b/src/utilities/utilities.cpp index 4b8d5a09..a5c1d45e 100644 --- a/src/utilities/utilities.cpp +++ b/src/utilities/utilities.cpp @@ -391,16 +391,9 @@ template <> Precision PrecisionValue() { return Precision::kComplexDoub // Returns false is this precision is not supported by the device template <> bool PrecisionSupported(const Device &) { return true; } template <> bool PrecisionSupported(const Device &) { return true; } -template <> bool PrecisionSupported(const Device &device) { - return device.HasExtension(kKhronosDoublePrecision); -} -template <> bool PrecisionSupported(const Device &device) { - return device.HasExtension(kKhronosDoublePrecision); -} -template <> bool PrecisionSupported(const Device &device) { - if (device.Name() == "Mali-T628") { return true; } // supports fp16 but not cl_khr_fp16 officially - return device.HasExtension(kKhronosHalfPrecision); -} +template <> bool PrecisionSupported(const Device &device) { return device.SupportsFP64(); } +template <> bool PrecisionSupported(const Device &device) { return device.SupportsFP64(); } +template <> bool PrecisionSupported(const Device &device) { return device.SupportsFP16(); } // ================================================================================================= diff --git a/src/utilities/utilities.hpp b/src/utilities/utilities.hpp index e45c606c..b2949c27 100644 --- a/src/utilities/utilities.hpp +++ b/src/utilities/utilities.hpp @@ -31,15 +31,13 @@ namespace clblast { // ================================================================================================= // Shorthands for half-precision -using half = cl_half; // based on the OpenCL type, which is actually an 'unsigned short' +using half = unsigned short; // the 'cl_half' OpenCL type is actually an 'unsigned short' // Shorthands for complex data-types using float2 = std::complex; using double2 = std::complex; // Khronos OpenCL extensions -const std::string kKhronosHalfPrecision = "cl_khr_fp16"; -const std::string kKhronosDoublePrecision = "cl_khr_fp64"; const std::string kKhronosAttributesAMD = "cl_amd_device_attribute_query"; const std::string kKhronosAttributesNVIDIA = "cl_nv_device_attribute_query"; diff --git a/test/diagnostics.cpp b/test/diagnostics.cpp index 99b936f8..af56cd30 100644 --- a/test/diagnostics.cpp +++ b/test/diagnostics.cpp @@ -85,7 +85,7 @@ void OpenCLDiagnostics(int argc, char *argv[]) { printf("* device.Name() %.4lf ms\n", TimeFunction(kNumRuns, [&](){device.Name();} )); printf("* device.Vendor() %.4lf ms\n", TimeFunction(kNumRuns, [&](){device.Vendor();} )); printf("* device.Version() %.4lf ms\n", TimeFunction(kNumRuns, [&](){device.Version();} )); - printf("* device.Platform() %.4lf ms\n", TimeFunction(kNumRuns, [&](){device.Platform();} )); + printf("* device.Platform() %.4lf ms\n", TimeFunction(kNumRuns, [&](){ device.PlatformID();} )); printf("* Buffer(context, 1024) %.4lf ms\n", TimeFunction(kNumRuns, [&](){Buffer(context, 1024);} )); printf("\n"); diff --git a/test/test_utilities.cpp b/test/test_utilities.cpp index b8fd94a9..579eb61c 100644 --- a/test/test_utilities.cpp +++ b/test/test_utilities.cpp @@ -88,7 +88,7 @@ void FloatToHalfBuffer(std::vector& result, const std::vector& sour } // As above, but now for OpenCL data-types instead of std::vectors -Buffer HalfToFloatBuffer(const Buffer& source, cl_command_queue queue_raw) { +Buffer HalfToFloatBuffer(const Buffer& source, RawCommandQueue queue_raw) { const auto size = source.GetSize() / sizeof(half); auto queue = Queue(queue_raw); auto context = queue.GetContext(); @@ -99,7 +99,7 @@ Buffer HalfToFloatBuffer(const Buffer& source, cl_command_queue que result.Write(queue, size, result_cpu); return result; } -void FloatToHalfBuffer(Buffer& result, const Buffer& source, cl_command_queue queue_raw) { +void FloatToHalfBuffer(Buffer& result, const Buffer& source, RawCommandQueue queue_raw) { const auto size = source.GetSize() / sizeof(float); auto queue = Queue(queue_raw); auto context = queue.GetContext(); diff --git a/test/test_utilities.hpp b/test/test_utilities.hpp index fc50a754..fe7a9cd2 100644 --- a/test/test_utilities.hpp +++ b/test/test_utilities.hpp @@ -89,8 +89,8 @@ std::vector HalfToFloatBuffer(const std::vector& source); void FloatToHalfBuffer(std::vector& result, const std::vector& source); // As above, but now for OpenCL data-types instead of std::vectors -Buffer HalfToFloatBuffer(const Buffer& source, cl_command_queue queue_raw); -void FloatToHalfBuffer(Buffer& result, const Buffer& source, cl_command_queue queue_raw); +Buffer HalfToFloatBuffer(const Buffer& source, RawCommandQueue queue_raw); +void FloatToHalfBuffer(Buffer& result, const Buffer& source, RawCommandQueue queue_raw); // ================================================================================================= } // namespace clblast