Merge pull request #330 from vbkaisetsu/CLBlast-270-col2im

Add col2im function
pull/331/head
Cedric Nugteren 2018-10-31 10:37:21 +01:00 committed by GitHub
commit 4215bbe62a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 109 additions and 33 deletions

View File

@ -24,6 +24,10 @@ R"(
// =================================================================================================
inline int grid_ceil(const int x, const int step) {
return x > 0 ? ((x - 1) / step + 1) * step : x / step * step;
}
__kernel __attribute__((reqd_work_group_size(COPY_DIMX, COPY_DIMY, 1)))
void col2im(const int input_h, const int input_w, const int channels,
const int output_h, const int output_w,
@ -31,38 +35,54 @@ void col2im(const int input_h, const int input_w, const int channels,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int stride_bez_h, const int stride_bez_w,
const int dilation_bez_h, const int dilation_bez_w,
const int gcd_h, const int gcd_w,
const __global real* restrict col_buffer, const int col_offset,
__global real *im_buffer, const int im_offset) {
const int x_x = get_global_id(0) + pad_w;
const int x_y = ((int) get_global_id(1)) % input_h + pad_h;
const int channel = ((int) get_global_id(1)) / input_h;
const int kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
const int kernel_extent_h = (kernel_h - 1) * dilation_h + 1;
const int col_channel_shift = channel * kernel_w * kernel_h * output_h * output_w + col_offset;
const int x_channel_shift = channel * input_h * input_w + im_offset;
const int t_y_begin = (x_y < kernel_extent_h) ? 0 : (x_y - kernel_extent_h) / stride_h + 1;
const int t_y_end = min(x_y / stride_h + 1, output_h);
const int t_x_begin = (x_x < kernel_extent_w) ? 0 : (x_x - kernel_extent_w) / stride_w + 1;
const int t_x_end = min(x_x / stride_w + 1, output_w);
__global real* im_buffer, const int im_offset) {
if (x_x < input_w + pad_w && channel < channels) {
const int input_h_scaled = (input_h - 1) / gcd_h + 1;
// Thread IDs
const int gcd_scale_w = get_global_id(0) + (pad_w - 1) / gcd_w + 1;
const int gcd_scale_h = ((int) get_global_id(1)) % input_h_scaled + (pad_h - 1) / gcd_h + 1;
const int c_id = ((int) get_global_id(1)) / input_h_scaled;
const int w_index = gcd_scale_w * gcd_w - pad_w;
const int h_index = gcd_scale_h * gcd_h - pad_h;
const int th_step = stride_h * dilation_h / gcd_h;
const int th_begin = grid_ceil(max(-stride_bez_h * gcd_scale_h * stride_h,
(dilation_bez_h * gcd_scale_h - kernel_h + 1) * dilation_h),
th_step);
const int th_end = min((output_h - stride_bez_h * gcd_scale_h) * stride_h,
(dilation_bez_h * gcd_scale_h + 1) * dilation_h);
const int tw_step = stride_w * dilation_w / gcd_w;
const int tw_begin = grid_ceil(max(-stride_bez_w * gcd_scale_w * stride_w,
(dilation_bez_w * gcd_scale_w - kernel_w + 1) * dilation_w),
tw_step);
const int tw_end = min((output_w - stride_bez_w * gcd_scale_w) * stride_w,
(dilation_bez_w * gcd_scale_w + 1) * dilation_w);
if (w_index < input_w && c_id < channels) {
real val;
SetToZero(val);
for (int t_y = t_y_begin; t_y < t_y_end; ++t_y) {
for (int t_x = t_x_begin; t_x < t_x_end; ++t_x) {
int w_y = x_y - t_y * stride_h;
int w_x = x_x - t_x * stride_w;
if (w_y % dilation_h == 0 && w_x % dilation_w == 0) {
w_y /= dilation_h;
w_x /= dilation_w;
val += col_buffer[col_channel_shift
+ (w_x + w_y * kernel_w) * output_h * output_w
+ t_y * output_w
+ t_x];
}
for (int th = th_begin; th < th_end; th += th_step) {
for (int tw = tw_begin; tw < tw_end; tw += tw_step) {
const int kh_id = -th / dilation_h + dilation_bez_h * gcd_scale_h;
const int kw_id = -tw / dilation_w + dilation_bez_w * gcd_scale_w;
const int h_id = th / stride_h + stride_bez_h * gcd_scale_h;
const int w_id = tw / stride_w + stride_bez_w * gcd_scale_w;
const int kernel_index = kw_id + kernel_w * kh_id;
const int patch_index = w_id + output_w * h_id;
const int output_index = patch_index + kernel_index * output_w * output_h +
c_id * output_w * output_h * kernel_h * kernel_w;
Add(val, val, col_buffer[output_index + col_offset]);
}
}
im_buffer[x_channel_shift + (x_y - pad_h) * input_w + x_x - pad_w] = val;
// Sets the input value
const int input_index = w_index + input_w * (h_index + input_h * c_id);
im_buffer[input_index + im_offset] = val;
}
}

View File

@ -49,6 +49,15 @@ void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size
const auto padding_w = dilation_w * (kernel_w - 1) + 1;
const auto col_w = (size_w >= padding_w) ? (size_w - padding_w) / stride_w + 1 : 1;
int stride_bez_h = 0;
int stride_bez_w = 0;
int dilation_bez_h = 0;
int dilation_bez_w = 0;
int gcd_h = 0;
int gcd_w = 0;
EuclidGCD(static_cast<int>(stride_h), static_cast<int>(dilation_h), stride_bez_h, dilation_bez_h, gcd_h);
EuclidGCD(static_cast<int>(stride_w), static_cast<int>(dilation_w), stride_bez_w, dilation_bez_w, gcd_w);
// Retrieves the kernel from the compiled binary
auto kernel = Kernel(program_, "col2im");
@ -66,14 +75,20 @@ void Xcol2im<T>::DoCol2im(const size_t channels, const size_t height, const size
kernel.SetArgument(10, static_cast<int>(stride_w));
kernel.SetArgument(11, static_cast<int>(dilation_h));
kernel.SetArgument(12, static_cast<int>(dilation_w));
kernel.SetArgument(13, col_buffer());
kernel.SetArgument(14, static_cast<int>(col_offset));
kernel.SetArgument(15, im_buffer());
kernel.SetArgument(16, static_cast<int>(im_offset));
kernel.SetArgument(13, stride_bez_h);
kernel.SetArgument(14, stride_bez_w);
kernel.SetArgument(15, dilation_bez_h);
kernel.SetArgument(16, dilation_bez_w);
kernel.SetArgument(17, gcd_h);
kernel.SetArgument(18, gcd_w);
kernel.SetArgument(19, col_buffer());
kernel.SetArgument(20, static_cast<int>(col_offset));
kernel.SetArgument(21, im_buffer());
kernel.SetArgument(22, static_cast<int>(im_offset));
// Launches the kernel
const auto w_ceiled = Ceil(col_w, db_["COPY_DIMX"]);
const auto h_ceiled = Ceil(col_h, db_["COPY_DIMY"]);
const auto w_ceiled = Ceil((width - 1) / gcd_w + 1, db_["COPY_DIMX"]);
const auto h_ceiled = Ceil((height - 1) / gcd_h + 1, db_["COPY_DIMY"]);
const auto global = std::vector<size_t>{w_ceiled, h_ceiled * channels};
const auto local = std::vector<size_t>{db_["COPY_DIMX"], db_["COPY_DIMY"]};
RunKernel(kernel, queue_, device_, global, local, event_);

View File

@ -488,5 +488,31 @@ std::string GetDeviceName(const Device& device) {
return device_name;
}
// =================================================================================================
// Solve Bezout's identity
// a * p + b * q = r = GCD(a, b)
void EuclidGCD(int a, int b, int &p, int &q, int &r) {
p = 0;
q = 1;
int p_1 = 1;
int q_1 = 0;
for (;;) {
const int c = a % b;
if (c == 0) {
break;
}
const int p_2 = p_1;
const int q_2 = q_1;
p_1 = p;
q_1 = q;
p = p_2 - p_1 * (a / b);
q = q_2 - q_1 * (a / b);
a = b;
b = c;
}
r = b;
}
// =================================================================================================
} // namespace clblast

View File

@ -371,6 +371,12 @@ std::string GetDeviceVendor(const Device& device);
std::string GetDeviceArchitecture(const Device& device);
std::string GetDeviceName(const Device& device);
// =================================================================================================
// Solve Bezout's identity
// a * p + b * q = r = GCD(a, b)
void EuclidGCD(int a, int b, int &p, int &q, int &r);
// =================================================================================================
} // namespace clblast

View File

@ -159,6 +159,15 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
// Reference taken from im2col but swapped the input/output
const auto col_h = TestXcol2im<T>::ColHeight(args);
const auto col_w = TestXcol2im<T>::ColWidth(args);
for (auto c_id = size_t{0}; c_id < args.channels; ++c_id) {
for (auto h_index = size_t{0}; h_index < args.height; ++h_index) {
for (auto w_index = size_t{0}; w_index < args.width; ++w_index) {
const auto im_index = w_index + args.width * (h_index + args.height * c_id);
buffers_host.a_mat[im_index + args.a_offset] = 0;
}
}
}
for (auto c_id = size_t{0}; c_id < args.channels; ++c_id) { // image channels
for (auto kh_id = size_t{0}; kh_id < args.kernel_h; ++kh_id) { // kernel height
for (auto kw_id = size_t{0}; kw_id < args.kernel_w; ++kw_id) { // kernel width
@ -178,7 +187,7 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
if (h_index >= 0 && h_index < args.height &&
w_index >= 0 && w_index < args.width) {
const auto im_index = w_index + args.width * (h_index + args.height * c_id);
buffers_host.a_mat[im_index + args.a_offset] = val;
buffers_host.a_mat[im_index + args.a_offset] += val;
}
}
}