diff --git a/test/correctness/testblas.hpp b/test/correctness/testblas.hpp index 1c0cf9e3..4e02fd28 100644 --- a/test/correctness/testblas.hpp +++ b/test/correctness/testblas.hpp @@ -57,6 +57,7 @@ class TestBlas: public Tester { static const std::vector kMatrixVectorDims; static const std::vector kBandSizes; static const std::vector kPadSizes; + static const std::vector kDilationSizes; static const std::vector kKernelSizes; static const std::vector kBatchCounts; const std::vector kOffsets; @@ -132,7 +133,8 @@ template const std::vector TestBlas::kMatr template const std::vector TestBlas::kMatrixVectorDims = { 61, 256 }; template const std::vector TestBlas::kBandSizes = { 4, 19 }; template const std::vector TestBlas::kBatchCounts = { 1, 3 }; -template const std::vector TestBlas::kPadSizes = { 0 }; +template const std::vector TestBlas::kPadSizes = { 0, 1 }; +template const std::vector TestBlas::kDilationSizes = { 1, 2 }; template const std::vector TestBlas::kKernelSizes = { 1, 3 }; // Test settings for the invalid tests @@ -282,7 +284,7 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na if (option == kArgImaxOffset) { imax_offsets = tester.kOffsets; } if (option == kArgAlpha) { alphas = tester.kAlphaValues; } if (option == kArgBeta) { betas = tester.kBetaValues; } - if (option == kArgChannels) { channelss = tester.kMatrixDims; } + if (option == kArgChannels) { channelss = tester.kKernelSizes; } if (option == kArgHeight) { heights = tester.kMatrixDims; } if (option == kArgWidth) { widths = tester.kMatrixDims; } if (option == kArgKernelH) { kernel_hs = tester.kKernelSizes; } @@ -291,8 +293,8 @@ size_t RunTests(int argc, char *argv[], const bool silent, const std::string &na if (option == kArgPadW) { pad_ws = tester.kPadSizes; } if (option == kArgStrideH) { stride_hs = tester.kKernelSizes; } if (option == kArgStrideW) { stride_ws = tester.kKernelSizes; } - if (option == kArgDilationH) { dilation_hs = tester.kKernelSizes; } - if (option == kArgDilationW) { dilation_ws = tester.kKernelSizes; } + if (option == kArgDilationH) { dilation_hs = tester.kDilationSizes; } + if (option == kArgDilationW) { dilation_ws = tester.kDilationSizes; } if (option == kArgBatchCount) { batch_counts = tester.kBatchCounts; } if (option == kArgXOffset) { x_sizes = tester.kVecSizes; } diff --git a/test/correctness/tester.cpp b/test/correctness/tester.cpp index 648aef6e..9dbd8934 100644 --- a/test/correctness/tester.cpp +++ b/test/correctness/tester.cpp @@ -371,6 +371,12 @@ std::string Tester::GetOptionsString(const Arguments &args) { if (o == kArgWidth) { result += kArgWidth + equals + ToString(args.width) + " "; } if (o == kArgKernelH) { result += kArgKernelH + equals + ToString(args.kernel_h) + " "; } if (o == kArgKernelW) { result += kArgKernelW + equals + ToString(args.kernel_w) + " "; } + if (o == kArgPadH) { result += kArgPadH + equals + ToString(args.pad_h) + " "; } + if (o == kArgPadW) { result += kArgPadW + equals + ToString(args.pad_w) + " "; } + if (o == kArgStrideH) { result += kArgStrideH + equals + ToString(args.stride_h) + " "; } + if (o == kArgStrideW) { result += kArgStrideW + equals + ToString(args.stride_w) + " "; } + if (o == kArgDilationH){ result += kArgDilationH + equals + ToString(args.dilation_h) + " "; } + if (o == kArgDilationW){ result += kArgDilationW + equals + ToString(args.dilation_w) + " "; } } return result; } diff --git a/test/routines/levelx/xim2col.hpp b/test/routines/levelx/xim2col.hpp index e5bc56cd..e6aefd9e 100644 --- a/test/routines/levelx/xim2col.hpp +++ b/test/routines/levelx/xim2col.hpp @@ -21,38 +21,6 @@ namespace clblast { // ================================================================================================= -template -StatusCode RunReference(const Arguments &args, BuffersHost &buffers_host) { - for (auto c_id = size_t{0}; c_id < args.channels; ++c_id) { // input channels - for (auto kh_id = size_t{0}; kh_id < args.kernel_h; ++kh_id) { // kernel height - for (auto kw_id = size_t{0}; kw_id < args.kernel_w; ++kw_id) { // kernel width - for (auto h_id = size_t{0}; h_id < args.height; h_id += args.stride_h) { // image height - for (auto w_id = size_t{0}; w_id < args.width; w_id += args.stride_w) { // image width - - // Retrieves the input value - const auto h_index = -args.pad_h + kh_id * args.dilation_h + h_id; - const auto w_index = -args.pad_w + kw_id * args.dilation_w + w_id; - auto val = ConstantZero(); - if (h_index < args.height && w_index < args.width) { - const auto input_index = w_index + args.width * (h_index + args.height * c_id); - val = buffers_host.a_mat[input_index + args.a_offset]; - } - - // Sets the output value - const auto kernel_index = kw_id + args.kernel_w * kh_id; - const auto patch_index = w_id + ((args.width / args.stride_w) * h_id + (args.height / args.stride_h) * c_id); - const auto output_index = kernel_index + args.kernel_h * args.kernel_w * patch_index; - buffers_host.b_mat[output_index + args.b_offset] = val; - } - } - } - } - } - return StatusCode::kSuccess; -} - -// ================================================================================================= - // See comment at top of file for a description of the class template class TestXim2col { @@ -71,8 +39,20 @@ public: static std::vector BuffersOut() { return {kBufMatB}; } // Describes how to obtain the sizes of the buffers + static size_t OutputHeight(const Arguments &args) { + const auto size = args.height + 2 * args.pad_h; + const auto padding = args.dilation_h * (args.kernel_h - 1) + 1; + if (size >= padding) { return (size - padding) / args.stride_h + 1; } + return 1; + } + static size_t OutputWidth(const Arguments &args) { + const auto size = args.width + 2 * args.pad_w; + const auto padding = args.dilation_w * (args.kernel_w - 1) + 1; + if (size >= padding) { return (size - padding) / args.stride_w + 1; } + return 1; + } static size_t NumPatches(const Arguments &args) { - return (args.width / args.stride_w) * (args.height / args.stride_h) * args.channels; + return OutputHeight(args) * OutputWidth(args) * args.channels; } static size_t GetSizeA(const Arguments &args) { return args.height * args.width * args.channels + args.a_offset; @@ -158,6 +138,42 @@ public: } }; +// ================================================================================================= + +template +StatusCode RunReference(const Arguments &args, BuffersHost &buffers_host) { + const auto output_h = TestXim2col::OutputHeight(args); + const auto output_w = TestXim2col::OutputWidth(args); + for (auto c_id = size_t{0}; c_id < args.channels; ++c_id) { // input channels + for (auto kh_id = size_t{0}; kh_id < args.kernel_h; ++kh_id) { // kernel height + for (auto kw_id = size_t{0}; kw_id < args.kernel_w; ++kw_id) { // kernel width + for (auto h_id = size_t{0}; h_id < output_h; ++h_id) { // image height + for (auto w_id = size_t{0}; w_id < output_w; ++w_id) { // image width + + // Retrieves the input value + const auto h_index = -args.pad_h + kh_id * args.dilation_h + args.stride_h * h_id; + const auto w_index = -args.pad_w + kw_id * args.dilation_w + args.stride_w * w_id; + auto val = ConstantZero(); + if (h_index >= 0 && h_index < args.height && + w_index >= 0 && w_index < args.width) { + const auto input_index = w_index + args.width * (h_index + args.height * c_id); + val = buffers_host.a_mat[input_index + args.a_offset]; + } + + // Sets the output value + const auto kernel_index = kw_id + args.kernel_w * kh_id; + const auto patch_index = w_id + output_w * h_id; + const auto output_index = patch_index + kernel_index * output_w * output_h + + c_id * output_w * output_h * args.kernel_h * args.kernel_w; + buffers_host.b_mat[output_index + args.b_offset] = val; + } + } + } + } + } + return StatusCode ::kSuccess; +} + // ================================================================================================= } // namespace clblast