mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-07-02 12:26:57 +02:00
Initial changes in preparation for half-precision fp16 support
This commit is contained in:
parent
1c72d225c5
commit
f2ba75890c
|
@ -20,6 +20,8 @@
|
|||
|
||||
#include <cltune.h>
|
||||
|
||||
#include "internal/utilities.h"
|
||||
|
||||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
|
|
|
@ -27,6 +27,9 @@
|
|||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
// Host data-type for half-precision floating-point (16-bit)
|
||||
using half = cl_half;
|
||||
|
||||
// Shorthands for complex data-types
|
||||
using float2 = std::complex<float>;
|
||||
using double2 = std::complex<double>;
|
||||
|
|
|
@ -13,10 +13,13 @@
|
|||
# ==================================================================================================
|
||||
|
||||
# Short-hands for data-types
|
||||
HLF = "half"
|
||||
FLT = "float"
|
||||
DBL = "double"
|
||||
FLT2 = "float2"
|
||||
DBL2 = "double2"
|
||||
|
||||
HCL = "cl_half"
|
||||
F2CL = "cl_float2"
|
||||
D2CL = "cl_double2"
|
||||
|
||||
|
|
|
@ -28,11 +28,12 @@ import os.path
|
|||
|
||||
# Local files
|
||||
from routine import Routine
|
||||
from datatype import DataType, FLT, DBL, FLT2, DBL2, F2CL, D2CL
|
||||
from datatype import DataType, HLF, FLT, DBL, FLT2, DBL2, HCL, F2CL, D2CL
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
# Regular data-types
|
||||
H = DataType("H", "H", HLF, [HLF, HLF, HCL, HCL], HLF ) # half (16)
|
||||
S = DataType("S", "S", FLT, [FLT, FLT, FLT, FLT], FLT ) # single (32)
|
||||
D = DataType("D", "D", DBL, [DBL, DBL, DBL, DBL], DBL ) # double (64)
|
||||
C = DataType("C", "C", FLT2, [FLT2, FLT2, F2CL, F2CL], FLT2) # single-complex (3232)
|
||||
|
@ -67,7 +68,7 @@ routines = [
|
|||
Routine(True, True, "1", "swap", T, [S,D,C,Z], ["n"], [], [], ["x","y"], [], "", "Swap two vectors", "Interchanges the contents of vectors x and y.", []),
|
||||
Routine(True, True, "1", "scal", T, [S,D,C,Z], ["n"], [], [], ["x"], ["alpha"], "", "Vector scaling", "Multiplies all elements of vector x by a scalar constant alpha.", []),
|
||||
Routine(True, True, "1", "copy", T, [S,D,C,Z], ["n"], [], ["x"], ["y"], [], "", "Vector copy", "Copies the contents of vector x into vector y.", []),
|
||||
Routine(True, True, "1", "axpy", T, [S,D,C,Z], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation y = alpha * x + y, in which x and y are vectors and alpha is a scalar constant.", []),
|
||||
Routine(True, True, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation y = alpha * x + y, in which x and y are vectors and alpha is a scalar constant.", []),
|
||||
Routine(True, True, "1", "dot", T, [S,D], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two vectors", "Multiplies the vectors x and y element-wise and accumulates the results. The sum is stored in the dot buffer.", []),
|
||||
Routine(True, True, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []),
|
||||
Routine(True, True, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []),
|
||||
|
@ -229,22 +230,23 @@ def wrapper_clblas(routines):
|
|||
result = ""
|
||||
for routine in routines:
|
||||
if routine.has_tests:
|
||||
result += "\n// Forwards the clBLAS calls for %s\n" % (routine.ShortNames())
|
||||
result += "\n// Forwards the clBLAS calls for %s\n" % (routine.ShortNamesTested())
|
||||
if routine.NoScalars():
|
||||
result += routine.RoutineHeaderWrapperCL(routine.template, True, 21)+";\n"
|
||||
for flavour in routine.flavours:
|
||||
indent = " "*(17 + routine.Length())
|
||||
result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n"
|
||||
arguments = routine.ArgumentsWrapperCL(flavour)
|
||||
if routine.scratch:
|
||||
result += " auto queue = Queue(queues[0]);\n"
|
||||
result += " auto context = queue.GetContext();\n"
|
||||
result += " auto scratch_buffer = Buffer<"+flavour.template+">(context, "+routine.scratch+");\n"
|
||||
arguments += ["scratch_buffer()"]
|
||||
result += " return clblas"+flavour.name+routine.name+"("
|
||||
result += (",\n"+indent).join([a for a in arguments])
|
||||
result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);"
|
||||
result += "\n}\n"
|
||||
if flavour.precision_name in ["S","D","C","Z"]:
|
||||
indent = " "*(17 + routine.Length())
|
||||
result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n"
|
||||
arguments = routine.ArgumentsWrapperCL(flavour)
|
||||
if routine.scratch:
|
||||
result += " auto queue = Queue(queues[0]);\n"
|
||||
result += " auto context = queue.GetContext();\n"
|
||||
result += " auto scratch_buffer = Buffer<"+flavour.template+">(context, "+routine.scratch+");\n"
|
||||
arguments += ["scratch_buffer()"]
|
||||
result += " return clblas"+flavour.name+routine.name+"("
|
||||
result += (",\n"+indent).join([a for a in arguments])
|
||||
result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);"
|
||||
result += "\n}\n"
|
||||
return result
|
||||
|
||||
# The wrapper to the reference CBLAS routines (for performance/correctness testing)
|
||||
|
@ -252,44 +254,45 @@ def wrapper_cblas(routines):
|
|||
result = ""
|
||||
for routine in routines:
|
||||
if routine.has_tests:
|
||||
result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNames())
|
||||
result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNamesTested())
|
||||
for flavour in routine.flavours:
|
||||
indent = " "*(10 + routine.Length())
|
||||
result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n"
|
||||
arguments = routine.ArgumentsWrapperC(flavour)
|
||||
if flavour.precision_name in ["S","D","C","Z"]:
|
||||
indent = " "*(10 + routine.Length())
|
||||
result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n"
|
||||
arguments = routine.ArgumentsWrapperC(flavour)
|
||||
|
||||
# Double-precision scalars
|
||||
for scalar in routine.scalars:
|
||||
if flavour.IsComplex(scalar):
|
||||
result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n"
|
||||
# Double-precision scalars
|
||||
for scalar in routine.scalars:
|
||||
if flavour.IsComplex(scalar):
|
||||
result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n"
|
||||
|
||||
# Special case for scalar outputs
|
||||
assignment = ""
|
||||
postfix = ""
|
||||
endofline = ""
|
||||
extra_argument = ""
|
||||
for output_buffer in routine.outputs:
|
||||
if output_buffer in routine.ScalarBuffersFirst():
|
||||
if flavour in [C,Z]:
|
||||
postfix += "_sub"
|
||||
indent += " "
|
||||
extra_argument += ",\n"+indent+"reinterpret_cast<return_pointer_"+flavour.buffertype[:-1]+">(&"+output_buffer+"_buffer["+output_buffer+"_offset])"
|
||||
elif output_buffer in routine.IndexBuffers():
|
||||
assignment = "((int*)&"+output_buffer+"_buffer[0])["+output_buffer+"_offset] = "
|
||||
indent += " "*len(assignment)
|
||||
else:
|
||||
assignment = output_buffer+"_buffer["+output_buffer+"_offset]"
|
||||
if (flavour.name in ["Sc","Dz"]):
|
||||
assignment = assignment+".real("
|
||||
endofline += ")"
|
||||
# Special case for scalar outputs
|
||||
assignment = ""
|
||||
postfix = ""
|
||||
endofline = ""
|
||||
extra_argument = ""
|
||||
for output_buffer in routine.outputs:
|
||||
if output_buffer in routine.ScalarBuffersFirst():
|
||||
if flavour in [C,Z]:
|
||||
postfix += "_sub"
|
||||
indent += " "
|
||||
extra_argument += ",\n"+indent+"reinterpret_cast<return_pointer_"+flavour.buffertype[:-1]+">(&"+output_buffer+"_buffer["+output_buffer+"_offset])"
|
||||
elif output_buffer in routine.IndexBuffers():
|
||||
assignment = "((int*)&"+output_buffer+"_buffer[0])["+output_buffer+"_offset] = "
|
||||
indent += " "*len(assignment)
|
||||
else:
|
||||
assignment = assignment+" = "
|
||||
indent += " "*len(assignment)
|
||||
assignment = output_buffer+"_buffer["+output_buffer+"_offset]"
|
||||
if (flavour.name in ["Sc","Dz"]):
|
||||
assignment = assignment+".real("
|
||||
endofline += ")"
|
||||
else:
|
||||
assignment = assignment+" = "
|
||||
indent += " "*len(assignment)
|
||||
|
||||
result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"("
|
||||
result += (",\n"+indent).join([a for a in arguments])
|
||||
result += extra_argument+endofline+");"
|
||||
result += "\n}\n"
|
||||
result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"("
|
||||
result += (",\n"+indent).join([a for a in arguments])
|
||||
result += extra_argument+endofline+");"
|
||||
result += "\n}\n"
|
||||
return result
|
||||
|
||||
# ==================================================================================================
|
||||
|
@ -368,9 +371,10 @@ for level in [1,2,3]:
|
|||
body += "int main(int argc, char *argv[]) {\n"
|
||||
not_first = "false"
|
||||
for flavour in routine.flavours:
|
||||
body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate()
|
||||
body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n"
|
||||
not_first = "true"
|
||||
if flavour.precision_name in ["S","D","C","Z"]:
|
||||
body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate()
|
||||
body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n"
|
||||
not_first = "true"
|
||||
body += " return 0;\n"
|
||||
body += "}\n"
|
||||
f.write(header+"\n")
|
||||
|
@ -397,7 +401,7 @@ for level in [1,2,3]:
|
|||
body += " case clblast::Precision::k"+PrecisionToFullName(precision)+":"
|
||||
found = False
|
||||
for flavour in routine.flavours:
|
||||
if flavour.precision_name == precision:
|
||||
if flavour.precision_name == precision and flavour.precision_name in ["S","D","C","Z"]:
|
||||
body += "\n clblast::RunClient<clblast::TestX"+routine.name+flavour.TestTemplate()
|
||||
body += ">(argc, argv); break;\n"
|
||||
found = True
|
||||
|
|
|
@ -119,6 +119,12 @@ class Routine():
|
|||
def ShortNames(self):
|
||||
return "/".join([f.name+self.name.upper() for f in self.flavours])
|
||||
|
||||
# As above, but excludes some
|
||||
def ShortNamesTested(self):
|
||||
names = [f.name+self.name.upper() for f in self.flavours]
|
||||
if "H"+self.name.upper() in names: names.remove("H"+self.name.upper())
|
||||
return "/".join(names)
|
||||
|
||||
# Determines which buffers go first (between alpha and beta) and which ones go after
|
||||
def BuffersFirst(self):
|
||||
if self.level == "2b":
|
||||
|
|
|
@ -19,7 +19,7 @@ R"(
|
|||
// Parameters set by the tuner or by the database. Here they are given a basic default value in case
|
||||
// this file is used outside of the CLBlast library.
|
||||
#ifndef PRECISION
|
||||
#define PRECISION 32 // Data-types: single or double precision, complex or regular
|
||||
#define PRECISION 32 // Data-types: half, single or double precision, complex or regular
|
||||
#endif
|
||||
|
||||
// =================================================================================================
|
||||
|
@ -31,8 +31,19 @@ R"(
|
|||
#endif
|
||||
#endif
|
||||
|
||||
// Half-precision
|
||||
#if PRECISION == 16
|
||||
typedef half real;
|
||||
typedef half2 real2;
|
||||
typedef half4 real4;
|
||||
typedef half8 real8;
|
||||
typedef half16 real16;
|
||||
#define ZERO 0.0
|
||||
#define ONE 1.0
|
||||
#define SMALLEST -1.0e37
|
||||
|
||||
// Single-precision
|
||||
#if PRECISION == 32
|
||||
#elif PRECISION == 32
|
||||
typedef float real;
|
||||
typedef float2 real2;
|
||||
typedef float4 real4;
|
||||
|
@ -68,7 +79,7 @@ R"(
|
|||
#define ONE 1.0f
|
||||
#define SMALLEST -1.0e37f
|
||||
|
||||
// Complex Double-precision
|
||||
// Complex double-precision
|
||||
#elif PRECISION == 6464
|
||||
typedef struct cdouble {double x; double y;} real;
|
||||
typedef struct cdouble2 {real x; real y;} real2;
|
||||
|
|
|
@ -397,6 +397,7 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Ev
|
|||
// =================================================================================================
|
||||
|
||||
// Compiles the templated class
|
||||
template class Routine<half>;
|
||||
template class Routine<float>;
|
||||
template class Routine<double>;
|
||||
template class Routine<float2>;
|
||||
|
|
|
@ -29,6 +29,7 @@ std::string ToString(T value) {
|
|||
}
|
||||
template std::string ToString<int>(int value);
|
||||
template std::string ToString<size_t>(size_t value);
|
||||
template std::string ToString<half>(half value);
|
||||
template std::string ToString<float>(float value);
|
||||
template std::string ToString<double>(double value);
|
||||
|
||||
|
@ -105,6 +106,9 @@ template <typename T>
|
|||
T ConvertArgument(const char* value) {
|
||||
return static_cast<T>(std::stoi(value));
|
||||
}
|
||||
template <> half ConvertArgument(const char* value) {
|
||||
return static_cast<half>(std::stod(value));
|
||||
}
|
||||
template <> float ConvertArgument(const char* value) {
|
||||
return static_cast<float>(std::stod(value));
|
||||
}
|
||||
|
@ -147,6 +151,7 @@ T GetArgument(const int argc, char *argv[], std::string &help,
|
|||
// Compiles the above function
|
||||
template int GetArgument<int>(const int, char **, std::string&, const std::string&, const int);
|
||||
template size_t GetArgument<size_t>(const int, char **, std::string&, const std::string&, const size_t);
|
||||
template half GetArgument<half>(const int, char **, std::string&, const std::string&, const half);
|
||||
template float GetArgument<float>(const int, char **, std::string&, const std::string&, const float);
|
||||
template double GetArgument<double>(const int, char **, std::string&, const std::string&, const double);
|
||||
template float2 GetArgument<float2>(const int, char **, std::string&, const std::string&, const float2);
|
||||
|
@ -227,6 +232,16 @@ void PopulateVector(std::vector<double2> &vector) {
|
|||
for (auto &element: vector) { element.real(dist(mt)); element.imag(dist(mt)); }
|
||||
}
|
||||
|
||||
// Specialized versions of the above for half-precision
|
||||
template <>
|
||||
void PopulateVector(std::vector<half> &vector) {
|
||||
auto lower_limit = static_cast<float>(kTestDataLowerLimit);
|
||||
auto upper_limit = static_cast<float>(kTestDataUpperLimit);
|
||||
std::mt19937 mt(GetRandomSeed());
|
||||
std::uniform_real_distribution<float> dist(lower_limit, upper_limit);
|
||||
for (auto &element: vector) { element = static_cast<half>(dist(mt)); }
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Returns a scalar with a default value
|
||||
|
@ -234,6 +249,7 @@ template <typename T>
|
|||
T GetScalar() {
|
||||
return static_cast<T>(2.0);
|
||||
}
|
||||
template half GetScalar<half>();
|
||||
template float GetScalar<float>();
|
||||
template double GetScalar<double>();
|
||||
|
||||
|
@ -288,6 +304,10 @@ template <> bool PrecisionSupported<double2>(const Device &device) {
|
|||
auto extensions = device.Capabilities();
|
||||
return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true;
|
||||
}
|
||||
template <> bool PrecisionSupported<half>(const Device &device) {
|
||||
auto extensions = device.Capabilities();
|
||||
return (extensions.find(kKhronosHalfPrecision) == std::string::npos) ? false : true;
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
|
Loading…
Reference in a new issue