From 66841fdb0eaf0cc210757cc7f683d0f4eebadc21 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 13 May 2023 16:48:03 +0300 Subject: [PATCH] ggml : multi-thread mul and diag_mask ops (#1428) --- ggml.c | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/ggml.c b/ggml.c index 057463839..e5b3528d8 100644 --- a/ggml.c +++ b/ggml.c @@ -7765,12 +7765,13 @@ static void ggml_compute_forward_mul_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(params->ith == 0); assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } + const int ith = params->ith; + const int nth = params->nth; const int nr = ggml_nrows(src0); const int64_t ne0 = src0->ne[0]; @@ -7796,7 +7797,7 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT(nb00 == sizeof(float)); if (nb10 == sizeof(float)) { - for (int ir = 0; ir < nr; ++ir) { + for (int ir = ith; ir < nr; ir += nth) { // src0, src1 and dst are same shape => same indices const int i3 = ir/(ne2*ne1); const int i2 = (ir - i3*ne2*ne1)/ne1; @@ -7822,7 +7823,7 @@ static void ggml_compute_forward_mul_f32( } } else { // src1 is not contiguous - for (int ir = 0; ir < nr; ++ir) { + for (int ir = ith; ir < nr; ir += nth) { // src0, src1 and dst are same shape => same indices const int i3 = ir/(ne2*ne1); const int i2 = (ir - i3*ne2*ne1)/ne1; @@ -10317,7 +10318,6 @@ static void ggml_compute_forward_diag_mask_f32( const struct ggml_tensor * src1, struct ggml_tensor * dst, const float value) { - assert(params->ith == 0); assert(src1->type == GGML_TYPE_I32); assert(ggml_nelements(src1) == 2); @@ -10325,6 +10325,9 @@ static void ggml_compute_forward_diag_mask_f32( return; } + const int ith = params->ith; + const int nth = params->nth; + const int n_past = ((int32_t *) src1->data)[0]; const bool inplace = (bool)((int32_t *) src1->data)[1]; @@ -10343,7 +10346,7 @@ static void ggml_compute_forward_diag_mask_f32( assert(src0->nb[0] == sizeof(float)); for (int k = 0; k < nz; k++) { - for (int j = 0; j < nr; j++) { + for (int j = ith; j < nr; j += nth) { for (int i = n_past; i < nc; i++) { if (i > n_past + j) { *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value; @@ -13609,7 +13612,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) work_size = MAX(work_size, cur); } break; case GGML_OP_SUB: - case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -13626,18 +13628,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = 1; } break; + case GGML_OP_MUL: case GGML_OP_GELU: - { - node->n_tasks = n_threads; - } break; case GGML_OP_SILU: - { - node->n_tasks = n_threads; - } break; case GGML_OP_SILU_BACK: - { - node->n_tasks = n_threads; - } break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: @@ -13715,11 +13709,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS_BACK: case GGML_OP_DIAG: - case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_ZERO: { node->n_tasks = 1; } break; + case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: