Added a CUDA version of the GEMM temp-buffer optional argument
parent
af14fff1e9
commit
44431daecc
|
@ -492,7 +492,8 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
|
|||
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
const T beta,
|
||||
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
|
||||
const CUcontext context, const CUdevice device);
|
||||
const CUcontext context, const CUdevice device,
|
||||
CUdeviceptr temp_buffer = nullptr);
|
||||
|
||||
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
|
||||
template <typename T>
|
||||
|
|
|
@ -60,12 +60,12 @@ def clblast_cc(routine, cuda=False):
|
|||
result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, " + event + ");" + NL
|
||||
if routine.batched:
|
||||
result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL
|
||||
if routine.temp_buffer and not cuda:
|
||||
if routine.temp_buffer:
|
||||
result += " const auto temp_buffer_provided = temp_buffer != nullptr;\n"
|
||||
result += " auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(nullptr);\n"
|
||||
result += " routine.Do" + routine.capitalized_name() + "("
|
||||
result += ("," + NL + indent1).join([a for a in routine.arguments_clcudaapi()])
|
||||
if routine.temp_buffer and not cuda:
|
||||
if routine.temp_buffer:
|
||||
result += ",\n" + indent1 + "temp_buffer_cpp, temp_buffer_provided"
|
||||
result += ");" + NL
|
||||
result += " return StatusCode::kSuccess;" + NL
|
||||
|
@ -84,6 +84,8 @@ def clblast_cc(routine, cuda=False):
|
|||
result += "," + NL + indent2
|
||||
if cuda:
|
||||
result += "const CUcontext, const CUdevice"
|
||||
if routine.temp_buffer:
|
||||
result += ", CUdeviceptr"
|
||||
else:
|
||||
result += "cl_command_queue*, cl_event*"
|
||||
if routine.temp_buffer:
|
||||
|
|
|
@ -819,7 +819,7 @@ class Routine:
|
|||
result += "const CUcontext context, const CUdevice device"
|
||||
else:
|
||||
result += "cl_command_queue* queue, cl_event* event" + default_event
|
||||
if self.temp_buffer and not cuda:
|
||||
if self.temp_buffer:
|
||||
result += ",\n" + indent + mem_type + " temp_buffer"
|
||||
if not implementation:
|
||||
result += " = nullptr"
|
||||
|
|
|
@ -1725,19 +1725,23 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
|
|||
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
const T beta,
|
||||
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
|
||||
const CUcontext context, const CUdevice device) {
|
||||
const CUcontext context, const CUdevice device,
|
||||
CUdeviceptr temp_buffer) {
|
||||
try {
|
||||
const auto context_cpp = Context(context);
|
||||
const auto device_cpp = Device(device);
|
||||
auto queue_cpp = Queue(context_cpp, device_cpp);
|
||||
auto routine = Xgemm<T>(queue_cpp, nullptr);
|
||||
const auto temp_buffer_provided = temp_buffer != nullptr;
|
||||
auto temp_buffer_cpp = temp_buffer_provided ? Buffer<T>(temp_buffer) : Buffer<T>(nullptr);
|
||||
routine.DoGemm(layout, a_transpose, b_transpose,
|
||||
m, n, k,
|
||||
alpha,
|
||||
Buffer<T>(a_buffer), a_offset, a_ld,
|
||||
Buffer<T>(b_buffer), b_offset, b_ld,
|
||||
beta,
|
||||
Buffer<T>(c_buffer), c_offset, c_ld);
|
||||
Buffer<T>(c_buffer), c_offset, c_ld,
|
||||
temp_buffer_cpp, temp_buffer_provided);
|
||||
return StatusCode::kSuccess;
|
||||
} catch (...) { return DispatchException(); }
|
||||
}
|
||||
|
@ -1748,7 +1752,7 @@ template StatusCode PUBLIC_API Gemm<float>(const Layout, const Transpose, const
|
|||
const CUdeviceptr, const size_t, const size_t,
|
||||
const float,
|
||||
CUdeviceptr, const size_t, const size_t,
|
||||
const CUcontext, const CUdevice);
|
||||
const CUcontext, const CUdevice, CUdeviceptr);
|
||||
template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const double,
|
||||
|
@ -1756,7 +1760,7 @@ template StatusCode PUBLIC_API Gemm<double>(const Layout, const Transpose, const
|
|||
const CUdeviceptr, const size_t, const size_t,
|
||||
const double,
|
||||
CUdeviceptr, const size_t, const size_t,
|
||||
const CUcontext, const CUdevice);
|
||||
const CUcontext, const CUdevice, CUdeviceptr);
|
||||
template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const float2,
|
||||
|
@ -1764,7 +1768,7 @@ template StatusCode PUBLIC_API Gemm<float2>(const Layout, const Transpose, const
|
|||
const CUdeviceptr, const size_t, const size_t,
|
||||
const float2,
|
||||
CUdeviceptr, const size_t, const size_t,
|
||||
const CUcontext, const CUdevice);
|
||||
const CUcontext, const CUdevice, CUdeviceptr);
|
||||
template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const double2,
|
||||
|
@ -1772,7 +1776,7 @@ template StatusCode PUBLIC_API Gemm<double2>(const Layout, const Transpose, cons
|
|||
const CUdeviceptr, const size_t, const size_t,
|
||||
const double2,
|
||||
CUdeviceptr, const size_t, const size_t,
|
||||
const CUcontext, const CUdevice);
|
||||
const CUcontext, const CUdevice, CUdeviceptr);
|
||||
template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const Transpose,
|
||||
const size_t, const size_t, const size_t,
|
||||
const half,
|
||||
|
@ -1780,7 +1784,7 @@ template StatusCode PUBLIC_API Gemm<half>(const Layout, const Transpose, const T
|
|||
const CUdeviceptr, const size_t, const size_t,
|
||||
const half,
|
||||
CUdeviceptr, const size_t, const size_t,
|
||||
const CUcontext, const CUdevice);
|
||||
const CUcontext, const CUdevice, CUdeviceptr);
|
||||
|
||||
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
|
||||
template <typename T>
|
||||
|
|
Loading…
Reference in New Issue