Initial changes in preparation for half-precision fp16 support

This commit is contained in:
Cedric Nugteren 2016-05-12 19:56:21 +02:00
parent 1c72d225c5
commit f2ba75890c
8 changed files with 105 additions and 55 deletions

View file

@ -20,6 +20,8 @@
#include <cltune.h>
#include "internal/utilities.h"
namespace clblast {
// =================================================================================================

View file

@ -27,6 +27,9 @@
namespace clblast {
// =================================================================================================
// Host data-type for half-precision floating-point (16-bit)
using half = cl_half;
// Shorthands for complex data-types
using float2 = std::complex<float>;
using double2 = std::complex<double>;

View file

@ -13,10 +13,13 @@
# ==================================================================================================
# Short-hands for data-types
HLF = "half"
FLT = "float"
DBL = "double"
FLT2 = "float2"
DBL2 = "double2"
HCL = "cl_half"
F2CL = "cl_float2"
D2CL = "cl_double2"

View file

@ -28,11 +28,12 @@ import os.path
# Local files
from routine import Routine
from datatype import DataType, FLT, DBL, FLT2, DBL2, F2CL, D2CL
from datatype import DataType, HLF, FLT, DBL, FLT2, DBL2, HCL, F2CL, D2CL
# ==================================================================================================
# Regular data-types
H = DataType("H", "H", HLF, [HLF, HLF, HCL, HCL], HLF ) # half (16)
S = DataType("S", "S", FLT, [FLT, FLT, FLT, FLT], FLT ) # single (32)
D = DataType("D", "D", DBL, [DBL, DBL, DBL, DBL], DBL ) # double (64)
C = DataType("C", "C", FLT2, [FLT2, FLT2, F2CL, F2CL], FLT2) # single-complex (3232)
@ -67,7 +68,7 @@ routines = [
Routine(True, True, "1", "swap", T, [S,D,C,Z], ["n"], [], [], ["x","y"], [], "", "Swap two vectors", "Interchanges the contents of vectors x and y.", []),
Routine(True, True, "1", "scal", T, [S,D,C,Z], ["n"], [], [], ["x"], ["alpha"], "", "Vector scaling", "Multiplies all elements of vector x by a scalar constant alpha.", []),
Routine(True, True, "1", "copy", T, [S,D,C,Z], ["n"], [], ["x"], ["y"], [], "", "Vector copy", "Copies the contents of vector x into vector y.", []),
Routine(True, True, "1", "axpy", T, [S,D,C,Z], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation y = alpha * x + y, in which x and y are vectors and alpha is a scalar constant.", []),
Routine(True, True, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation y = alpha * x + y, in which x and y are vectors and alpha is a scalar constant.", []),
Routine(True, True, "1", "dot", T, [S,D], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two vectors", "Multiplies the vectors x and y element-wise and accumulates the results. The sum is stored in the dot buffer.", []),
Routine(True, True, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []),
Routine(True, True, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []),
@ -229,22 +230,23 @@ def wrapper_clblas(routines):
result = ""
for routine in routines:
if routine.has_tests:
result += "\n// Forwards the clBLAS calls for %s\n" % (routine.ShortNames())
result += "\n// Forwards the clBLAS calls for %s\n" % (routine.ShortNamesTested())
if routine.NoScalars():
result += routine.RoutineHeaderWrapperCL(routine.template, True, 21)+";\n"
for flavour in routine.flavours:
indent = " "*(17 + routine.Length())
result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n"
arguments = routine.ArgumentsWrapperCL(flavour)
if routine.scratch:
result += " auto queue = Queue(queues[0]);\n"
result += " auto context = queue.GetContext();\n"
result += " auto scratch_buffer = Buffer<"+flavour.template+">(context, "+routine.scratch+");\n"
arguments += ["scratch_buffer()"]
result += " return clblas"+flavour.name+routine.name+"("
result += (",\n"+indent).join([a for a in arguments])
result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);"
result += "\n}\n"
if flavour.precision_name in ["S","D","C","Z"]:
indent = " "*(17 + routine.Length())
result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n"
arguments = routine.ArgumentsWrapperCL(flavour)
if routine.scratch:
result += " auto queue = Queue(queues[0]);\n"
result += " auto context = queue.GetContext();\n"
result += " auto scratch_buffer = Buffer<"+flavour.template+">(context, "+routine.scratch+");\n"
arguments += ["scratch_buffer()"]
result += " return clblas"+flavour.name+routine.name+"("
result += (",\n"+indent).join([a for a in arguments])
result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);"
result += "\n}\n"
return result
# The wrapper to the reference CBLAS routines (for performance/correctness testing)
@ -252,44 +254,45 @@ def wrapper_cblas(routines):
result = ""
for routine in routines:
if routine.has_tests:
result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNames())
result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNamesTested())
for flavour in routine.flavours:
indent = " "*(10 + routine.Length())
result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n"
arguments = routine.ArgumentsWrapperC(flavour)
if flavour.precision_name in ["S","D","C","Z"]:
indent = " "*(10 + routine.Length())
result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n"
arguments = routine.ArgumentsWrapperC(flavour)
# Double-precision scalars
for scalar in routine.scalars:
if flavour.IsComplex(scalar):
result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n"
# Double-precision scalars
for scalar in routine.scalars:
if flavour.IsComplex(scalar):
result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n"
# Special case for scalar outputs
assignment = ""
postfix = ""
endofline = ""
extra_argument = ""
for output_buffer in routine.outputs:
if output_buffer in routine.ScalarBuffersFirst():
if flavour in [C,Z]:
postfix += "_sub"
indent += " "
extra_argument += ",\n"+indent+"reinterpret_cast<return_pointer_"+flavour.buffertype[:-1]+">(&"+output_buffer+"_buffer["+output_buffer+"_offset])"
elif output_buffer in routine.IndexBuffers():
assignment = "((int*)&"+output_buffer+"_buffer[0])["+output_buffer+"_offset] = "
indent += " "*len(assignment)
else:
assignment = output_buffer+"_buffer["+output_buffer+"_offset]"
if (flavour.name in ["Sc","Dz"]):
assignment = assignment+".real("
endofline += ")"
# Special case for scalar outputs
assignment = ""
postfix = ""
endofline = ""
extra_argument = ""
for output_buffer in routine.outputs:
if output_buffer in routine.ScalarBuffersFirst():
if flavour in [C,Z]:
postfix += "_sub"
indent += " "
extra_argument += ",\n"+indent+"reinterpret_cast<return_pointer_"+flavour.buffertype[:-1]+">(&"+output_buffer+"_buffer["+output_buffer+"_offset])"
elif output_buffer in routine.IndexBuffers():
assignment = "((int*)&"+output_buffer+"_buffer[0])["+output_buffer+"_offset] = "
indent += " "*len(assignment)
else:
assignment = assignment+" = "
indent += " "*len(assignment)
assignment = output_buffer+"_buffer["+output_buffer+"_offset]"
if (flavour.name in ["Sc","Dz"]):
assignment = assignment+".real("
endofline += ")"
else:
assignment = assignment+" = "
indent += " "*len(assignment)
result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"("
result += (",\n"+indent).join([a for a in arguments])
result += extra_argument+endofline+");"
result += "\n}\n"
result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"("
result += (",\n"+indent).join([a for a in arguments])
result += extra_argument+endofline+");"
result += "\n}\n"
return result
# ==================================================================================================
@ -368,9 +371,10 @@ for level in [1,2,3]:
body += "int main(int argc, char *argv[]) {\n"
not_first = "false"
for flavour in routine.flavours:
body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate()
body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n"
not_first = "true"
if flavour.precision_name in ["S","D","C","Z"]:
body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate()
body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n"
not_first = "true"
body += " return 0;\n"
body += "}\n"
f.write(header+"\n")
@ -397,7 +401,7 @@ for level in [1,2,3]:
body += " case clblast::Precision::k"+PrecisionToFullName(precision)+":"
found = False
for flavour in routine.flavours:
if flavour.precision_name == precision:
if flavour.precision_name == precision and flavour.precision_name in ["S","D","C","Z"]:
body += "\n clblast::RunClient<clblast::TestX"+routine.name+flavour.TestTemplate()
body += ">(argc, argv); break;\n"
found = True

View file

@ -119,6 +119,12 @@ class Routine():
def ShortNames(self):
return "/".join([f.name+self.name.upper() for f in self.flavours])
# As above, but excludes some
def ShortNamesTested(self):
names = [f.name+self.name.upper() for f in self.flavours]
if "H"+self.name.upper() in names: names.remove("H"+self.name.upper())
return "/".join(names)
# Determines which buffers go first (between alpha and beta) and which ones go after
def BuffersFirst(self):
if self.level == "2b":

View file

@ -19,7 +19,7 @@ R"(
// Parameters set by the tuner or by the database. Here they are given a basic default value in case
// this file is used outside of the CLBlast library.
#ifndef PRECISION
#define PRECISION 32 // Data-types: single or double precision, complex or regular
#define PRECISION 32 // Data-types: half, single or double precision, complex or regular
#endif
// =================================================================================================
@ -31,8 +31,19 @@ R"(
#endif
#endif
// Half-precision
#if PRECISION == 16
typedef half real;
typedef half2 real2;
typedef half4 real4;
typedef half8 real8;
typedef half16 real16;
#define ZERO 0.0
#define ONE 1.0
#define SMALLEST -1.0e37
// Single-precision
#if PRECISION == 32
#elif PRECISION == 32
typedef float real;
typedef float2 real2;
typedef float4 real4;
@ -68,7 +79,7 @@ R"(
#define ONE 1.0f
#define SMALLEST -1.0e37f
// Complex Double-precision
// Complex double-precision
#elif PRECISION == 6464
typedef struct cdouble {double x; double y;} real;
typedef struct cdouble2 {real x; real y;} real2;

View file

@ -397,6 +397,7 @@ StatusCode Routine<T>::PadCopyTransposeMatrix(EventPointer event, std::vector<Ev
// =================================================================================================
// Compiles the templated class
template class Routine<half>;
template class Routine<float>;
template class Routine<double>;
template class Routine<float2>;

View file

@ -29,6 +29,7 @@ std::string ToString(T value) {
}
template std::string ToString<int>(int value);
template std::string ToString<size_t>(size_t value);
template std::string ToString<half>(half value);
template std::string ToString<float>(float value);
template std::string ToString<double>(double value);
@ -105,6 +106,9 @@ template <typename T>
T ConvertArgument(const char* value) {
return static_cast<T>(std::stoi(value));
}
template <> half ConvertArgument(const char* value) {
return static_cast<half>(std::stod(value));
}
template <> float ConvertArgument(const char* value) {
return static_cast<float>(std::stod(value));
}
@ -147,6 +151,7 @@ T GetArgument(const int argc, char *argv[], std::string &help,
// Compiles the above function
template int GetArgument<int>(const int, char **, std::string&, const std::string&, const int);
template size_t GetArgument<size_t>(const int, char **, std::string&, const std::string&, const size_t);
template half GetArgument<half>(const int, char **, std::string&, const std::string&, const half);
template float GetArgument<float>(const int, char **, std::string&, const std::string&, const float);
template double GetArgument<double>(const int, char **, std::string&, const std::string&, const double);
template float2 GetArgument<float2>(const int, char **, std::string&, const std::string&, const float2);
@ -227,6 +232,16 @@ void PopulateVector(std::vector<double2> &vector) {
for (auto &element: vector) { element.real(dist(mt)); element.imag(dist(mt)); }
}
// Specialized versions of the above for half-precision
template <>
void PopulateVector(std::vector<half> &vector) {
auto lower_limit = static_cast<float>(kTestDataLowerLimit);
auto upper_limit = static_cast<float>(kTestDataUpperLimit);
std::mt19937 mt(GetRandomSeed());
std::uniform_real_distribution<float> dist(lower_limit, upper_limit);
for (auto &element: vector) { element = static_cast<half>(dist(mt)); }
}
// =================================================================================================
// Returns a scalar with a default value
@ -234,6 +249,7 @@ template <typename T>
T GetScalar() {
return static_cast<T>(2.0);
}
template half GetScalar<half>();
template float GetScalar<float>();
template double GetScalar<double>();
@ -288,6 +304,10 @@ template <> bool PrecisionSupported<double2>(const Device &device) {
auto extensions = device.Capabilities();
return (extensions.find(kKhronosDoublePrecision) == std::string::npos) ? false : true;
}
template <> bool PrecisionSupported<half>(const Device &device) {
auto extensions = device.Capabilities();
return (extensions.find(kKhronosHalfPrecision) == std::string::npos) ? false : true;
}
// =================================================================================================
} // namespace clblast