Clean-up of the routine class, moved RunKernel to the routine/common file
parent
7b4c0e1cf0
commit
bacb5d2bb2
|
@ -141,7 +141,7 @@ set(PRECISIONS 32 64 3232 6464)
|
|||
|
||||
# Gathers all source-files
|
||||
set(SOURCES src/clblast.cc src/database.cc src/routine.cc src/cache.cc
|
||||
src/utilities.cc src/clblast_c.cc)
|
||||
src/utilities.cc src/clblast_c.cc src/routines/common.cc)
|
||||
foreach(ROUTINE ${LEVEL1_ROUTINES})
|
||||
set(SOURCES ${SOURCES} src/routines/level1/${ROUTINE}.cc)
|
||||
endforeach()
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "internal/utilities.h"
|
||||
#include "internal/database.h"
|
||||
#include "internal/buffer_test.h"
|
||||
#include "internal/routines/common.h"
|
||||
|
||||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
@ -40,8 +41,7 @@ class Routine {
|
|||
|
||||
protected:
|
||||
|
||||
// Non-static variable for the precision. Note that the same variable (but static) might exist in
|
||||
// a derived class.
|
||||
// Non-static variable for the precision
|
||||
const Precision precision_;
|
||||
|
||||
// The routine's name and its kernel-source in string form
|
||||
|
@ -61,23 +61,8 @@ class Routine {
|
|||
const Database db_;
|
||||
};
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Enqueues a kernel, waits for completion, and checks for errors
|
||||
StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device,
|
||||
std::vector<size_t> global, const std::vector<size_t> &local,
|
||||
EventPointer event, std::vector<Event>& waitForEvents);
|
||||
|
||||
// As above, but without an event waiting list
|
||||
StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device,
|
||||
std::vector<size_t> global, const std::vector<size_t> &local,
|
||||
EventPointer event);
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
||||
// Temporary fix: TODO place include in a more logical place
|
||||
#include "internal/routines/common.h"
|
||||
|
||||
// CLBLAST_ROUTINE_H_
|
||||
#endif
|
||||
|
|
|
@ -8,7 +8,8 @@
|
|||
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||
//
|
||||
// This file contains all the interfaces to common kernels, such as copying, padding, and
|
||||
// transposing a matrix. These functions are templated and thus header-only.
|
||||
// transposing a matrix. These functions are templated and thus header-only. This file also contains
|
||||
// other common functions to routines, such as a function to launch a kernel.
|
||||
//
|
||||
// =================================================================================================
|
||||
|
||||
|
@ -18,17 +19,30 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "internal/utilities.h"
|
||||
#include "internal/routine.h"
|
||||
#include "clblast.h"
|
||||
#include "internal/clpp11.h"
|
||||
#include "internal/database.h"
|
||||
|
||||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
// Enqueues a kernel, waits for completion, and checks for errors
|
||||
StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
|
||||
std::vector<size_t> global, const std::vector<size_t> &local,
|
||||
EventPointer event, std::vector<Event>& waitForEvents);
|
||||
|
||||
// As above, but without an event waiting list
|
||||
StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
|
||||
std::vector<size_t> global, const std::vector<size_t> &local,
|
||||
EventPointer event);
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Copies or transposes a matrix and optionally pads/unpads it with zeros. This method is also able
|
||||
// to write to symmetric and triangular matrices through optional arguments.
|
||||
template <typename T>
|
||||
StatusCode PadCopyTransposeMatrix(Queue queue, const Device device, const Context context,
|
||||
const Database db,
|
||||
StatusCode PadCopyTransposeMatrix(Queue &queue, const Device &device, const Context &context,
|
||||
const Database &db,
|
||||
EventPointer event, std::vector<Event>& waitForEvents,
|
||||
const size_t src_one, const size_t src_two,
|
||||
const size_t src_ld, const size_t src_offset,
|
||||
|
|
|
@ -386,7 +386,7 @@ files = [
|
|||
path_clblast+"/test/wrapper_cblas.h",
|
||||
]
|
||||
header_lines = [84, 74, 93, 22, 29, 41]
|
||||
footer_lines = [17, 71, 19, 14, 6, 6]
|
||||
footer_lines = [17, 75, 19, 14, 6, 6]
|
||||
|
||||
# Checks whether the command-line arguments are valid; exists otherwise
|
||||
for f in files:
|
||||
|
|
|
@ -29,10 +29,10 @@
|
|||
#include "internal/routines/level1/xdotc.h"
|
||||
#include "internal/routines/level1/xnrm2.h"
|
||||
#include "internal/routines/level1/xasum.h"
|
||||
#include "internal/routines/level1/xsum.h" // non-BLAS function
|
||||
#include "internal/routines/level1/xsum.h" // non-BLAS routine
|
||||
#include "internal/routines/level1/xamax.h"
|
||||
#include "internal/routines/level1/xmax.h" // non-BLAS function
|
||||
#include "internal/routines/level1/xmin.h" // non-BLAS function
|
||||
#include "internal/routines/level1/xmax.h" // non-BLAS routine
|
||||
#include "internal/routines/level1/xmin.h" // non-BLAS routine
|
||||
|
||||
// BLAS level-2 includes
|
||||
#include "internal/routines/level2/xgemv.h"
|
||||
|
@ -68,7 +68,7 @@
|
|||
#include "internal/routines/level3/xher2k.h"
|
||||
#include "internal/routines/level3/xtrmm.h"
|
||||
|
||||
// Extra includes (level-x)
|
||||
// Level-x includes (non-BLAS)
|
||||
#include "internal/routines/levelx/xomatcopy.h"
|
||||
|
||||
namespace clblast {
|
||||
|
@ -2123,6 +2123,7 @@ template StatusCode PUBLIC_API Omatcopy<half>(const Layout, const Transpose,
|
|||
StatusCode ClearCache() { return CacheClearAll(); }
|
||||
|
||||
// Fills the cache with all binaries for a specific device
|
||||
// TODO: Add half-precision FP16 set-up calls
|
||||
StatusCode FillCache(const cl_device_id device) {
|
||||
try {
|
||||
|
||||
|
@ -2171,7 +2172,7 @@ StatusCode FillCache(const cl_device_id device) {
|
|||
Xsyr2<float>(queue, nullptr).SetUp(); Xsyr2<double>(queue, nullptr).SetUp();
|
||||
Xspr2<float>(queue, nullptr).SetUp(); Xspr2<double>(queue, nullptr).SetUp();
|
||||
|
||||
// Runs all the level 1 set-up functions
|
||||
// Runs all the level 3 set-up functions
|
||||
Xgemm<float>(queue, nullptr).SetUp(); Xgemm<double>(queue, nullptr).SetUp(); Xgemm<float2>(queue, nullptr).SetUp(); Xgemm<double2>(queue, nullptr).SetUp();
|
||||
Xsymm<float>(queue, nullptr).SetUp(); Xsymm<double>(queue, nullptr).SetUp(); Xsymm<float2>(queue, nullptr).SetUp(); Xsymm<double2>(queue, nullptr).SetUp();
|
||||
Xhemm<float2>(queue, nullptr).SetUp(); Xhemm<double2>(queue, nullptr).SetUp();
|
||||
|
@ -2181,6 +2182,9 @@ StatusCode FillCache(const cl_device_id device) {
|
|||
Xher2k<float2,float>(queue, nullptr).SetUp(); Xher2k<double2,double>(queue, nullptr).SetUp();
|
||||
Xtrmm<float>(queue, nullptr).SetUp(); Xtrmm<double>(queue, nullptr).SetUp(); Xtrmm<float2>(queue, nullptr).SetUp(); Xtrmm<double2>(queue, nullptr).SetUp();
|
||||
|
||||
// Runs all the level 3 set-up functions
|
||||
Xomatcopy<float>(queue, nullptr).SetUp(); Xomatcopy<double>(queue, nullptr).SetUp(); Xomatcopy<float2>(queue, nullptr).SetUp(); Xomatcopy<double2>(queue, nullptr).SetUp();
|
||||
|
||||
} catch (...) { return StatusCode::kBuildProgramFailure; }
|
||||
return StatusCode::kSuccess;
|
||||
}
|
||||
|
|
|
@ -127,50 +127,5 @@ StatusCode Routine::SetUp() {
|
|||
return StatusCode::kSuccess;
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Enqueues a kernel, waits for completion, and checks for errors
|
||||
StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device,
|
||||
std::vector<size_t> global, const std::vector<size_t> &local,
|
||||
EventPointer event, std::vector<Event>& waitForEvents) {
|
||||
|
||||
// Tests for validity of the local thread sizes
|
||||
if (local.size() > device.MaxWorkItemDimensions()) {
|
||||
return StatusCode::kInvalidLocalNumDimensions;
|
||||
}
|
||||
const auto max_work_item_sizes = device.MaxWorkItemSizes();
|
||||
for (auto i=size_t{0}; i<local.size(); ++i) {
|
||||
if (local[i] > max_work_item_sizes[i]) { return StatusCode::kInvalidLocalThreadsDim; }
|
||||
}
|
||||
auto local_size = size_t{1};
|
||||
for (auto &item: local) { local_size *= item; }
|
||||
if (local_size > device.MaxWorkGroupSize()) { return StatusCode::kInvalidLocalThreadsTotal; }
|
||||
|
||||
// Make sure the global thread sizes are at least equal to the local sizes
|
||||
for (auto i=size_t{0}; i<global.size(); ++i) {
|
||||
if (global[i] < local[i]) { global[i] = local[i]; }
|
||||
}
|
||||
|
||||
// Tests for local memory usage
|
||||
const auto local_mem_usage = kernel.LocalMemUsage(device);
|
||||
if (!device.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; }
|
||||
|
||||
// Launches the kernel (and checks for launch errors)
|
||||
try {
|
||||
kernel.Launch(queue, global, local, event, waitForEvents);
|
||||
} catch (...) { return StatusCode::kKernelLaunchError; }
|
||||
|
||||
// No errors, normal termination of this function
|
||||
return StatusCode::kSuccess;
|
||||
}
|
||||
|
||||
// As above, but without an event waiting list
|
||||
StatusCode RunKernel(Kernel &kernel, Queue queue, const Device device,
|
||||
std::vector<size_t> global, const std::vector<size_t> &local,
|
||||
EventPointer event) {
|
||||
auto emptyWaitingList = std::vector<Event>();
|
||||
return RunKernel(kernel, queue, device, global, local, event, emptyWaitingList);
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
|
||||
// =================================================================================================
|
||||
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||
// width of 100 characters per line.
|
||||
//
|
||||
// Author(s):
|
||||
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||
//
|
||||
// This file implements the common routine functions (see the header for more information).
|
||||
//
|
||||
// =================================================================================================
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "internal/routines/common.h"
|
||||
|
||||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
// Enqueues a kernel, waits for completion, and checks for errors
|
||||
StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
|
||||
std::vector<size_t> global, const std::vector<size_t> &local,
|
||||
EventPointer event, std::vector<Event>& waitForEvents) {
|
||||
|
||||
// Tests for validity of the local thread sizes
|
||||
if (local.size() > device.MaxWorkItemDimensions()) {
|
||||
return StatusCode::kInvalidLocalNumDimensions;
|
||||
}
|
||||
const auto max_work_item_sizes = device.MaxWorkItemSizes();
|
||||
for (auto i=size_t{0}; i<local.size(); ++i) {
|
||||
if (local[i] > max_work_item_sizes[i]) { return StatusCode::kInvalidLocalThreadsDim; }
|
||||
}
|
||||
auto local_size = size_t{1};
|
||||
for (auto &item: local) { local_size *= item; }
|
||||
if (local_size > device.MaxWorkGroupSize()) { return StatusCode::kInvalidLocalThreadsTotal; }
|
||||
|
||||
// Make sure the global thread sizes are at least equal to the local sizes
|
||||
for (auto i=size_t{0}; i<global.size(); ++i) {
|
||||
if (global[i] < local[i]) { global[i] = local[i]; }
|
||||
}
|
||||
|
||||
// Tests for local memory usage
|
||||
const auto local_mem_usage = kernel.LocalMemUsage(device);
|
||||
if (!device.IsLocalMemoryValid(local_mem_usage)) { return StatusCode::kInvalidLocalMemUsage; }
|
||||
|
||||
// Launches the kernel (and checks for launch errors)
|
||||
try {
|
||||
kernel.Launch(queue, global, local, event, waitForEvents);
|
||||
} catch (...) { return StatusCode::kKernelLaunchError; }
|
||||
|
||||
// No errors, normal termination of this function
|
||||
return StatusCode::kSuccess;
|
||||
}
|
||||
|
||||
// As above, but without an event waiting list
|
||||
StatusCode RunKernel(Kernel &kernel, Queue &queue, const Device &device,
|
||||
std::vector<size_t> global, const std::vector<size_t> &local,
|
||||
EventPointer event) {
|
||||
auto emptyWaitingList = std::vector<Event>();
|
||||
return RunKernel(kernel, queue, device, global, local, event, emptyWaitingList);
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
Loading…
Reference in New Issue