diff --git a/include/clblast_cuda.h b/include/clblast_cuda.h index e28f68e5..0f510981 100644 --- a/include/clblast_cuda.h +++ b/include/clblast_cuda.h @@ -492,7 +492,8 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const T beta, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, - const CUcontext context, const CUdevice device); + const CUcontext context, const CUdevice device, + CUdeviceptr temp_buffer = nullptr); // Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template diff --git a/scripts/generator/generator/cpp.py b/scripts/generator/generator/cpp.py index a850a032..656253d7 100644 --- a/scripts/generator/generator/cpp.py +++ b/scripts/generator/generator/cpp.py @@ -60,12 +60,12 @@ def clblast_cc(routine, cuda=False): result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, " + event + ");" + NL if routine.batched: result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL - if routine.temp_buffer and not cuda: + if routine.temp_buffer: result += " const auto temp_buffer_provided = temp_buffer != nullptr;\n" result += " auto temp_buffer_cpp = temp_buffer_provided ? Buffer(temp_buffer) : Buffer(nullptr);\n" result += " routine.Do" + routine.capitalized_name() + "(" result += ("," + NL + indent1).join([a for a in routine.arguments_clcudaapi()]) - if routine.temp_buffer and not cuda: + if routine.temp_buffer: result += ",\n" + indent1 + "temp_buffer_cpp, temp_buffer_provided" result += ");" + NL result += " return StatusCode::kSuccess;" + NL @@ -84,6 +84,8 @@ def clblast_cc(routine, cuda=False): result += "," + NL + indent2 if cuda: result += "const CUcontext, const CUdevice" + if routine.temp_buffer: + result += ", CUdeviceptr" else: result += "cl_command_queue*, cl_event*" if routine.temp_buffer: diff --git a/scripts/generator/generator/routine.py b/scripts/generator/generator/routine.py index 11b9080f..22be02b0 100644 --- a/scripts/generator/generator/routine.py +++ b/scripts/generator/generator/routine.py @@ -819,7 +819,7 @@ class Routine: result += "const CUcontext context, const CUdevice device" else: result += "cl_command_queue* queue, cl_event* event" + default_event - if self.temp_buffer and not cuda: + if self.temp_buffer: result += ",\n" + indent + mem_type + " temp_buffer" if not implementation: result += " = nullptr" diff --git a/src/clblast_cuda.cpp b/src/clblast_cuda.cpp index 0e3d949d..187443eb 100644 --- a/src/clblast_cuda.cpp +++ b/src/clblast_cuda.cpp @@ -1725,19 +1725,23 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const T beta, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, - const CUcontext context, const CUdevice device) { + const CUcontext context, const CUdevice device, + CUdeviceptr temp_buffer) { try { const auto context_cpp = Context(context); const auto device_cpp = Device(device); auto queue_cpp = Queue(context_cpp, device_cpp); auto routine = Xgemm(queue_cpp, nullptr); + const auto temp_buffer_provided = temp_buffer != nullptr; + auto temp_buffer_cpp = temp_buffer_provided ? Buffer(temp_buffer) : Buffer(nullptr); routine.DoGemm(layout, a_transpose, b_transpose, m, n, k, alpha, Buffer(a_buffer), a_offset, a_ld, Buffer(b_buffer), b_offset, b_ld, beta, - Buffer(c_buffer), c_offset, c_ld); + Buffer(c_buffer), c_offset, c_ld, + temp_buffer_cpp, temp_buffer_provided); return StatusCode::kSuccess; } catch (...) { return DispatchException(); } } @@ -1748,7 +1752,7 @@ template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const const CUdeviceptr, const size_t, const size_t, const float, CUdeviceptr, const size_t, const size_t, - const CUcontext, const CUdevice); + const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double, @@ -1756,7 +1760,7 @@ template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const const CUdeviceptr, const size_t, const size_t, const double, CUdeviceptr, const size_t, const size_t, - const CUcontext, const CUdevice); + const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const float2, @@ -1764,7 +1768,7 @@ template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const const CUdeviceptr, const size_t, const size_t, const float2, CUdeviceptr, const size_t, const size_t, - const CUcontext, const CUdevice); + const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const double2, @@ -1772,7 +1776,7 @@ template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, cons const CUdeviceptr, const size_t, const size_t, const double2, CUdeviceptr, const size_t, const size_t, - const CUcontext, const CUdevice); + const CUcontext, const CUdevice, CUdeviceptr); template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const Transpose, const size_t, const size_t, const size_t, const half, @@ -1780,7 +1784,7 @@ template StatusCode PUBLIC_API Gemm(const Layout, const Transpose, const T const CUdeviceptr, const size_t, const size_t, const half, CUdeviceptr, const size_t, const size_t, - const CUcontext, const CUdevice); + const CUcontext, const CUdevice, CUdeviceptr); // Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM template