Removed complex numbers support for CONVGEMM
parent
5903820ba2
commit
2dd539f911
10
doc/api.md
10
doc/api.md
|
@ -3099,16 +3099,6 @@ CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, c
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
const cl_mem kernel_buffer, const size_t kernel_offset,
|
||||||
cl_mem result_buffer, const size_t result_offset,
|
cl_mem result_buffer, const size_t result_offset,
|
||||||
cl_command_queue* queue, cl_event* event)
|
cl_command_queue* queue, cl_event* event)
|
||||||
CLBlastStatusCode CLBlastCconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
|
||||||
cl_mem result_buffer, const size_t result_offset,
|
|
||||||
cl_command_queue* queue, cl_event* event)
|
|
||||||
CLBlastStatusCode CLBlastZconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
|
||||||
cl_mem result_buffer, const size_t result_offset,
|
|
||||||
cl_command_queue* queue, cl_event* event)
|
|
||||||
CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
const cl_mem im_buffer, const size_t im_offset,
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
const cl_mem kernel_buffer, const size_t kernel_offset,
|
||||||
|
|
|
@ -636,7 +636,7 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width
|
||||||
cl_mem col_buffer, const size_t col_offset,
|
cl_mem col_buffer, const size_t col_offset,
|
||||||
cl_command_queue* queue, cl_event* event = nullptr);
|
cl_command_queue* queue, cl_event* event = nullptr);
|
||||||
|
|
||||||
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM
|
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
|
||||||
template <typename T>
|
template <typename T>
|
||||||
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
const cl_mem im_buffer, const size_t im_offset,
|
||||||
|
|
|
@ -1410,7 +1410,7 @@ CLBlastStatusCode PUBLIC_API CLBlastHim2col(const size_t channels, const size_t
|
||||||
cl_mem col_buffer, const size_t col_offset,
|
cl_mem col_buffer, const size_t col_offset,
|
||||||
cl_command_queue* queue, cl_event* event);
|
cl_command_queue* queue, cl_event* event);
|
||||||
|
|
||||||
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM
|
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
|
||||||
CLBlastStatusCode PUBLIC_API CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
CLBlastStatusCode PUBLIC_API CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
const cl_mem im_buffer, const size_t im_offset,
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
const cl_mem kernel_buffer, const size_t kernel_offset,
|
||||||
|
@ -1421,16 +1421,6 @@ CLBlastStatusCode PUBLIC_API CLBlastDconvgemm(const size_t channels, const size_
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
const cl_mem kernel_buffer, const size_t kernel_offset,
|
||||||
cl_mem result_buffer, const size_t result_offset,
|
cl_mem result_buffer, const size_t result_offset,
|
||||||
cl_command_queue* queue, cl_event* event);
|
cl_command_queue* queue, cl_event* event);
|
||||||
CLBlastStatusCode PUBLIC_API CLBlastCconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
|
||||||
cl_mem result_buffer, const size_t result_offset,
|
|
||||||
cl_command_queue* queue, cl_event* event);
|
|
||||||
CLBlastStatusCode PUBLIC_API CLBlastZconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
|
||||||
cl_mem result_buffer, const size_t result_offset,
|
|
||||||
cl_command_queue* queue, cl_event* event);
|
|
||||||
CLBlastStatusCode PUBLIC_API CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
CLBlastStatusCode PUBLIC_API CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
const cl_mem im_buffer, const size_t im_offset,
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
const cl_mem kernel_buffer, const size_t kernel_offset,
|
||||||
|
|
|
@ -608,7 +608,7 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width
|
||||||
CUdeviceptr col_buffer, const size_t col_offset,
|
CUdeviceptr col_buffer, const size_t col_offset,
|
||||||
const CUcontext context, const CUdevice device);
|
const CUcontext context, const CUdevice device);
|
||||||
|
|
||||||
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM
|
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
|
||||||
template <typename T>
|
template <typename T>
|
||||||
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
||||||
const CUdeviceptr im_buffer, const size_t im_offset,
|
const CUdeviceptr im_buffer, const size_t im_offset,
|
||||||
|
|
|
@ -181,7 +181,7 @@ ROUTINES = [
|
||||||
Routine(True, True, 0, False, "x", "had", T, [S,D,C,Z,H], ["n"], [], ["x","y"], ["z"], [xn,yn,zn], ["alpha","beta"], "", "Element-wise vector product (Hadamard)", "Performs the Hadamard element-wise product _z = alpha * x * y + beta * z_, in which _x_, _y_, and _z_ are vectors and _alpha_ and _beta_ are scalar constants.", []),
|
Routine(True, True, 0, False, "x", "had", T, [S,D,C,Z,H], ["n"], [], ["x","y"], ["z"], [xn,yn,zn], ["alpha","beta"], "", "Element-wise vector product (Hadamard)", "Performs the Hadamard element-wise product _z = alpha * x * y + beta * z_, in which _x_, _y_, and _z_ are vectors and _alpha_ and _beta_ are scalar constants.", []),
|
||||||
Routine(True, True, 0, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
|
Routine(True, True, 0, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
|
||||||
Routine(True, True, 0, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, [], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix.", []),
|
Routine(True, True, 0, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, [], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix.", []),
|
||||||
Routine(True, True, 0, False, "x", "convgemm", T, [S,D,C,Z,H], convgemm_constants, [], ["im","kernel"], ["result"], [imb,kernel,result],[""], "", "Batched convolution as GEMM (non-BLAS function)", "Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D input tensor (NCHW - batch-channelin-height-width), _kernel_ the 4D kernel weights tensor (KCHW - channelout-channelin-height-width), and _result_ the 4D output tensor (NCHW - batch-channelout-height-width).", []),
|
Routine(True, True, 0, False, "x", "convgemm", T, [S,D,H], convgemm_constants, [], ["im","kernel"], ["result"], [imb,kernel,result],[""], "", "Batched convolution as GEMM (non-BLAS function)", "Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D input tensor (NCHW - batch-channelin-height-width), _kernel_ the 4D kernel weights tensor (KCHW - channelout-channelin-height-width), and _result_ the 4D output tensor (NCHW - batch-channelout-height-width).", []),
|
||||||
# Batched routines:
|
# Batched routines:
|
||||||
Routine(True, True, 1, False, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
|
Routine(True, True, 1, False, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
|
||||||
Routine(True, True, 1, False, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
|
Routine(True, True, 1, False, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
|
||||||
|
|
|
@ -2252,7 +2252,7 @@ template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const si
|
||||||
cl_mem, const size_t,
|
cl_mem, const size_t,
|
||||||
cl_command_queue*, cl_event*);
|
cl_command_queue*, cl_event*);
|
||||||
|
|
||||||
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM
|
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
|
||||||
template <typename T>
|
template <typename T>
|
||||||
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
const cl_mem im_buffer, const size_t im_offset,
|
||||||
|
@ -2279,16 +2279,6 @@ template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, cons
|
||||||
const cl_mem, const size_t,
|
const cl_mem, const size_t,
|
||||||
cl_mem, const size_t,
|
cl_mem, const size_t,
|
||||||
cl_command_queue*, cl_event*);
|
cl_command_queue*, cl_event*);
|
||||||
template StatusCode PUBLIC_API Convgemm<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
|
|
||||||
const cl_mem, const size_t,
|
|
||||||
const cl_mem, const size_t,
|
|
||||||
cl_mem, const size_t,
|
|
||||||
cl_command_queue*, cl_event*);
|
|
||||||
template StatusCode PUBLIC_API Convgemm<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
|
|
||||||
const cl_mem, const size_t,
|
|
||||||
const cl_mem, const size_t,
|
|
||||||
cl_mem, const size_t,
|
|
||||||
cl_command_queue*, cl_event*);
|
|
||||||
template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
|
template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
|
||||||
const cl_mem, const size_t,
|
const cl_mem, const size_t,
|
||||||
const cl_mem, const size_t,
|
const cl_mem, const size_t,
|
||||||
|
|
|
@ -3710,36 +3710,6 @@ CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, c
|
||||||
);
|
);
|
||||||
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
||||||
}
|
}
|
||||||
CLBlastStatusCode CLBlastCconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
|
||||||
cl_mem result_buffer, const size_t result_offset,
|
|
||||||
cl_command_queue* queue, cl_event* event) {
|
|
||||||
try {
|
|
||||||
return static_cast<CLBlastStatusCode>(
|
|
||||||
clblast::Convgemm<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
|
|
||||||
im_buffer, im_offset,
|
|
||||||
kernel_buffer, kernel_offset,
|
|
||||||
result_buffer, result_offset,
|
|
||||||
queue, event)
|
|
||||||
);
|
|
||||||
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
|
||||||
}
|
|
||||||
CLBlastStatusCode CLBlastZconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
|
||||||
cl_mem result_buffer, const size_t result_offset,
|
|
||||||
cl_command_queue* queue, cl_event* event) {
|
|
||||||
try {
|
|
||||||
return static_cast<CLBlastStatusCode>(
|
|
||||||
clblast::Convgemm<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
|
|
||||||
im_buffer, im_offset,
|
|
||||||
kernel_buffer, kernel_offset,
|
|
||||||
result_buffer, result_offset,
|
|
||||||
queue, event)
|
|
||||||
);
|
|
||||||
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
|
|
||||||
}
|
|
||||||
CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
||||||
const cl_mem im_buffer, const size_t im_offset,
|
const cl_mem im_buffer, const size_t im_offset,
|
||||||
const cl_mem kernel_buffer, const size_t kernel_offset,
|
const cl_mem kernel_buffer, const size_t kernel_offset,
|
||||||
|
|
|
@ -2350,7 +2350,7 @@ template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const si
|
||||||
CUdeviceptr, const size_t,
|
CUdeviceptr, const size_t,
|
||||||
const CUcontext, const CUdevice);
|
const CUcontext, const CUdevice);
|
||||||
|
|
||||||
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM
|
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
|
||||||
template <typename T>
|
template <typename T>
|
||||||
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
|
||||||
const CUdeviceptr im_buffer, const size_t im_offset,
|
const CUdeviceptr im_buffer, const size_t im_offset,
|
||||||
|
@ -2379,16 +2379,6 @@ template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, cons
|
||||||
const CUdeviceptr, const size_t,
|
const CUdeviceptr, const size_t,
|
||||||
CUdeviceptr, const size_t,
|
CUdeviceptr, const size_t,
|
||||||
const CUcontext, const CUdevice);
|
const CUcontext, const CUdevice);
|
||||||
template StatusCode PUBLIC_API Convgemm<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
|
|
||||||
const CUdeviceptr, const size_t,
|
|
||||||
const CUdeviceptr, const size_t,
|
|
||||||
CUdeviceptr, const size_t,
|
|
||||||
const CUcontext, const CUdevice);
|
|
||||||
template StatusCode PUBLIC_API Convgemm<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
|
|
||||||
const CUdeviceptr, const size_t,
|
|
||||||
const CUdeviceptr, const size_t,
|
|
||||||
CUdeviceptr, const size_t,
|
|
||||||
const CUcontext, const CUdevice);
|
|
||||||
template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
|
template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
|
||||||
const CUdeviceptr, const size_t,
|
const CUdeviceptr, const size_t,
|
||||||
const CUdeviceptr, const size_t,
|
const CUdeviceptr, const size_t,
|
||||||
|
|
|
@ -17,8 +17,6 @@ int main(int argc, char *argv[]) {
|
||||||
auto errors = size_t{0};
|
auto errors = size_t{0};
|
||||||
errors += clblast::RunTests<clblast::TestXconvgemm<float>, float, float>(argc, argv, false, "SCONVGEMM");
|
errors += clblast::RunTests<clblast::TestXconvgemm<float>, float, float>(argc, argv, false, "SCONVGEMM");
|
||||||
errors += clblast::RunTests<clblast::TestXconvgemm<double>, double, double>(argc, argv, true, "DCONVGEMM");
|
errors += clblast::RunTests<clblast::TestXconvgemm<double>, double, double>(argc, argv, true, "DCONVGEMM");
|
||||||
errors += clblast::RunTests<clblast::TestXconvgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CCONVGEMM");
|
|
||||||
errors += clblast::RunTests<clblast::TestXconvgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZCONVGEMM");
|
|
||||||
errors += clblast::RunTests<clblast::TestXconvgemm<clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HCONVGEMM");
|
errors += clblast::RunTests<clblast::TestXconvgemm<clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HCONVGEMM");
|
||||||
if (errors > 0) { return 1; } else { return 0; }
|
if (errors > 0) { return 1; } else { return 0; }
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,10 +22,8 @@ int main(int argc, char *argv[]) {
|
||||||
clblast::RunClient<clblast::TestXconvgemm<float>, float, float>(argc, argv); break;
|
clblast::RunClient<clblast::TestXconvgemm<float>, float, float>(argc, argv); break;
|
||||||
case clblast::Precision::kDouble:
|
case clblast::Precision::kDouble:
|
||||||
clblast::RunClient<clblast::TestXconvgemm<double>, double, double>(argc, argv); break;
|
clblast::RunClient<clblast::TestXconvgemm<double>, double, double>(argc, argv); break;
|
||||||
case clblast::Precision::kComplexSingle:
|
case clblast::Precision::kComplexSingle: throw std::runtime_error("Unsupported precision mode");
|
||||||
clblast::RunClient<clblast::TestXconvgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv); break;
|
case clblast::Precision::kComplexDouble: throw std::runtime_error("Unsupported precision mode");
|
||||||
case clblast::Precision::kComplexDouble:
|
|
||||||
clblast::RunClient<clblast::TestXconvgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv); break;
|
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue