From 1117d06607d2d885640ac501f05f0aae5494e2c5 Mon Sep 17 00:00:00 2001 From: shibe2 Date: Wed, 18 Oct 2023 16:09:22 +0400 Subject: [PATCH] opencl : fix element-wise multiplication (#3656) --- ggml-opencl.cpp | 67 ++++++++++++++----------------------------------- 1 file changed, 19 insertions(+), 48 deletions(-) diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 22fd0e3a7..67ac20eac 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -1395,75 +1395,46 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; - const int64_t ne0 = ne00 * ne01 * ne02 * ne03; const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; const int64_t ne12 = src1->ne[2]; const int64_t ne13 = src1->ne[3]; - const int64_t nb10 = src1->nb[0]; const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; size_t x_size; size_t d_size; - cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0 + cl_mem d_X = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &x_size); // src0 cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted. - cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst + cl_mem d_D = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &d_size); // dst for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - const int i0 = i03*ne02 + i02; - cl_event ev; // copy src0 to device - CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, i0, src0, i03, i02, &ev)); + CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, &ev)); - if (nb10 == sizeof(float)) { - // Contiguous, avoid overhead from queueing many kernel runs - const int64_t i13 = i03%ne13; - const int64_t i12 = i02%ne12; - const int i1 = i13*ne12*ne11 + i12*ne11; + const int64_t i13 = i03%ne13; + const int64_t i12 = i02%ne12; + const int i1 = i13*ne12*ne11 + i12*ne11; - cl_int x_offset = 0; - cl_int y_offset = i1*ne10; - cl_int d_offset = 0; + cl_int x_offset = 0; + cl_int y_offset = i1*ne10; + cl_int d_offset = 0; - size_t global = ne00 * ne01; - cl_int ky = ne10; - CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky)); - CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL)); - } else { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const int64_t i13 = i03%ne13; - const int64_t i12 = i02%ne12; - const int64_t i11 = i01%ne11; - const int i1 = i13*ne12*ne11 + i12*ne11 + i11; + size_t global = ne00 * ne01; + cl_int ky = ne10 * ne11; - cl_int x_offset = i01*ne00; - cl_int y_offset = i1*ne10; - cl_int d_offset = i01*ne00; - - // compute - size_t global = ne00; - cl_int ky = ne10; - CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky)); - CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL)); - } - } + CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky)); + CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL)); CL_CHECK(clReleaseEvent(ev)); CL_CHECK(clFinish(queue));