All enums in the C API are now prefixed with CLBlast to avoid potential name clashes with other projects

This commit is contained in:
Cedric Nugteren 2016-10-22 16:14:56 +02:00
parent 4a5516aa78
commit a670c4c4bf
13 changed files with 5141 additions and 4744 deletions

View file

@ -1,6 +1,7 @@
Development version (next release) Development version (next release)
- Updated to version 8.0 of the CLCudaAPI C++11 OpenCL header - Updated to version 8.0 of the CLCudaAPI C++11 OpenCL header
- Changed the enums in the C API to avoid potential name clashes with external code
- Greatly improved the way exceptions are handled in the library (thanks to 'intelfx') - Greatly improved the way exceptions are handled in the library (thanks to 'intelfx')
- Improved performance of GEMM kernels for small sizes by using a direct single-kernel implementation - Improved performance of GEMM kernels for small sizes by using a direct single-kernel implementation
- Fixed a bug in the tests and samples related to waiting for an invalid event - Fixed a bug in the tests and samples related to waiting for an invalid event

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -106,13 +106,13 @@ void run_example_routine(const cl_device_id device) {
clock_t start = clock(); clock_t start = clock();
// Calls an example routine // Calls an example routine
StatusCode status = CLBlastSasum(n, CLBlastStatusCode status = CLBlastSasum(n,
device_output, 0, device_output, 0,
device_input, 0, 1, device_input, 0, 1,
&queue, &event); &queue, &event);
// Wait for completion // Wait for completion
if (status == kSuccess) { if (status == CLBlastSuccess) {
clWaitForEvents(1, &event); clWaitForEvents(1, &event);
clReleaseEvent(event); clReleaseEvent(event);
} }

View file

@ -74,17 +74,17 @@ int main(void) {
clEnqueueWriteBuffer(queue, device_y, CL_TRUE, 0, m*sizeof(double), host_y, 0, NULL, NULL); clEnqueueWriteBuffer(queue, device_y, CL_TRUE, 0, m*sizeof(double), host_y, 0, NULL, NULL);
// Call the DGEMV routine. // Call the DGEMV routine.
StatusCode status = CLBlastDgemv(kRowMajor, kNo, CLBlastStatusCode status = CLBlastDgemv(CLBlastLayoutRowMajor, CLBlastTransposeNo,
m, n, m, n,
alpha, alpha,
device_a, 0, a_ld, device_a, 0, a_ld,
device_x, 0, 1, device_x, 0, 1,
beta, beta,
device_y, 0, 1, device_y, 0, 1,
&queue, &event); &queue, &event);
// Wait for completion // Wait for completion
if (status == kSuccess) { if (status == CLBlastSuccess) {
clWaitForEvents(1, &event); clWaitForEvents(1, &event);
clReleaseEvent(event); clReleaseEvent(event);
} }

View file

@ -71,13 +71,13 @@ int main(void) {
clEnqueueWriteBuffer(queue, device_b, CL_TRUE, 0, n*sizeof(cl_half), host_b, 0, NULL, NULL); clEnqueueWriteBuffer(queue, device_b, CL_TRUE, 0, n*sizeof(cl_half), host_b, 0, NULL, NULL);
// Call the HAXPY routine. // Call the HAXPY routine.
StatusCode status = CLBlastHaxpy(n, alpha, CLBlastStatusCode status = CLBlastHaxpy(n, alpha,
device_a, 0, 1, device_a, 0, 1,
device_b, 0, 1, device_b, 0, 1,
&queue, &event); &queue, &event);
// Wait for completion // Wait for completion
if (status == kSuccess) { if (status == CLBlastSuccess) {
clWaitForEvents(1, &event); clWaitForEvents(1, &event);
clReleaseEvent(event); clReleaseEvent(event);
} }

View file

@ -67,13 +67,13 @@ int main(void) {
clEnqueueWriteBuffer(queue, device_output, CL_TRUE, 0, 1*sizeof(float), host_output, 0, NULL, NULL); clEnqueueWriteBuffer(queue, device_output, CL_TRUE, 0, 1*sizeof(float), host_output, 0, NULL, NULL);
// Call the SASUM routine. // Call the SASUM routine.
StatusCode status = CLBlastSasum(n, CLBlastStatusCode status = CLBlastSasum(n,
device_output, 0, device_output, 0,
device_input, 0, 1, device_input, 0, 1,
&queue, &event); &queue, &event);
// Wait for completion // Wait for completion
if (status == kSuccess) { if (status == CLBlastSuccess) {
clWaitForEvents(1, &event); clWaitForEvents(1, &event);
clReleaseEvent(event); clReleaseEvent(event);
} }

View file

@ -77,17 +77,18 @@ int main(void) {
clEnqueueWriteBuffer(queue, device_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL); clEnqueueWriteBuffer(queue, device_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL);
// Call the SGEMM routine. // Call the SGEMM routine.
StatusCode status = CLBlastSgemm(kRowMajor, kNo, kNo, CLBlastStatusCode status = CLBlastSgemm(CLBlastLayoutRowMajor,
m, n, k, CLBlastTransposeNo, CLBlastTransposeNo,
alpha, m, n, k,
device_a, 0, a_ld, alpha,
device_b, 0, b_ld, device_a, 0, a_ld,
beta, device_b, 0, b_ld,
device_c, 0, c_ld, beta,
&queue, &event); device_c, 0, c_ld,
&queue, &event);
// Wait for completion // Wait for completion
if (status == kSuccess) { if (status == CLBlastSuccess) {
clWaitForEvents(1, &event); clWaitForEvents(1, &event);
clReleaseEvent(event); clReleaseEvent(event);
} }

View file

@ -73,7 +73,7 @@ def clblast_c_h(routine):
"""The C API header (.h)""" """The C API header (.h)"""
result = NL + "// " + routine.description + ": " + routine.short_names() + NL result = NL + "// " + routine.description + ": " + routine.short_names() + NL
for flavour in routine.flavours: for flavour in routine.flavours:
result += routine.routine_header_c(flavour, 31, " PUBLIC_API") + ";" + NL result += routine.routine_header_c(flavour, 38, " PUBLIC_API") + ";" + NL
return result return result
@ -82,13 +82,15 @@ def clblast_c_cc(routine):
result = NL + "// " + routine.name.upper() + NL result = NL + "// " + routine.name.upper() + NL
for flavour in routine.flavours: for flavour in routine.flavours:
template = "<" + flavour.template + ">" if routine.no_scalars() else "" template = "<" + flavour.template + ">" if routine.no_scalars() else ""
indent = " " * (45 + routine.length() + len(template)) indent = " " * (16 + routine.length() + len(template))
result += routine.routine_header_c(flavour, 20, "") + " {" + NL result += routine.routine_header_c(flavour, 27, "") + " {" + NL
result += " try {" + NL result += " try {" + NL
result += " return static_cast<StatusCode>(clblast::" + routine.name.capitalize() + template + "(" result += " return static_cast<CLBlastStatusCode>(" + NL
result += " clblast::" + routine.name.capitalize() + template + "("
result += ("," + NL + indent).join([a for a in routine.arguments_cast(flavour, indent)]) result += ("," + NL + indent).join([a for a in routine.arguments_cast(flavour, indent)])
result += "," + NL + indent + "queue, event));" + NL result += "," + NL + indent + "queue, event)" + NL
result += " } catch (...) { return static_cast<StatusCode>(clblast::DispatchExceptionForC()); }" + NL result += " );" + NL
result += " } catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }" + NL
result += "}" + NL result += "}" + NL
return result return result

View file

@ -32,7 +32,7 @@ def generate(routine):
result += "C API:" + NL result += "C API:" + NL
result += "```" + NL result += "```" + NL
for flavour in routine.flavours: for flavour in routine.flavours:
result += routine.routine_header_c(flavour, 20, "") + NL result += routine.routine_header_c(flavour, 27, "") + NL
result += "```" + NL + NL result += "```" + NL + NL
# Routine arguments # Routine arguments

View file

@ -349,6 +349,13 @@ class Routine:
return [", ".join(definitions)] return [", ".join(definitions)]
return [] return []
def options_def_c(self):
"""As above, but now for the C API"""
if self.options:
definitions = ["const CLBlast" + convert.option_to_clblast(o) + " " + o for o in self.options]
return [", ".join(definitions)]
return []
def options_def_wrapper_clblas(self): def options_def_wrapper_clblas(self):
"""As above, but now using clBLAS data-types""" """As above, but now using clBLAS data-types"""
if self.options: if self.options:
@ -453,6 +460,17 @@ class Routine:
list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_second()])) + list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_second()])) +
list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()]))) list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()])))
def arguments_def_c(self, flavour):
"""As above, but for the C API"""
return (self.options_def_c() + self.sizes_def() +
list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_first()])) +
self.scalar_def("alpha", flavour) +
list(chain(*[self.buffer_def(b) for b in self.buffers_first()])) +
self.scalar_def("beta", flavour) +
list(chain(*[self.buffer_def(b) for b in self.buffers_second()])) +
list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_second()])) +
list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()])))
def arguments_def_wrapper_clblas(self, flavour): def arguments_def_wrapper_clblas(self, flavour):
"""As above, but clBLAS wrapper plain data-types""" """As above, but clBLAS wrapper plain data-types"""
return (self.options_def_wrapper_clblas() + self.sizes_def() + return (self.options_def_wrapper_clblas() + self.sizes_def() +
@ -523,8 +541,8 @@ class Routine:
def routine_header_c(self, flavour, spaces, extra_qualifier): def routine_header_c(self, flavour, spaces, extra_qualifier):
"""As above, but now for C""" """As above, but now for C"""
indent = " " * (spaces + self.length()) indent = " " * (spaces + self.length())
result = "StatusCode" + extra_qualifier + " CLBlast" + flavour.name + self.name + "(" result = "CLBlastStatusCode" + extra_qualifier + " CLBlast" + flavour.name + self.name + "("
result += (",\n" + indent).join([a for a in self.arguments_def(flavour)]) result += (",\n" + indent).join([a for a in self.arguments_def_c(flavour)])
result += ",\n" + indent + "cl_command_queue* queue, cl_event* event)" result += ",\n" + indent + "cl_command_queue* queue, cl_event* event)"
return result return result

File diff suppressed because it is too large Load diff

View file

@ -248,7 +248,7 @@ void Tester<T,U>::TestErrorCodes(const StatusCode clblas_status, const StatusCod
} }
// Could not compile the CLBlast kernel properly // Could not compile the CLBlast kernel properly
else if (clblast_status == StatusCode::kBuildProgramFailure || else if (clblast_status == StatusCode::kOpenCLBuildProgramFailure ||
clblast_status == StatusCode::kNotImplemented) { clblast_status == StatusCode::kNotImplemented) {
PrintTestResult(kSkippedCompilation); PrintTestResult(kSkippedCompilation);
ReportSkipped(); ReportSkipped();