diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index ac75f8e16..51445b5e7 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -8247,7 +8247,8 @@ static void clamp_f32(const float * x, float * dst, const float min, const float dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); } -static void im2col_f32_f16(const float *x, sycl::half *dst, int offset_delta, +template +static void im2col_kernel(const float *x, T *dst, int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW, int s0, int s1, int p0, int p1, int d0, int d1, @@ -11019,7 +11020,8 @@ static void soft_max_f32_sycl(const float *x, const float *y, float *dst, }); } -static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH, +template +static void im2col_sycl(const float *x, T *dst, int IW, int IH, int OW, int OH, int KW, int KH, int IC, int offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, @@ -11036,7 +11038,7 @@ static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH, sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - im2col_f32_f16(x, dst, offset_delta, IW, IH, OW, KW, KH, + im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1, item_ct1); }); @@ -12424,7 +12426,7 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0, GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; @@ -12447,8 +12449,11 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0, const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - im2col_f32_f16_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, - IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + if (dst->type == GGML_TYPE_F16) { + im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + } else { + im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); + } (void) src0; (void) src0_dd;