Added a C interface to the OverrideParameters function; added some in-line comments to the API
parent
08bfb75a9d
commit
cda449a5c3
|
@ -621,6 +621,9 @@ StatusCode PUBLIC_API FillCache(const cl_device_id device);
|
|||
|
||||
// =================================================================================================
|
||||
|
||||
// Overrides tuning parameters for a specific device-precision-routine combination. The next time
|
||||
// (and all further times) the target routine is called it will re-compile and use the new
|
||||
// parameters.
|
||||
StatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const std::string &kernel_name,
|
||||
const Precision precision,
|
||||
const std::unordered_map<std::string,size_t> ¶meters);
|
||||
|
|
|
@ -96,6 +96,8 @@ typedef enum CLBlastStatusCode_ {
|
|||
CLBlastInsufficientMemoryY = -1007, // Vector Y's OpenCL buffer is too small
|
||||
|
||||
// Custom additional status codes for CLBlast
|
||||
CLBlastInvalidOverrideKernel = -2048, // Trying to override parameters for an invalid kernel
|
||||
CLBlastMissingOverrideParameter = -2047, // Missing override parameter(s) for the target kernel
|
||||
CLBlastInvalidLocalMemUsage = -2046, // Not enough local memory available on this device
|
||||
CLBlastNoHalfPrecision = -2045, // Half precision (16-bits) not supported by the device
|
||||
CLBlastNoDoublePrecision = -2044, // Double precision (64-bits) not supported by the device
|
||||
|
@ -117,6 +119,11 @@ typedef enum CLBlastDiagonal_ { CLBlastDiagonalNonUnit = 131,
|
|||
CLBlastDiagonalUnit = 132 } CLBlastDiagonal;
|
||||
typedef enum CLBlastSide_ { CLBlastSideLeft = 141, CLBlastSideRight = 142 } CLBlastSide;
|
||||
|
||||
// Precision enum (values in bits)
|
||||
typedef enum CLBlastPrecision_ { CLBlastPrecisionHalf = 16, CLBlastPrecisionSingle = 32,
|
||||
CLBlastPrecisionDouble = 64, CLBlastPrecisionComplexSingle = 3232,
|
||||
CLBlastPrecisionComplexDouble = 6464 } CLBlastPrecision;
|
||||
|
||||
// =================================================================================================
|
||||
// BLAS level-1 (vector-vector) routines
|
||||
// =================================================================================================
|
||||
|
@ -1338,6 +1345,15 @@ CLBlastStatusCode PUBLIC_API CLBlastFillCache(const cl_device_id device);
|
|||
|
||||
// =================================================================================================
|
||||
|
||||
// Overrides tuning parameters for a specific device-precision-routine combination. The next time
|
||||
// (and all further times) the target routine is called it will re-compile and use the new
|
||||
// parameters.
|
||||
CLBlastStatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const char* kernel_name,
|
||||
const CLBlastPrecision precision, const size_t num_parameters,
|
||||
const char** parameters_names, const size_t* parameters_values);
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
|
|
@ -41,8 +41,8 @@ FILES = [
|
|||
"/include/clblast_netlib_c.h",
|
||||
"/src/clblast_netlib_c.cpp",
|
||||
]
|
||||
HEADER_LINES = [121, 73, 118, 22, 29, 41, 65, 32]
|
||||
FOOTER_LINES = [23, 138, 19, 18, 6, 6, 9, 2]
|
||||
HEADER_LINES = [121, 73, 125, 23, 29, 41, 65, 32]
|
||||
FOOTER_LINES = [26, 139, 28, 38, 6, 6, 9, 2]
|
||||
|
||||
# Different possibilities for requirements
|
||||
ald_m = "The value of `a_ld` must be at least `m`."
|
||||
|
|
|
@ -2255,6 +2255,7 @@ StatusCode FillCache(const cl_device_id device) {
|
|||
|
||||
// =================================================================================================
|
||||
|
||||
// Overrides the tuning parameters for this device-precision-kernel combination
|
||||
StatusCode OverrideParameters(const cl_device_id device, const std::string &kernel_name,
|
||||
const Precision precision,
|
||||
const std::unordered_map<std::string,size_t> ¶meters) {
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
// =================================================================================================
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "utilities/utilities.hpp"
|
||||
#include "clblast_c.h"
|
||||
|
@ -3484,3 +3485,23 @@ CLBlastStatusCode CLBlastFillCache(const cl_device_id device) {
|
|||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Overrides the tuning parameters for this device-precision-kernel combination
|
||||
CLBlastStatusCode PUBLIC_API OverrideParameters(const cl_device_id device, const char* kernel_name,
|
||||
const CLBlastPrecision precision, const size_t num_parameters,
|
||||
const char** parameters_names, const size_t* parameters_values) {
|
||||
try {
|
||||
const auto kernel_name_cpp = std::string(kernel_name);
|
||||
const auto precision_cpp = static_cast<clblast::Precision>(precision);
|
||||
auto parameters = std::unordered_map<std::string, size_t>();
|
||||
for (auto i = size_t{0}; i < num_parameters; ++i) {
|
||||
const auto parameter_name = std::string(parameters_names[i]);
|
||||
const auto parameter_value = parameters_values[i];
|
||||
parameters[parameter_name] = parameter_value;
|
||||
}
|
||||
const auto status = clblast::OverrideParameters(device, kernel_name_cpp, precision_cpp, parameters);
|
||||
return static_cast<CLBlastStatusCode>(status);
|
||||
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
|
Loading…
Reference in New Issue