CUDA API now takes context and device in instead of stream

pull/204/head
Cedric Nugteren 2017-10-12 12:20:43 +02:00
parent b901809345
commit cc5b475425
5 changed files with 476 additions and 371 deletions

View File

@ -103,7 +103,7 @@ StatusCode Rotg(CUdeviceptr sa_buffer, const size_t sa_offset,
CUdeviceptr sb_buffer, const size_t sb_offset, CUdeviceptr sb_buffer, const size_t sb_offset,
CUdeviceptr sc_buffer, const size_t sc_offset, CUdeviceptr sc_buffer, const size_t sc_offset,
CUdeviceptr ss_buffer, const size_t ss_offset, CUdeviceptr ss_buffer, const size_t ss_offset,
CUstream* stream); const CUcontext context, const CUdevice device);
// Generate modified givens plane rotation: SROTMG/DROTMG // Generate modified givens plane rotation: SROTMG/DROTMG
template <typename T> template <typename T>
@ -112,7 +112,7 @@ StatusCode Rotmg(CUdeviceptr sd1_buffer, const size_t sd1_offset,
CUdeviceptr sx1_buffer, const size_t sx1_offset, CUdeviceptr sx1_buffer, const size_t sx1_offset,
const CUdeviceptr sy1_buffer, const size_t sy1_offset, const CUdeviceptr sy1_buffer, const size_t sy1_offset,
CUdeviceptr sparam_buffer, const size_t sparam_offset, CUdeviceptr sparam_buffer, const size_t sparam_offset,
CUstream* stream); const CUcontext context, const CUdevice device);
// Apply givens plane rotation: SROT/DROT // Apply givens plane rotation: SROT/DROT
template <typename T> template <typename T>
@ -121,7 +121,7 @@ StatusCode Rot(const size_t n,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
const T cos, const T cos,
const T sin, const T sin,
CUstream* stream); const CUcontext context, const CUdevice device);
// Apply modified givens plane rotation: SROTM/DROTM // Apply modified givens plane rotation: SROTM/DROTM
template <typename T> template <typename T>
@ -129,28 +129,28 @@ StatusCode Rotm(const size_t n,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr sparam_buffer, const size_t sparam_offset, CUdeviceptr sparam_buffer, const size_t sparam_offset,
CUstream* stream); const CUcontext context, const CUdevice device);
// Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP // Swap two vectors: SSWAP/DSWAP/CSWAP/ZSWAP/HSWAP
template <typename T> template <typename T>
StatusCode Swap(const size_t n, StatusCode Swap(const size_t n,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Vector scaling: SSCAL/DSCAL/CSCAL/ZSCAL/HSCAL // Vector scaling: SSCAL/DSCAL/CSCAL/ZSCAL/HSCAL
template <typename T> template <typename T>
StatusCode Scal(const size_t n, StatusCode Scal(const size_t n,
const T alpha, const T alpha,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Vector copy: SCOPY/DCOPY/CCOPY/ZCOPY/HCOPY // Vector copy: SCOPY/DCOPY/CCOPY/ZCOPY/HCOPY
template <typename T> template <typename T>
StatusCode Copy(const size_t n, StatusCode Copy(const size_t n,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY // Vector-times-constant plus vector: SAXPY/DAXPY/CAXPY/ZAXPY/HAXPY
template <typename T> template <typename T>
@ -158,7 +158,7 @@ StatusCode Axpy(const size_t n,
const T alpha, const T alpha,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Dot product of two vectors: SDOT/DDOT/HDOT // Dot product of two vectors: SDOT/DDOT/HDOT
template <typename T> template <typename T>
@ -166,7 +166,7 @@ StatusCode Dot(const size_t n,
CUdeviceptr dot_buffer, const size_t dot_offset, CUdeviceptr dot_buffer, const size_t dot_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Dot product of two complex vectors: CDOTU/ZDOTU // Dot product of two complex vectors: CDOTU/ZDOTU
template <typename T> template <typename T>
@ -174,7 +174,7 @@ StatusCode Dotu(const size_t n,
CUdeviceptr dot_buffer, const size_t dot_offset, CUdeviceptr dot_buffer, const size_t dot_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Dot product of two complex vectors, one conjugated: CDOTC/ZDOTC // Dot product of two complex vectors, one conjugated: CDOTC/ZDOTC
template <typename T> template <typename T>
@ -182,56 +182,56 @@ StatusCode Dotc(const size_t n,
CUdeviceptr dot_buffer, const size_t dot_offset, CUdeviceptr dot_buffer, const size_t dot_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Euclidian norm of a vector: SNRM2/DNRM2/ScNRM2/DzNRM2/HNRM2 // Euclidian norm of a vector: SNRM2/DNRM2/ScNRM2/DzNRM2/HNRM2
template <typename T> template <typename T>
StatusCode Nrm2(const size_t n, StatusCode Nrm2(const size_t n,
CUdeviceptr nrm2_buffer, const size_t nrm2_offset, CUdeviceptr nrm2_buffer, const size_t nrm2_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Absolute sum of values in a vector: SASUM/DASUM/ScASUM/DzASUM/HASUM // Absolute sum of values in a vector: SASUM/DASUM/ScASUM/DzASUM/HASUM
template <typename T> template <typename T>
StatusCode Asum(const size_t n, StatusCode Asum(const size_t n,
CUdeviceptr asum_buffer, const size_t asum_offset, CUdeviceptr asum_buffer, const size_t asum_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Sum of values in a vector (non-BLAS function): SSUM/DSUM/ScSUM/DzSUM/HSUM // Sum of values in a vector (non-BLAS function): SSUM/DSUM/ScSUM/DzSUM/HSUM
template <typename T> template <typename T>
StatusCode Sum(const size_t n, StatusCode Sum(const size_t n,
CUdeviceptr sum_buffer, const size_t sum_offset, CUdeviceptr sum_buffer, const size_t sum_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Index of absolute maximum value in a vector: iSAMAX/iDAMAX/iCAMAX/iZAMAX/iHAMAX // Index of absolute maximum value in a vector: iSAMAX/iDAMAX/iCAMAX/iZAMAX/iHAMAX
template <typename T> template <typename T>
StatusCode Amax(const size_t n, StatusCode Amax(const size_t n,
CUdeviceptr imax_buffer, const size_t imax_offset, CUdeviceptr imax_buffer, const size_t imax_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Index of absolute minimum value in a vector (non-BLAS function): iSAMIN/iDAMIN/iCAMIN/iZAMIN/iHAMIN // Index of absolute minimum value in a vector (non-BLAS function): iSAMIN/iDAMIN/iCAMIN/iZAMIN/iHAMIN
template <typename T> template <typename T>
StatusCode Amin(const size_t n, StatusCode Amin(const size_t n,
CUdeviceptr imin_buffer, const size_t imin_offset, CUdeviceptr imin_buffer, const size_t imin_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Index of maximum value in a vector (non-BLAS function): iSMAX/iDMAX/iCMAX/iZMAX/iHMAX // Index of maximum value in a vector (non-BLAS function): iSMAX/iDMAX/iCMAX/iZMAX/iHMAX
template <typename T> template <typename T>
StatusCode Max(const size_t n, StatusCode Max(const size_t n,
CUdeviceptr imax_buffer, const size_t imax_offset, CUdeviceptr imax_buffer, const size_t imax_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Index of minimum value in a vector (non-BLAS function): iSMIN/iDMIN/iCMIN/iZMIN/iHMIN // Index of minimum value in a vector (non-BLAS function): iSMIN/iDMIN/iCMIN/iZMIN/iHMIN
template <typename T> template <typename T>
StatusCode Min(const size_t n, StatusCode Min(const size_t n,
CUdeviceptr imin_buffer, const size_t imin_offset, CUdeviceptr imin_buffer, const size_t imin_offset,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// ================================================================================================= // =================================================================================================
// BLAS level-2 (matrix-vector) routines // BLAS level-2 (matrix-vector) routines
@ -246,7 +246,7 @@ StatusCode Gemv(const Layout layout, const Transpose a_transpose,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta, const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// General banded matrix-vector multiplication: SGBMV/DGBMV/CGBMV/ZGBMV/HGBMV // General banded matrix-vector multiplication: SGBMV/DGBMV/CGBMV/ZGBMV/HGBMV
template <typename T> template <typename T>
@ -257,7 +257,7 @@ StatusCode Gbmv(const Layout layout, const Transpose a_transpose,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta, const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Hermitian matrix-vector multiplication: CHEMV/ZHEMV // Hermitian matrix-vector multiplication: CHEMV/ZHEMV
template <typename T> template <typename T>
@ -268,7 +268,7 @@ StatusCode Hemv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta, const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Hermitian banded matrix-vector multiplication: CHBMV/ZHBMV // Hermitian banded matrix-vector multiplication: CHBMV/ZHBMV
template <typename T> template <typename T>
@ -279,7 +279,7 @@ StatusCode Hbmv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta, const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Hermitian packed matrix-vector multiplication: CHPMV/ZHPMV // Hermitian packed matrix-vector multiplication: CHPMV/ZHPMV
template <typename T> template <typename T>
@ -290,7 +290,7 @@ StatusCode Hpmv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta, const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Symmetric matrix-vector multiplication: SSYMV/DSYMV/HSYMV // Symmetric matrix-vector multiplication: SSYMV/DSYMV/HSYMV
template <typename T> template <typename T>
@ -301,7 +301,7 @@ StatusCode Symv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta, const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Symmetric banded matrix-vector multiplication: SSBMV/DSBMV/HSBMV // Symmetric banded matrix-vector multiplication: SSBMV/DSBMV/HSBMV
template <typename T> template <typename T>
@ -312,7 +312,7 @@ StatusCode Sbmv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta, const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Symmetric packed matrix-vector multiplication: SSPMV/DSPMV/HSPMV // Symmetric packed matrix-vector multiplication: SSPMV/DSPMV/HSPMV
template <typename T> template <typename T>
@ -323,7 +323,7 @@ StatusCode Spmv(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const T beta, const T beta,
CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Triangular matrix-vector multiplication: STRMV/DTRMV/CTRMV/ZTRMV/HTRMV // Triangular matrix-vector multiplication: STRMV/DTRMV/CTRMV/ZTRMV/HTRMV
template <typename T> template <typename T>
@ -331,7 +331,7 @@ StatusCode Trmv(const Layout layout, const Triangle triangle, const Transpose a_
const size_t n, const size_t n,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Triangular banded matrix-vector multiplication: STBMV/DTBMV/CTBMV/ZTBMV/HTBMV // Triangular banded matrix-vector multiplication: STBMV/DTBMV/CTBMV/ZTBMV/HTBMV
template <typename T> template <typename T>
@ -339,7 +339,7 @@ StatusCode Tbmv(const Layout layout, const Triangle triangle, const Transpose a_
const size_t n, const size_t k, const size_t n, const size_t k,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Triangular packed matrix-vector multiplication: STPMV/DTPMV/CTPMV/ZTPMV/HTPMV // Triangular packed matrix-vector multiplication: STPMV/DTPMV/CTPMV/ZTPMV/HTPMV
template <typename T> template <typename T>
@ -347,7 +347,7 @@ StatusCode Tpmv(const Layout layout, const Triangle triangle, const Transpose a_
const size_t n, const size_t n,
const CUdeviceptr ap_buffer, const size_t ap_offset, const CUdeviceptr ap_buffer, const size_t ap_offset,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Solves a triangular system of equations: STRSV/DTRSV/CTRSV/ZTRSV // Solves a triangular system of equations: STRSV/DTRSV/CTRSV/ZTRSV
template <typename T> template <typename T>
@ -355,7 +355,7 @@ StatusCode Trsv(const Layout layout, const Triangle triangle, const Transpose a_
const size_t n, const size_t n,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Solves a banded triangular system of equations: STBSV/DTBSV/CTBSV/ZTBSV // Solves a banded triangular system of equations: STBSV/DTBSV/CTBSV/ZTBSV
template <typename T> template <typename T>
@ -363,7 +363,7 @@ StatusCode Tbsv(const Layout layout, const Triangle triangle, const Transpose a_
const size_t n, const size_t k, const size_t n, const size_t k,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// Solves a packed triangular system of equations: STPSV/DTPSV/CTPSV/ZTPSV // Solves a packed triangular system of equations: STPSV/DTPSV/CTPSV/ZTPSV
template <typename T> template <typename T>
@ -371,7 +371,7 @@ StatusCode Tpsv(const Layout layout, const Triangle triangle, const Transpose a_
const size_t n, const size_t n,
const CUdeviceptr ap_buffer, const size_t ap_offset, const CUdeviceptr ap_buffer, const size_t ap_offset,
CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUstream* stream); const CUcontext context, const CUdevice device);
// General rank-1 matrix update: SGER/DGER/HGER // General rank-1 matrix update: SGER/DGER/HGER
template <typename T> template <typename T>
@ -381,7 +381,7 @@ StatusCode Ger(const Layout layout,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// General rank-1 complex matrix update: CGERU/ZGERU // General rank-1 complex matrix update: CGERU/ZGERU
template <typename T> template <typename T>
@ -391,7 +391,7 @@ StatusCode Geru(const Layout layout,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// General rank-1 complex conjugated matrix update: CGERC/ZGERC // General rank-1 complex conjugated matrix update: CGERC/ZGERC
template <typename T> template <typename T>
@ -401,7 +401,7 @@ StatusCode Gerc(const Layout layout,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Hermitian rank-1 matrix update: CHER/ZHER // Hermitian rank-1 matrix update: CHER/ZHER
template <typename T> template <typename T>
@ -410,7 +410,7 @@ StatusCode Her(const Layout layout, const Triangle triangle,
const T alpha, const T alpha,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Hermitian packed rank-1 matrix update: CHPR/ZHPR // Hermitian packed rank-1 matrix update: CHPR/ZHPR
template <typename T> template <typename T>
@ -419,7 +419,7 @@ StatusCode Hpr(const Layout layout, const Triangle triangle,
const T alpha, const T alpha,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr ap_buffer, const size_t ap_offset, CUdeviceptr ap_buffer, const size_t ap_offset,
CUstream* stream); const CUcontext context, const CUdevice device);
// Hermitian rank-2 matrix update: CHER2/ZHER2 // Hermitian rank-2 matrix update: CHER2/ZHER2
template <typename T> template <typename T>
@ -429,7 +429,7 @@ StatusCode Her2(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Hermitian packed rank-2 matrix update: CHPR2/ZHPR2 // Hermitian packed rank-2 matrix update: CHPR2/ZHPR2
template <typename T> template <typename T>
@ -439,7 +439,7 @@ StatusCode Hpr2(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr ap_buffer, const size_t ap_offset, CUdeviceptr ap_buffer, const size_t ap_offset,
CUstream* stream); const CUcontext context, const CUdevice device);
// Symmetric rank-1 matrix update: SSYR/DSYR/HSYR // Symmetric rank-1 matrix update: SSYR/DSYR/HSYR
template <typename T> template <typename T>
@ -448,7 +448,7 @@ StatusCode Syr(const Layout layout, const Triangle triangle,
const T alpha, const T alpha,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Symmetric packed rank-1 matrix update: SSPR/DSPR/HSPR // Symmetric packed rank-1 matrix update: SSPR/DSPR/HSPR
template <typename T> template <typename T>
@ -457,7 +457,7 @@ StatusCode Spr(const Layout layout, const Triangle triangle,
const T alpha, const T alpha,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
CUdeviceptr ap_buffer, const size_t ap_offset, CUdeviceptr ap_buffer, const size_t ap_offset,
CUstream* stream); const CUcontext context, const CUdevice device);
// Symmetric rank-2 matrix update: SSYR2/DSYR2/HSYR2 // Symmetric rank-2 matrix update: SSYR2/DSYR2/HSYR2
template <typename T> template <typename T>
@ -467,7 +467,7 @@ StatusCode Syr2(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Symmetric packed rank-2 matrix update: SSPR2/DSPR2/HSPR2 // Symmetric packed rank-2 matrix update: SSPR2/DSPR2/HSPR2
template <typename T> template <typename T>
@ -477,7 +477,7 @@ StatusCode Spr2(const Layout layout, const Triangle triangle,
const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc, const CUdeviceptr x_buffer, const size_t x_offset, const size_t x_inc,
const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc, const CUdeviceptr y_buffer, const size_t y_offset, const size_t y_inc,
CUdeviceptr ap_buffer, const size_t ap_offset, CUdeviceptr ap_buffer, const size_t ap_offset,
CUstream* stream); const CUcontext context, const CUdevice device);
// ================================================================================================= // =================================================================================================
// BLAS level-3 (matrix-matrix) routines // BLAS level-3 (matrix-matrix) routines
@ -492,7 +492,7 @@ StatusCode Gemm(const Layout layout, const Transpose a_transpose, const Transpos
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const T beta, const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM // Symmetric matrix-matrix multiplication: SSYMM/DSYMM/CSYMM/ZSYMM/HSYMM
template <typename T> template <typename T>
@ -503,7 +503,7 @@ StatusCode Symm(const Layout layout, const Side side, const Triangle triangle,
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const T beta, const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Hermitian matrix-matrix multiplication: CHEMM/ZHEMM // Hermitian matrix-matrix multiplication: CHEMM/ZHEMM
template <typename T> template <typename T>
@ -514,7 +514,7 @@ StatusCode Hemm(const Layout layout, const Side side, const Triangle triangle,
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const T beta, const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK // Rank-K update of a symmetric matrix: SSYRK/DSYRK/CSYRK/ZSYRK/HSYRK
template <typename T> template <typename T>
@ -524,7 +524,7 @@ StatusCode Syrk(const Layout layout, const Triangle triangle, const Transpose a_
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
const T beta, const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Rank-K update of a hermitian matrix: CHERK/ZHERK // Rank-K update of a hermitian matrix: CHERK/ZHERK
template <typename T> template <typename T>
@ -534,7 +534,7 @@ StatusCode Herk(const Layout layout, const Triangle triangle, const Transpose a_
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
const T beta, const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K // Rank-2K update of a symmetric matrix: SSYR2K/DSYR2K/CSYR2K/ZSYR2K/HSYR2K
template <typename T> template <typename T>
@ -545,7 +545,7 @@ StatusCode Syr2k(const Layout layout, const Triangle triangle, const Transpose a
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const T beta, const T beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Rank-2K update of a hermitian matrix: CHER2K/ZHER2K // Rank-2K update of a hermitian matrix: CHER2K/ZHER2K
template <typename T, typename U> template <typename T, typename U>
@ -556,7 +556,7 @@ StatusCode Her2k(const Layout layout, const Triangle triangle, const Transpose a
const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, const CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
const U beta, const U beta,
CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld, CUdeviceptr c_buffer, const size_t c_offset, const size_t c_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM // Triangular matrix-matrix multiplication: STRMM/DTRMM/CTRMM/ZTRMM/HTRMM
template <typename T> template <typename T>
@ -565,7 +565,7 @@ StatusCode Trmm(const Layout layout, const Side side, const Triangle triangle, c
const T alpha, const T alpha,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM // Solves a triangular system of equations: STRSM/DTRSM/CTRSM/ZTRSM
template <typename T> template <typename T>
@ -574,7 +574,7 @@ StatusCode Trsm(const Layout layout, const Side side, const Triangle triangle, c
const T alpha, const T alpha,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// ================================================================================================= // =================================================================================================
// Extra non-BLAS routines (level-X) // Extra non-BLAS routines (level-X)
@ -587,14 +587,14 @@ StatusCode Omatcopy(const Layout layout, const Transpose a_transpose,
const T alpha, const T alpha,
const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld, const CUdeviceptr a_buffer, const size_t a_offset, const size_t a_ld,
CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld, CUdeviceptr b_buffer, const size_t b_offset, const size_t b_ld,
CUstream* stream); const CUcontext context, const CUdevice device);
// Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL // Im2col function (non-BLAS function): SIM2COL/DIM2COL/CIM2COL/ZIM2COL/HIM2COL
template <typename T> template <typename T>
StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, StatusCode Im2col(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w,
const CUdeviceptr im_buffer, const size_t im_offset, const CUdeviceptr im_buffer, const size_t im_offset,
CUdeviceptr col_buffer, const size_t col_offset, CUdeviceptr col_buffer, const size_t col_offset,
CUstream* stream); const CUcontext context, const CUdevice device);
// Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED // Batched version of AXPY: SAXPYBATCHED/DAXPYBATCHED/CAXPYBATCHED/ZAXPYBATCHED/HAXPYBATCHED
template <typename T> template <typename T>
@ -603,7 +603,7 @@ StatusCode AxpyBatched(const size_t n,
const CUdeviceptr x_buffer, const size_t *x_offsets, const size_t x_inc, const CUdeviceptr x_buffer, const size_t *x_offsets, const size_t x_inc,
CUdeviceptr y_buffer, const size_t *y_offsets, const size_t y_inc, CUdeviceptr y_buffer, const size_t *y_offsets, const size_t y_inc,
const size_t batch_count, const size_t batch_count,
CUstream* stream); const CUcontext context, const CUdevice device);
// Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED // Batched version of GEMM: SGEMMBATCHED/DGEMMBATCHED/CGEMMBATCHED/ZGEMMBATCHED/HGEMMBATCHED
template <typename T> template <typename T>
@ -615,7 +615,7 @@ StatusCode GemmBatched(const Layout layout, const Transpose a_transpose, const T
const T *betas, const T *betas,
CUdeviceptr c_buffer, const size_t *c_offsets, const size_t c_ld, CUdeviceptr c_buffer, const size_t *c_offsets, const size_t c_ld,
const size_t batch_count, const size_t batch_count,
CUstream* stream); const CUcontext context, const CUdevice device);
// ================================================================================================= // =================================================================================================

View File

@ -50,7 +50,12 @@ def clblast_cc(routine, cuda=False):
if routine.implemented: if routine.implemented:
result += routine.routine_header_cpp(12, "", cuda) + " {" + NL result += routine.routine_header_cpp(12, "", cuda) + " {" + NL
result += " try {" + NL result += " try {" + NL
result += " auto queue_cpp = Queue(*queue);" + NL if cuda:
result += " const auto context_cpp = Context(context);" + NL
result += " const auto device_cpp = Device(device);" + NL
result += " auto queue_cpp = Queue(context_cpp, device_cpp);" + NL
else:
result += " auto queue_cpp = Queue(*queue);" + NL
result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, event);" + NL result += " auto routine = X" + routine.plain_name() + "<" + routine.template.template + ">(queue_cpp, event);" + NL
if routine.batched: if routine.batched:
result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL result += " " + (NL + " ").join(routine.batched_transform_to_cpp()) + NL
@ -72,7 +77,7 @@ def clblast_cc(routine, cuda=False):
result += ("," + NL + indent2).join([a for a in arguments]) result += ("," + NL + indent2).join([a for a in arguments])
result += "," + NL + indent2 result += "," + NL + indent2
if cuda: if cuda:
result += "CUstream*" result += "const CUcontext, const CUdevice"
else: else:
result += "cl_command_queue*, cl_event*" result += "cl_command_queue*, cl_event*"
result += ");" + NL result += ");" + NL

View File

@ -813,7 +813,7 @@ class Routine:
result += (",\n" + indent).join([a for a in arguments]) result += (",\n" + indent).join([a for a in arguments])
result += ",\n" + indent result += ",\n" + indent
if cuda: if cuda:
result += "CUstream* stream" result += "const CUcontext context, const CUdevice device"
else: else:
result += "cl_command_queue* queue, cl_event* event" + default_event result += "cl_command_queue* queue, cl_event* event" + default_event
result += ")" result += ")"
@ -830,7 +830,7 @@ class Routine:
result += (",\n" + indent).join([a for a in arguments]) result += (",\n" + indent).join([a for a in arguments])
result += ",\n" + indent result += ",\n" + indent
if cuda: if cuda:
result += "CUstream* stream" result += "const CUcontext, const CUdevice"
else: else:
result += "cl_command_queue*, cl_event*" result += "cl_command_queue*, cl_event*"
result += ")" result += ")"

File diff suppressed because it is too large Load Diff

View File

@ -15,7 +15,7 @@
#ifndef CLBLAST_BUFFER_TEST_H_ #ifndef CLBLAST_BUFFER_TEST_H_
#define CLBLAST_BUFFER_TEST_H_ #define CLBLAST_BUFFER_TEST_H_
#include "utilities/utilities.hpp #include "utilities/utilities.hpp"
namespace clblast { namespace clblast {
// ================================================================================================= // =================================================================================================