mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-07-15 10:55:42 +02:00
Merge pull request #198 from CNugteren/cuda_api_preparation
Cuda API preparation
This commit is contained in:
commit
2bb8402ec1
|
@ -80,8 +80,8 @@ extern template std::string BinaryCache::Get(const BinaryKeyRef &, bool *) const
|
||||||
|
|
||||||
// The key struct for the cache of compiled OpenCL programs (context-dependent)
|
// The key struct for the cache of compiled OpenCL programs (context-dependent)
|
||||||
// Order of fields: context, device_id, precision, routine_name (smaller fields first)
|
// Order of fields: context, device_id, precision, routine_name (smaller fields first)
|
||||||
typedef std::tuple<cl_context, cl_device_id, Precision, std::string> ProgramKey;
|
typedef std::tuple<RawContext, RawDeviceID, Precision, std::string> ProgramKey;
|
||||||
typedef std::tuple<const cl_context &, const cl_device_id &, const Precision &, const std::string &> ProgramKeyRef;
|
typedef std::tuple<const RawContext &, const RawDeviceID &, const Precision &, const std::string &> ProgramKeyRef;
|
||||||
|
|
||||||
typedef Cache<ProgramKey, Program> ProgramCache;
|
typedef Cache<ProgramKey, Program> ProgramCache;
|
||||||
|
|
||||||
|
@ -94,8 +94,8 @@ class Database;
|
||||||
|
|
||||||
// The key struct for the cache of database maps.
|
// The key struct for the cache of database maps.
|
||||||
// Order of fields: platform_id, device_id, precision, kernel_name (smaller fields first)
|
// Order of fields: platform_id, device_id, precision, kernel_name (smaller fields first)
|
||||||
typedef std::tuple<cl_platform_id, cl_device_id, Precision, std::string> DatabaseKey;
|
typedef std::tuple<RawPlatformID, RawDeviceID, Precision, std::string> DatabaseKey;
|
||||||
typedef std::tuple<const cl_platform_id &, const cl_device_id &, const Precision &, const std::string &> DatabaseKeyRef;
|
typedef std::tuple<const RawPlatformID &, const RawDeviceID &, const Precision &, const std::string &> DatabaseKeyRef;
|
||||||
|
|
||||||
typedef Cache<DatabaseKey, Database> DatabaseCache;
|
typedef Cache<DatabaseKey, Database> DatabaseCache;
|
||||||
|
|
||||||
|
|
|
@ -2492,7 +2492,7 @@ StatusCode OverrideParameters(const cl_device_id device, const std::string &kern
|
||||||
|
|
||||||
// Retrieves the device name
|
// Retrieves the device name
|
||||||
const auto device_cpp = Device(device);
|
const auto device_cpp = Device(device);
|
||||||
const auto platform_id = device_cpp.Platform();
|
const auto platform_id = device_cpp.PlatformID();
|
||||||
const auto device_name = GetDeviceName(device_cpp);
|
const auto device_name = GetDeviceName(device_cpp);
|
||||||
|
|
||||||
// Retrieves the current database values to verify whether the new ones are complete
|
// Retrieves the current database values to verify whether the new ones are complete
|
||||||
|
|
|
@ -59,34 +59,36 @@ namespace clblast {
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
// Represents a runtime error returned by an OpenCL API function
|
// Represents a runtime error returned by an OpenCL API function
|
||||||
class CLError : public ErrorCode<DeviceError, cl_int> {
|
class CLCudaAPIError : public ErrorCode<DeviceError, cl_int> {
|
||||||
public:
|
public:
|
||||||
explicit CLError(cl_int status, const std::string &where):
|
explicit CLCudaAPIError(cl_int status, const std::string &where):
|
||||||
ErrorCode(status,
|
ErrorCode(status, where, "OpenCL error: " + where + ": " +
|
||||||
where,
|
std::to_string(static_cast<int>(status))) {
|
||||||
"OpenCL error: " + where + ": " + std::to_string(static_cast<int>(status))) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void Check(const cl_int status, const std::string &where) {
|
static void Check(const cl_int status, const std::string &where) {
|
||||||
if (status != CL_SUCCESS) {
|
if (status != CL_SUCCESS) {
|
||||||
throw CLError(status, where);
|
throw CLCudaAPIError(status, where);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void CheckDtor(const cl_int status, const std::string &where) {
|
static void CheckDtor(const cl_int status, const std::string &where) {
|
||||||
if (status != CL_SUCCESS) {
|
if (status != CL_SUCCESS) {
|
||||||
fprintf(stderr, "CLBlast: %s (ignoring)\n", CLError(status, where).what());
|
fprintf(stderr, "CLBlast: %s (ignoring)\n", CLCudaAPIError(status, where).what());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Exception returned when building a program
|
||||||
|
using CLCudaAPIBuildError = CLCudaAPIError;
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
// Error occurred in OpenCL
|
// Error occurred in OpenCL
|
||||||
#define CheckError(call) CLError::Check(call, CLError::TrimCallString(#call))
|
#define CheckError(call) CLCudaAPIError::Check(call, CLCudaAPIError::TrimCallString(#call))
|
||||||
|
|
||||||
// Error occured in OpenCL (no-exception version for destructors)
|
// Error occurred in OpenCL (no-exception version for destructors)
|
||||||
#define CheckErrorDtor(call) CLError::CheckDtor(call, CLError::TrimCallString(#call))
|
#define CheckErrorDtor(call) CLCudaAPIError::CheckDtor(call, CLCudaAPIError::TrimCallString(#call))
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
|
@ -142,6 +144,9 @@ using EventPointer = cl_event*;
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Raw platform ID type
|
||||||
|
using RawPlatformID = cl_platform_id;
|
||||||
|
|
||||||
// C++11 version of 'cl_platform_id'
|
// C++11 version of 'cl_platform_id'
|
||||||
class Platform {
|
class Platform {
|
||||||
public:
|
public:
|
||||||
|
@ -177,7 +182,7 @@ class Platform {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accessor to the private data-member
|
// Accessor to the private data-member
|
||||||
const cl_platform_id& operator()() const { return platform_; }
|
const RawPlatformID& operator()() const { return platform_; }
|
||||||
private:
|
private:
|
||||||
cl_platform_id platform_;
|
cl_platform_id platform_;
|
||||||
|
|
||||||
|
@ -206,6 +211,9 @@ inline std::vector<Platform> GetAllPlatforms() {
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Raw device ID type
|
||||||
|
using RawDeviceID = cl_device_id;
|
||||||
|
|
||||||
// C++11 version of 'cl_device_id'
|
// C++11 version of 'cl_device_id'
|
||||||
class Device {
|
class Device {
|
||||||
public:
|
public:
|
||||||
|
@ -230,7 +238,7 @@ class Device {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Methods to retrieve device information
|
// Methods to retrieve device information
|
||||||
cl_platform_id Platform() const { return GetInfo<cl_platform_id>(CL_DEVICE_PLATFORM); }
|
RawPlatformID PlatformID() const { return GetInfo<cl_platform_id>(CL_DEVICE_PLATFORM); }
|
||||||
std::string Version() const { return GetInfoString(CL_DEVICE_VERSION); }
|
std::string Version() const { return GetInfoString(CL_DEVICE_VERSION); }
|
||||||
size_t VersionNumber() const
|
size_t VersionNumber() const
|
||||||
{
|
{
|
||||||
|
@ -262,11 +270,19 @@ class Device {
|
||||||
unsigned long LocalMemSize() const {
|
unsigned long LocalMemSize() const {
|
||||||
return static_cast<unsigned long>(GetInfo<cl_ulong>(CL_DEVICE_LOCAL_MEM_SIZE));
|
return static_cast<unsigned long>(GetInfo<cl_ulong>(CL_DEVICE_LOCAL_MEM_SIZE));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Capabilities() const { return GetInfoString(CL_DEVICE_EXTENSIONS); }
|
std::string Capabilities() const { return GetInfoString(CL_DEVICE_EXTENSIONS); }
|
||||||
bool HasExtension(const std::string &extension) const {
|
bool HasExtension(const std::string &extension) const {
|
||||||
const auto extensions = Capabilities();
|
const auto extensions = Capabilities();
|
||||||
return extensions.find(extension) != std::string::npos;
|
return extensions.find(extension) != std::string::npos;
|
||||||
}
|
}
|
||||||
|
bool SupportsFP64() const {
|
||||||
|
return HasExtension("cl_khr_fp64");
|
||||||
|
}
|
||||||
|
bool SupportsFP16() const {
|
||||||
|
if (Name() == "Mali-T628") { return true; } // supports fp16 but not cl_khr_fp16 officially
|
||||||
|
return HasExtension("cl_khr_fp16");
|
||||||
|
}
|
||||||
|
|
||||||
size_t CoreClock() const {
|
size_t CoreClock() const {
|
||||||
return static_cast<size_t>(GetInfo<cl_uint>(CL_DEVICE_MAX_CLOCK_FREQUENCY));
|
return static_cast<size_t>(GetInfo<cl_uint>(CL_DEVICE_MAX_CLOCK_FREQUENCY));
|
||||||
|
@ -330,9 +346,8 @@ class Device {
|
||||||
std::string{"."} + std::to_string(GetInfo<cl_uint>(CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV));
|
std::string{"."} + std::to_string(GetInfo<cl_uint>(CL_DEVICE_COMPUTE_CAPABILITY_MINOR_NV));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Accessor to the private data-member
|
// Accessor to the private data-member
|
||||||
const cl_device_id& operator()() const { return device_; }
|
const RawDeviceID& operator()() const { return device_; }
|
||||||
private:
|
private:
|
||||||
cl_device_id device_;
|
cl_device_id device_;
|
||||||
|
|
||||||
|
@ -366,6 +381,9 @@ class Device {
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Raw context type
|
||||||
|
using RawContext = cl_context;
|
||||||
|
|
||||||
// C++11 version of 'cl_context'
|
// C++11 version of 'cl_context'
|
||||||
class Context {
|
class Context {
|
||||||
public:
|
public:
|
||||||
|
@ -385,12 +403,12 @@ class Context {
|
||||||
auto status = CL_SUCCESS;
|
auto status = CL_SUCCESS;
|
||||||
const cl_device_id dev = device();
|
const cl_device_id dev = device();
|
||||||
*context_ = clCreateContext(nullptr, 1, &dev, nullptr, nullptr, &status);
|
*context_ = clCreateContext(nullptr, 1, &dev, nullptr, nullptr, &status);
|
||||||
CLError::Check(status, "clCreateContext");
|
CLCudaAPIError::Check(status, "clCreateContext");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accessor to the private data-member
|
// Accessor to the private data-member
|
||||||
const cl_context& operator()() const { return *context_; }
|
const RawContext& operator()() const { return *context_; }
|
||||||
cl_context* pointer() const { return &(*context_); }
|
RawContext* pointer() const { return &(*context_); }
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<cl_context> context_;
|
std::shared_ptr<cl_context> context_;
|
||||||
};
|
};
|
||||||
|
@ -400,9 +418,6 @@ using ContextPointer = cl_context*;
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
// Enumeration of build statuses of the run-time compilation process
|
|
||||||
enum class BuildStatus { kSuccess, kError, kInvalid };
|
|
||||||
|
|
||||||
// C++11 version of 'cl_program'.
|
// C++11 version of 'cl_program'.
|
||||||
class Program {
|
class Program {
|
||||||
public:
|
public:
|
||||||
|
@ -415,10 +430,10 @@ class Program {
|
||||||
delete p;
|
delete p;
|
||||||
}) {
|
}) {
|
||||||
const char *source_ptr = &source[0];
|
const char *source_ptr = &source[0];
|
||||||
size_t length = source.length();
|
const auto length = source.length();
|
||||||
auto status = CL_SUCCESS;
|
auto status = CL_SUCCESS;
|
||||||
*program_ = clCreateProgramWithSource(context(), 1, &source_ptr, &length, &status);
|
*program_ = clCreateProgramWithSource(context(), 1, &source_ptr, &length, &status);
|
||||||
CLError::Check(status, "clCreateProgramWithSource");
|
CLCudaAPIError::Check(status, "clCreateProgramWithSource");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Binary-based constructor with memory management
|
// Binary-based constructor with memory management
|
||||||
|
@ -428,18 +443,18 @@ class Program {
|
||||||
delete p;
|
delete p;
|
||||||
}) {
|
}) {
|
||||||
const char *binary_ptr = &binary[0];
|
const char *binary_ptr = &binary[0];
|
||||||
size_t length = binary.length();
|
const auto length = binary.length();
|
||||||
auto status1 = CL_SUCCESS;
|
auto status1 = CL_SUCCESS;
|
||||||
auto status2 = CL_SUCCESS;
|
auto status2 = CL_SUCCESS;
|
||||||
const cl_device_id dev = device();
|
const auto dev = device();
|
||||||
*program_ = clCreateProgramWithBinary(context(), 1, &dev, &length,
|
*program_ = clCreateProgramWithBinary(context(), 1, &dev, &length,
|
||||||
reinterpret_cast<const unsigned char**>(&binary_ptr),
|
reinterpret_cast<const unsigned char**>(&binary_ptr),
|
||||||
&status1, &status2);
|
&status1, &status2);
|
||||||
CLError::Check(status1, "clCreateProgramWithBinary (binary status)");
|
CLCudaAPIError::Check(status1, "clCreateProgramWithBinary (binary status)");
|
||||||
CLError::Check(status2, "clCreateProgramWithBinary");
|
CLCudaAPIError::Check(status2, "clCreateProgramWithBinary");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compiles the device program and returns whether or not there where any warnings/errors
|
// Compiles the device program and checks whether or not there are any warnings/errors
|
||||||
void Build(const Device &device, std::vector<std::string> &options) {
|
void Build(const Device &device, std::vector<std::string> &options) {
|
||||||
options.push_back("-cl-std=CL1.1");
|
options.push_back("-cl-std=CL1.1");
|
||||||
auto options_string = std::accumulate(options.begin(), options.end(), std::string{" "});
|
auto options_string = std::accumulate(options.begin(), options.end(), std::string{" "});
|
||||||
|
@ -447,6 +462,11 @@ class Program {
|
||||||
CheckError(clBuildProgram(*program_, 1, &dev, options_string.c_str(), nullptr, nullptr));
|
CheckError(clBuildProgram(*program_, 1, &dev, options_string.c_str(), nullptr, nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Confirms whether a certain status code is an actual compilation error or warning
|
||||||
|
bool StatusIsCompilationWarningOrError(const cl_int status) const {
|
||||||
|
return (status == CL_BUILD_PROGRAM_FAILURE);
|
||||||
|
}
|
||||||
|
|
||||||
// Retrieves the warning/error message from the compiler (if any)
|
// Retrieves the warning/error message from the compiler (if any)
|
||||||
std::string GetBuildInfo(const Device &device) const {
|
std::string GetBuildInfo(const Device &device) const {
|
||||||
auto bytes = size_t{0};
|
auto bytes = size_t{0};
|
||||||
|
@ -477,6 +497,9 @@ class Program {
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Raw command-queue type
|
||||||
|
using RawCommandQueue = cl_command_queue;
|
||||||
|
|
||||||
// C++11 version of 'cl_command_queue'
|
// C++11 version of 'cl_command_queue'
|
||||||
class Queue {
|
class Queue {
|
||||||
public:
|
public:
|
||||||
|
@ -495,7 +518,7 @@ class Queue {
|
||||||
}) {
|
}) {
|
||||||
auto status = CL_SUCCESS;
|
auto status = CL_SUCCESS;
|
||||||
*queue_ = clCreateCommandQueue(context(), device(), CL_QUEUE_PROFILING_ENABLE, &status);
|
*queue_ = clCreateCommandQueue(context(), device(), CL_QUEUE_PROFILING_ENABLE, &status);
|
||||||
CLError::Check(status, "clCreateCommandQueue");
|
CLCudaAPIError::Check(status, "clCreateCommandQueue");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Synchronizes the queue
|
// Synchronizes the queue
|
||||||
|
@ -523,7 +546,7 @@ class Queue {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accessor to the private data-member
|
// Accessor to the private data-member
|
||||||
const cl_command_queue& operator()() const { return *queue_; }
|
const RawCommandQueue& operator()() const { return *queue_; }
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<cl_command_queue> queue_;
|
std::shared_ptr<cl_command_queue> queue_;
|
||||||
};
|
};
|
||||||
|
@ -587,7 +610,7 @@ class Buffer {
|
||||||
if (access_ == BufferAccess::kWriteOnly) { flags = CL_MEM_WRITE_ONLY; }
|
if (access_ == BufferAccess::kWriteOnly) { flags = CL_MEM_WRITE_ONLY; }
|
||||||
auto status = CL_SUCCESS;
|
auto status = CL_SUCCESS;
|
||||||
*buffer_ = clCreateBuffer(context(), flags, size*sizeof(T), nullptr, &status);
|
*buffer_ = clCreateBuffer(context(), flags, size*sizeof(T), nullptr, &status);
|
||||||
CLError::Check(status, "clCreateBuffer");
|
CLCudaAPIError::Check(status, "clCreateBuffer");
|
||||||
}
|
}
|
||||||
|
|
||||||
// As above, but now with read/write access as a default
|
// As above, but now with read/write access as a default
|
||||||
|
@ -719,7 +742,7 @@ class Kernel {
|
||||||
}) {
|
}) {
|
||||||
auto status = CL_SUCCESS;
|
auto status = CL_SUCCESS;
|
||||||
*kernel_ = clCreateKernel(program(), name.c_str(), &status);
|
*kernel_ = clCreateKernel(program(), name.c_str(), &status);
|
||||||
CLError::Check(status, "clCreateKernel");
|
CLCudaAPIError::Check(status, "clCreateKernel");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets a kernel argument at the indicated position
|
// Sets a kernel argument at the indicated position
|
||||||
|
|
|
@ -60,7 +60,6 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
|
||||||
event_(event),
|
event_(event),
|
||||||
context_(queue_.GetContext()),
|
context_(queue_.GetContext()),
|
||||||
device_(queue_.GetDevice()),
|
device_(queue_.GetDevice()),
|
||||||
platform_(device_.Platform()),
|
|
||||||
db_(kernel_names) {
|
db_(kernel_names) {
|
||||||
|
|
||||||
InitDatabase(userDatabase);
|
InitDatabase(userDatabase);
|
||||||
|
@ -68,18 +67,19 @@ Routine::Routine(Queue &queue, EventPointer event, const std::string &name,
|
||||||
}
|
}
|
||||||
|
|
||||||
void Routine::InitDatabase(const std::vector<database::DatabaseEntry> &userDatabase) {
|
void Routine::InitDatabase(const std::vector<database::DatabaseEntry> &userDatabase) {
|
||||||
|
const auto platform_id = device_.PlatformID();
|
||||||
for (const auto &kernel_name : kernel_names_) {
|
for (const auto &kernel_name : kernel_names_) {
|
||||||
|
|
||||||
// Queries the cache to see whether or not the kernel parameter database is already there
|
// Queries the cache to see whether or not the kernel parameter database is already there
|
||||||
bool has_db;
|
bool has_db;
|
||||||
db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_, device_(), precision_, kernel_name },
|
db_(kernel_name) = DatabaseCache::Instance().Get(DatabaseKeyRef{ platform_id, device_(), precision_, kernel_name },
|
||||||
&has_db);
|
&has_db);
|
||||||
if (has_db) { continue; }
|
if (has_db) { continue; }
|
||||||
|
|
||||||
// Builds the parameter database for this device and routine set and stores it in the cache
|
// Builds the parameter database for this device and routine set and stores it in the cache
|
||||||
log_debug("Searching database for kernel '" + kernel_name + "'");
|
log_debug("Searching database for kernel '" + kernel_name + "'");
|
||||||
db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase);
|
db_(kernel_name) = Database(device_, kernel_name, precision_, userDatabase);
|
||||||
DatabaseCache::Instance().Store(DatabaseKey{ platform_, device_(), precision_, kernel_name },
|
DatabaseCache::Instance().Store(DatabaseKey{ platform_id, device_(), precision_, kernel_name },
|
||||||
Database{ db_(kernel_name) });
|
Database{ db_(kernel_name) });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,13 +123,13 @@ void Routine::InitProgram(std::initializer_list<const char *> source) {
|
||||||
// Otherwise, the kernel will be compiled and program will be built. Both the binary and the
|
// Otherwise, the kernel will be compiled and program will be built. Both the binary and the
|
||||||
// program will be added to the cache.
|
// program will be added to the cache.
|
||||||
|
|
||||||
// Inspects whether or not cl_khr_fp64 is supported in case of double precision
|
// Inspects whether or not FP64 is supported in case of double precision
|
||||||
if ((precision_ == Precision::kDouble && !PrecisionSupported<double>(device_)) ||
|
if ((precision_ == Precision::kDouble && !PrecisionSupported<double>(device_)) ||
|
||||||
(precision_ == Precision::kComplexDouble && !PrecisionSupported<double2>(device_))) {
|
(precision_ == Precision::kComplexDouble && !PrecisionSupported<double2>(device_))) {
|
||||||
throw RuntimeErrorCode(StatusCode::kNoDoublePrecision);
|
throw RuntimeErrorCode(StatusCode::kNoDoublePrecision);
|
||||||
}
|
}
|
||||||
|
|
||||||
// As above, but for cl_khr_fp16 (half precision)
|
// As above, but for FP16 (half precision)
|
||||||
if (precision_ == Precision::kHalf && !PrecisionSupported<half>(device_)) {
|
if (precision_ == Precision::kHalf && !PrecisionSupported<half>(device_)) {
|
||||||
throw RuntimeErrorCode(StatusCode::kNoHalfPrecision);
|
throw RuntimeErrorCode(StatusCode::kNoHalfPrecision);
|
||||||
}
|
}
|
||||||
|
@ -188,8 +188,8 @@ void Routine::InitProgram(std::initializer_list<const char *> source) {
|
||||||
program_ = Program(context_, source_string);
|
program_ = Program(context_, source_string);
|
||||||
try {
|
try {
|
||||||
program_.Build(device_, options);
|
program_.Build(device_, options);
|
||||||
} catch (const CLError &e) {
|
} catch (const CLCudaAPIBuildError &e) {
|
||||||
if (e.status() == CL_BUILD_PROGRAM_FAILURE) {
|
if (program_.StatusIsCompilationWarningOrError(e.status())) {
|
||||||
fprintf(stdout, "OpenCL compiler error/warning: %s\n",
|
fprintf(stdout, "OpenCL compiler error/warning: %s\n",
|
||||||
program_.GetBuildInfo(device_).c_str());
|
program_.GetBuildInfo(device_).c_str());
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,7 +75,6 @@ class Routine {
|
||||||
EventPointer event_;
|
EventPointer event_;
|
||||||
const Context context_;
|
const Context context_;
|
||||||
const Device device_;
|
const Device device_;
|
||||||
const cl_platform_id platform_;
|
|
||||||
|
|
||||||
// Compiled program (either retrieved from cache or compiled in slow path)
|
// Compiled program (either retrieved from cache or compiled in slow path)
|
||||||
Program program_;
|
Program program_;
|
||||||
|
|
|
@ -55,7 +55,7 @@ StatusCode DispatchException()
|
||||||
} catch (BLASError &e) {
|
} catch (BLASError &e) {
|
||||||
// no message is printed for invalid argument errors
|
// no message is printed for invalid argument errors
|
||||||
status = e.status();
|
status = e.status();
|
||||||
} catch (CLError &e) {
|
} catch (CLCudaAPIError &e) {
|
||||||
message = e.what();
|
message = e.what();
|
||||||
status = static_cast<StatusCode>(e.status());
|
status = static_cast<StatusCode>(e.status());
|
||||||
} catch (RuntimeErrorCode &e) {
|
} catch (RuntimeErrorCode &e) {
|
||||||
|
|
|
@ -391,16 +391,9 @@ template <> Precision PrecisionValue<double2>() { return Precision::kComplexDoub
|
||||||
// Returns false is this precision is not supported by the device
|
// Returns false is this precision is not supported by the device
|
||||||
template <> bool PrecisionSupported<float>(const Device &) { return true; }
|
template <> bool PrecisionSupported<float>(const Device &) { return true; }
|
||||||
template <> bool PrecisionSupported<float2>(const Device &) { return true; }
|
template <> bool PrecisionSupported<float2>(const Device &) { return true; }
|
||||||
template <> bool PrecisionSupported<double>(const Device &device) {
|
template <> bool PrecisionSupported<double>(const Device &device) { return device.SupportsFP64(); }
|
||||||
return device.HasExtension(kKhronosDoublePrecision);
|
template <> bool PrecisionSupported<double2>(const Device &device) { return device.SupportsFP64(); }
|
||||||
}
|
template <> bool PrecisionSupported<half>(const Device &device) { return device.SupportsFP16(); }
|
||||||
template <> bool PrecisionSupported<double2>(const Device &device) {
|
|
||||||
return device.HasExtension(kKhronosDoublePrecision);
|
|
||||||
}
|
|
||||||
template <> bool PrecisionSupported<half>(const Device &device) {
|
|
||||||
if (device.Name() == "Mali-T628") { return true; } // supports fp16 but not cl_khr_fp16 officially
|
|
||||||
return device.HasExtension(kKhronosHalfPrecision);
|
|
||||||
}
|
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
|
|
|
@ -31,15 +31,13 @@ namespace clblast {
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
|
|
||||||
// Shorthands for half-precision
|
// Shorthands for half-precision
|
||||||
using half = cl_half; // based on the OpenCL type, which is actually an 'unsigned short'
|
using half = unsigned short; // the 'cl_half' OpenCL type is actually an 'unsigned short'
|
||||||
|
|
||||||
// Shorthands for complex data-types
|
// Shorthands for complex data-types
|
||||||
using float2 = std::complex<float>;
|
using float2 = std::complex<float>;
|
||||||
using double2 = std::complex<double>;
|
using double2 = std::complex<double>;
|
||||||
|
|
||||||
// Khronos OpenCL extensions
|
// Khronos OpenCL extensions
|
||||||
const std::string kKhronosHalfPrecision = "cl_khr_fp16";
|
|
||||||
const std::string kKhronosDoublePrecision = "cl_khr_fp64";
|
|
||||||
const std::string kKhronosAttributesAMD = "cl_amd_device_attribute_query";
|
const std::string kKhronosAttributesAMD = "cl_amd_device_attribute_query";
|
||||||
const std::string kKhronosAttributesNVIDIA = "cl_nv_device_attribute_query";
|
const std::string kKhronosAttributesNVIDIA = "cl_nv_device_attribute_query";
|
||||||
|
|
||||||
|
|
|
@ -85,7 +85,7 @@ void OpenCLDiagnostics(int argc, char *argv[]) {
|
||||||
printf("* device.Name() %.4lf ms\n", TimeFunction(kNumRuns, [&](){device.Name();} ));
|
printf("* device.Name() %.4lf ms\n", TimeFunction(kNumRuns, [&](){device.Name();} ));
|
||||||
printf("* device.Vendor() %.4lf ms\n", TimeFunction(kNumRuns, [&](){device.Vendor();} ));
|
printf("* device.Vendor() %.4lf ms\n", TimeFunction(kNumRuns, [&](){device.Vendor();} ));
|
||||||
printf("* device.Version() %.4lf ms\n", TimeFunction(kNumRuns, [&](){device.Version();} ));
|
printf("* device.Version() %.4lf ms\n", TimeFunction(kNumRuns, [&](){device.Version();} ));
|
||||||
printf("* device.Platform() %.4lf ms\n", TimeFunction(kNumRuns, [&](){device.Platform();} ));
|
printf("* device.Platform() %.4lf ms\n", TimeFunction(kNumRuns, [&](){ device.PlatformID();} ));
|
||||||
printf("* Buffer<float>(context, 1024) %.4lf ms\n", TimeFunction(kNumRuns, [&](){Buffer<float>(context, 1024);} ));
|
printf("* Buffer<float>(context, 1024) %.4lf ms\n", TimeFunction(kNumRuns, [&](){Buffer<float>(context, 1024);} ));
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
|
@ -88,7 +88,7 @@ void FloatToHalfBuffer(std::vector<half>& result, const std::vector<float>& sour
|
||||||
}
|
}
|
||||||
|
|
||||||
// As above, but now for OpenCL data-types instead of std::vectors
|
// As above, but now for OpenCL data-types instead of std::vectors
|
||||||
Buffer<float> HalfToFloatBuffer(const Buffer<half>& source, cl_command_queue queue_raw) {
|
Buffer<float> HalfToFloatBuffer(const Buffer<half>& source, RawCommandQueue queue_raw) {
|
||||||
const auto size = source.GetSize() / sizeof(half);
|
const auto size = source.GetSize() / sizeof(half);
|
||||||
auto queue = Queue(queue_raw);
|
auto queue = Queue(queue_raw);
|
||||||
auto context = queue.GetContext();
|
auto context = queue.GetContext();
|
||||||
|
@ -99,7 +99,7 @@ Buffer<float> HalfToFloatBuffer(const Buffer<half>& source, cl_command_queue que
|
||||||
result.Write(queue, size, result_cpu);
|
result.Write(queue, size, result_cpu);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
void FloatToHalfBuffer(Buffer<half>& result, const Buffer<float>& source, cl_command_queue queue_raw) {
|
void FloatToHalfBuffer(Buffer<half>& result, const Buffer<float>& source, RawCommandQueue queue_raw) {
|
||||||
const auto size = source.GetSize() / sizeof(float);
|
const auto size = source.GetSize() / sizeof(float);
|
||||||
auto queue = Queue(queue_raw);
|
auto queue = Queue(queue_raw);
|
||||||
auto context = queue.GetContext();
|
auto context = queue.GetContext();
|
||||||
|
|
|
@ -89,8 +89,8 @@ std::vector<float> HalfToFloatBuffer(const std::vector<half>& source);
|
||||||
void FloatToHalfBuffer(std::vector<half>& result, const std::vector<float>& source);
|
void FloatToHalfBuffer(std::vector<half>& result, const std::vector<float>& source);
|
||||||
|
|
||||||
// As above, but now for OpenCL data-types instead of std::vectors
|
// As above, but now for OpenCL data-types instead of std::vectors
|
||||||
Buffer<float> HalfToFloatBuffer(const Buffer<half>& source, cl_command_queue queue_raw);
|
Buffer<float> HalfToFloatBuffer(const Buffer<half>& source, RawCommandQueue queue_raw);
|
||||||
void FloatToHalfBuffer(Buffer<half>& result, const Buffer<float>& source, cl_command_queue queue_raw);
|
void FloatToHalfBuffer(Buffer<half>& result, const Buffer<float>& source, RawCommandQueue queue_raw);
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
} // namespace clblast
|
} // namespace clblast
|
||||||
|
|
Loading…
Reference in a new issue