Fix im2col with 32fp (#5286)

This commit is contained in:
AidanBeltonS 2024-02-03 08:11:37 +00:00 committed by GitHub
parent 191221178f
commit a305dba8ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8247,7 +8247,8 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
} }
static void im2col_f32_f16(const float *x, sycl::half *dst, int offset_delta, template <typename T>
static void im2col_kernel(const float *x, T *dst, int offset_delta,
int IW, int IH, int OW, int KW, int KH, int IW, int IH, int OW, int KW, int KH,
int pelements, int CHW, int s0, int s1, int p0, int pelements, int CHW, int s0, int s1, int p0,
int p1, int d0, int d1, int p1, int d0, int d1,
@ -11019,7 +11020,8 @@ static void soft_max_f32_sycl(const float *x, const float *y, float *dst,
}); });
} }
static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH, template <typename T>
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
int OW, int OH, int KW, int KH, int IC, int OW, int OH, int KW, int KH, int IC,
int offset_delta, int s0, int s1, int p0, int offset_delta, int s0, int s1, int p0,
int p1, int d0, int d1, int p1, int d0, int d1,
@ -11036,7 +11038,7 @@ static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
im2col_f32_f16(x, dst, offset_delta, IW, IH, OW, KW, KH, im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
parallel_elements, (IC * KH * KW), s0, s1, p0, parallel_elements, (IC * KH * KW), s0, s1, p0,
p1, d0, d1, item_ct1); p1, d0, d1, item_ct1);
}); });
@ -12424,7 +12426,7 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@ -12447,8 +12449,11 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
im2col_f32_f16_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, if (dst->type == GGML_TYPE_F16) {
IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
} else {
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
}
(void) src0; (void) src0;
(void) src0_dd; (void) src0_dd;