mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-07-04 21:36:57 +02:00
Implemented proper im2col reference function and completd tests
This commit is contained in:
parent
777681dcbd
commit
132e62892d
|
@ -57,6 +57,7 @@ class TestBlas: public Tester<T,U> {
|
|||
static const std::vector<size_t> kMatrixVectorDims;
|
||||
static const std::vector<size_t> kBandSizes;
|
||||
static const std::vector<size_t> kPadSizes;
|
||||
static const std::vector<size_t> kDilationSizes;
|
||||
static const std::vector<size_t> kKernelSizes;
|
||||
static const std::vector<size_t> kBatchCounts;
|
||||
const std::vector<size_t> kOffsets;
|
||||
|
@ -132,7 +133,8 @@ template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kMatr
|
|||
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kMatrixVectorDims = { 61, 256 };
|
||||
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kBandSizes = { 4, 19 };
|
||||
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kBatchCounts = { 1, 3 };
|
||||
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kPadSizes = { 0 };
|
||||
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kPadSizes = { 0, 1 };
|
||||
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kDilationSizes = { 1, 2 };
|
||||
template <typename T, typename U> const std::vector<size_t> TestBlas<T,U>::kKernelSizes = { 1, 3 };
|
||||
|
||||
// Test settings for the invalid tests
|
||||
|
@ -282,7 +284,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
|
|||
if (option == kArgImaxOffset) { imax_offsets = tester.kOffsets; }
|
||||
if (option == kArgAlpha) { alphas = tester.kAlphaValues; }
|
||||
if (option == kArgBeta) { betas = tester.kBetaValues; }
|
||||
if (option == kArgChannels) { channelss = tester.kMatrixDims; }
|
||||
if (option == kArgChannels) { channelss = tester.kKernelSizes; }
|
||||
if (option == kArgHeight) { heights = tester.kMatrixDims; }
|
||||
if (option == kArgWidth) { widths = tester.kMatrixDims; }
|
||||
if (option == kArgKernelH) { kernel_hs = tester.kKernelSizes; }
|
||||
|
@ -291,8 +293,8 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na
|
|||
if (option == kArgPadW) { pad_ws = tester.kPadSizes; }
|
||||
if (option == kArgStrideH) { stride_hs = tester.kKernelSizes; }
|
||||
if (option == kArgStrideW) { stride_ws = tester.kKernelSizes; }
|
||||
if (option == kArgDilationH) { dilation_hs = tester.kKernelSizes; }
|
||||
if (option == kArgDilationW) { dilation_ws = tester.kKernelSizes; }
|
||||
if (option == kArgDilationH) { dilation_hs = tester.kDilationSizes; }
|
||||
if (option == kArgDilationW) { dilation_ws = tester.kDilationSizes; }
|
||||
if (option == kArgBatchCount) { batch_counts = tester.kBatchCounts; }
|
||||
|
||||
if (option == kArgXOffset) { x_sizes = tester.kVecSizes; }
|
||||
|
|
|
@ -371,6 +371,12 @@ std::string Tester<T,U>::GetOptionsString(const Arguments<U> &args) {
|
|||
if (o == kArgWidth) { result += kArgWidth + equals + ToString(args.width) + " "; }
|
||||
if (o == kArgKernelH) { result += kArgKernelH + equals + ToString(args.kernel_h) + " "; }
|
||||
if (o == kArgKernelW) { result += kArgKernelW + equals + ToString(args.kernel_w) + " "; }
|
||||
if (o == kArgPadH) { result += kArgPadH + equals + ToString(args.pad_h) + " "; }
|
||||
if (o == kArgPadW) { result += kArgPadW + equals + ToString(args.pad_w) + " "; }
|
||||
if (o == kArgStrideH) { result += kArgStrideH + equals + ToString(args.stride_h) + " "; }
|
||||
if (o == kArgStrideW) { result += kArgStrideW + equals + ToString(args.stride_w) + " "; }
|
||||
if (o == kArgDilationH){ result += kArgDilationH + equals + ToString(args.dilation_h) + " "; }
|
||||
if (o == kArgDilationW){ result += kArgDilationW + equals + ToString(args.dilation_w) + " "; }
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -21,38 +21,6 @@
|
|||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
template <typename T>
|
||||
StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host) {
|
||||
for (auto c_id = size_t{0}; c_id < args.channels; ++c_id) { // input channels
|
||||
for (auto kh_id = size_t{0}; kh_id < args.kernel_h; ++kh_id) { // kernel height
|
||||
for (auto kw_id = size_t{0}; kw_id < args.kernel_w; ++kw_id) { // kernel width
|
||||
for (auto h_id = size_t{0}; h_id < args.height; h_id += args.stride_h) { // image height
|
||||
for (auto w_id = size_t{0}; w_id < args.width; w_id += args.stride_w) { // image width
|
||||
|
||||
// Retrieves the input value
|
||||
const auto h_index = -args.pad_h + kh_id * args.dilation_h + h_id;
|
||||
const auto w_index = -args.pad_w + kw_id * args.dilation_w + w_id;
|
||||
auto val = ConstantZero<T>();
|
||||
if (h_index < args.height && w_index < args.width) {
|
||||
const auto input_index = w_index + args.width * (h_index + args.height * c_id);
|
||||
val = buffers_host.a_mat[input_index + args.a_offset];
|
||||
}
|
||||
|
||||
// Sets the output value
|
||||
const auto kernel_index = kw_id + args.kernel_w * kh_id;
|
||||
const auto patch_index = w_id + ((args.width / args.stride_w) * h_id + (args.height / args.stride_h) * c_id);
|
||||
const auto output_index = kernel_index + args.kernel_h * args.kernel_w * patch_index;
|
||||
buffers_host.b_mat[output_index + args.b_offset] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return StatusCode::kSuccess;
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// See comment at top of file for a description of the class
|
||||
template <typename T>
|
||||
class TestXim2col {
|
||||
|
@ -71,8 +39,20 @@ public:
|
|||
static std::vector<std::string> BuffersOut() { return {kBufMatB}; }
|
||||
|
||||
// Describes how to obtain the sizes of the buffers
|
||||
static size_t OutputHeight(const Arguments<T> &args) {
|
||||
const auto size = args.height + 2 * args.pad_h;
|
||||
const auto padding = args.dilation_h * (args.kernel_h - 1) + 1;
|
||||
if (size >= padding) { return (size - padding) / args.stride_h + 1; }
|
||||
return 1;
|
||||
}
|
||||
static size_t OutputWidth(const Arguments<T> &args) {
|
||||
const auto size = args.width + 2 * args.pad_w;
|
||||
const auto padding = args.dilation_w * (args.kernel_w - 1) + 1;
|
||||
if (size >= padding) { return (size - padding) / args.stride_w + 1; }
|
||||
return 1;
|
||||
}
|
||||
static size_t NumPatches(const Arguments<T> &args) {
|
||||
return (args.width / args.stride_w) * (args.height / args.stride_h) * args.channels;
|
||||
return OutputHeight(args) * OutputWidth(args) * args.channels;
|
||||
}
|
||||
static size_t GetSizeA(const Arguments<T> &args) {
|
||||
return args.height * args.width * args.channels + args.a_offset;
|
||||
|
@ -158,6 +138,42 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
template <typename T>
|
||||
StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host) {
|
||||
const auto output_h = TestXim2col<T>::OutputHeight(args);
|
||||
const auto output_w = TestXim2col<T>::OutputWidth(args);
|
||||
for (auto c_id = size_t{0}; c_id < args.channels; ++c_id) { // input channels
|
||||
for (auto kh_id = size_t{0}; kh_id < args.kernel_h; ++kh_id) { // kernel height
|
||||
for (auto kw_id = size_t{0}; kw_id < args.kernel_w; ++kw_id) { // kernel width
|
||||
for (auto h_id = size_t{0}; h_id < output_h; ++h_id) { // image height
|
||||
for (auto w_id = size_t{0}; w_id < output_w; ++w_id) { // image width
|
||||
|
||||
// Retrieves the input value
|
||||
const auto h_index = -args.pad_h + kh_id * args.dilation_h + args.stride_h * h_id;
|
||||
const auto w_index = -args.pad_w + kw_id * args.dilation_w + args.stride_w * w_id;
|
||||
auto val = ConstantZero<T>();
|
||||
if (h_index >= 0 && h_index < args.height &&
|
||||
w_index >= 0 && w_index < args.width) {
|
||||
const auto input_index = w_index + args.width * (h_index + args.height * c_id);
|
||||
val = buffers_host.a_mat[input_index + args.a_offset];
|
||||
}
|
||||
|
||||
// Sets the output value
|
||||
const auto kernel_index = kw_id + args.kernel_w * kh_id;
|
||||
const auto patch_index = w_id + output_w * h_id;
|
||||
const auto output_index = patch_index + kernel_index * output_w * output_h +
|
||||
c_id * output_w * output_h * args.kernel_h * args.kernel_w;
|
||||
buffers_host.b_mat[output_index + args.b_offset] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return StatusCode ::kSuccess;
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
} // namespace clblast
|
||||
|
||||
|
|
Loading…
Reference in a new issue