Removed complex numbers support for CONVGEMM

pull/319/head
Cedric Nugteren 2018-07-29 10:37:14 +02:00
parent 5903820ba2
commit 2dd539f911
10 changed files with 8 additions and 82 deletions

View File

@ -3099,16 +3099,6 @@ CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, c
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastCconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastZconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event)
CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,

View File

@ -636,7 +636,7 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event = nullptr);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,

View File

@ -1410,7 +1410,7 @@ CLBlastStatusCode PUBLIC_API CLBlastHim2col(const size_t channels, const size_t
cl_mem col_buffer, const size_t col_offset,
cl_command_queue* queue, cl_event* event);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
CLBlastStatusCode PUBLIC_API CLBlastSconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
@ -1421,16 +1421,6 @@ CLBlastStatusCode PUBLIC_API CLBlastDconvgemm(const size_t channels, const size_
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastCconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastZconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event);
CLBlastStatusCode PUBLIC_API CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,

View File

@ -608,7 +608,7 @@ StatusCode Im2col(const size_t channels, const size_t height, const size_t width
CUdeviceptr col_buffer, const size_t col_offset,
const CUcontext context, const CUdevice device);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const CUdeviceptr im_buffer, const size_t im_offset,

View File

@ -181,7 +181,7 @@ ROUTINES = [
Routine(True, True, 0, False, "x", "had", T, [S,D,C,Z,H], ["n"], [], ["x","y"], ["z"], [xn,yn,zn], ["alpha","beta"], "", "Element-wise vector product (Hadamard)", "Performs the Hadamard element-wise product _z = alpha * x * y + beta * z_, in which _x_, _y_, and _z_ are vectors and _alpha_ and _beta_ are scalar constants.", []),
Routine(True, True, 0, False, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], [amn,bnma], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
Routine(True, True, 0, False, "x", "im2col", T, [S,D,C,Z,H], im2col_constants, [], ["im"], ["col"], [im,col], [""], "", "Im2col function (non-BLAS function)", "Performs the im2col algorithm, in which _im_ is the input matrix and _col_ is the output matrix.", []),
Routine(True, True, 0, False, "x", "convgemm", T, [S,D,C,Z,H], convgemm_constants, [], ["im","kernel"], ["result"], [imb,kernel,result],[""], "", "Batched convolution as GEMM (non-BLAS function)", "Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D input tensor (NCHW - batch-channelin-height-width), _kernel_ the 4D kernel weights tensor (KCHW - channelout-channelin-height-width), and _result_ the 4D output tensor (NCHW - batch-channelout-height-width).", []),
Routine(True, True, 0, False, "x", "convgemm", T, [S,D,H], convgemm_constants, [], ["im","kernel"], ["result"], [imb,kernel,result],[""], "", "Batched convolution as GEMM (non-BLAS function)", "Integrates im2col and GEMM for batched 3D convolution, in which _im_ is the 4D input tensor (NCHW - batch-channelin-height-width), _kernel_ the 4D kernel weights tensor (KCHW - channelout-channelin-height-width), and _result_ the 4D output tensor (NCHW - batch-channelout-height-width).", []),
# Batched routines:
Routine(True, True, 1, False, "x", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [xn,yn], ["alpha"], "", "Batched version of AXPY", "As AXPY, but multiple operations are batched together for better performance.", []),
Routine(True, True, 1, False, "x", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], [amk,bkn,cmn], ["alpha","beta"], "", "Batched version of GEMM", "As GEMM, but multiple operations are batched together for better performance.", [ald_transa_m_k, bld_transb_k_n, cld_m]),

View File

@ -2252,7 +2252,7 @@ template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const si
cl_mem, const size_t,
cl_command_queue*, cl_event*);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
@ -2279,16 +2279,6 @@ template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, cons
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API Convgemm<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API Convgemm<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,
cl_mem, const size_t,
cl_command_queue*, cl_event*);
template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const cl_mem, const size_t,
const cl_mem, const size_t,

View File

@ -3710,36 +3710,6 @@ CLBlastStatusCode CLBlastDconvgemm(const size_t channels, const size_t height, c
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
CLBlastStatusCode CLBlastCconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
clblast::Convgemm<float2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
im_buffer, im_offset,
kernel_buffer, kernel_offset,
result_buffer, result_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
CLBlastStatusCode CLBlastZconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,
cl_mem result_buffer, const size_t result_offset,
cl_command_queue* queue, cl_event* event) {
try {
return static_cast<CLBlastStatusCode>(
clblast::Convgemm<double2>(channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, num_kernels, batch_count,
im_buffer, im_offset,
kernel_buffer, kernel_offset,
result_buffer, result_offset,
queue, event)
);
} catch (...) { return static_cast<CLBlastStatusCode>(clblast::DispatchExceptionForC()); }
}
CLBlastStatusCode CLBlastHconvgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const cl_mem im_buffer, const size_t im_offset,
const cl_mem kernel_buffer, const size_t kernel_offset,

View File

@ -2350,7 +2350,7 @@ template StatusCode PUBLIC_API Im2col<half>(const size_t, const size_t, const si
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/CCONVGEMM/ZCONVGEMM/HCONVGEMM
// Batched convolution as GEMM (non-BLAS function): SCONVGEMM/DCONVGEMM/HCONVGEMM
template <typename T>
StatusCode Convgemm(const size_t channels, const size_t height, const size_t width, const size_t kernel_h, const size_t kernel_w, const size_t pad_h, const size_t pad_w, const size_t stride_h, const size_t stride_w, const size_t dilation_h, const size_t dilation_w, const size_t num_kernels, const size_t batch_count,
const CUdeviceptr im_buffer, const size_t im_offset,
@ -2379,16 +2379,6 @@ template StatusCode PUBLIC_API Convgemm<double>(const size_t, const size_t, cons
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Convgemm<float2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Convgemm<double2>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,
CUdeviceptr, const size_t,
const CUcontext, const CUdevice);
template StatusCode PUBLIC_API Convgemm<half>(const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t, const size_t,
const CUdeviceptr, const size_t,
const CUdeviceptr, const size_t,

View File

@ -17,8 +17,6 @@ int main(int argc, char *argv[]) {
auto errors = size_t{0};
errors += clblast::RunTests<clblast::TestXconvgemm<float>, float, float>(argc, argv, false, "SCONVGEMM");
errors += clblast::RunTests<clblast::TestXconvgemm<double>, double, double>(argc, argv, true, "DCONVGEMM");
errors += clblast::RunTests<clblast::TestXconvgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv, true, "CCONVGEMM");
errors += clblast::RunTests<clblast::TestXconvgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv, true, "ZCONVGEMM");
errors += clblast::RunTests<clblast::TestXconvgemm<clblast::half>, clblast::half, clblast::half>(argc, argv, true, "HCONVGEMM");
if (errors > 0) { return 1; } else { return 0; }
}

View File

@ -22,10 +22,8 @@ int main(int argc, char *argv[]) {
clblast::RunClient<clblast::TestXconvgemm<float>, float, float>(argc, argv); break;
case clblast::Precision::kDouble:
clblast::RunClient<clblast::TestXconvgemm<double>, double, double>(argc, argv); break;
case clblast::Precision::kComplexSingle:
clblast::RunClient<clblast::TestXconvgemm<clblast::float2>, clblast::float2, clblast::float2>(argc, argv); break;
case clblast::Precision::kComplexDouble:
clblast::RunClient<clblast::TestXconvgemm<clblast::double2>, clblast::double2, clblast::double2>(argc, argv); break;
case clblast::Precision::kComplexSingle: throw std::runtime_error("Unsupported precision mode");
case clblast::Precision::kComplexDouble: throw std::runtime_error("Unsupported precision mode");
}
return 0;
}