ggml : multi-thread mul and diag_mask ops (#1428)

This commit is contained in:
Georgi Gerganov 2023-05-13 16:48:03 +03:00 committed by GitHub
parent 905d87b70a
commit 66841fdb0e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

26
ggml.c
View file

@ -7765,12 +7765,13 @@ static void ggml_compute_forward_mul_f32(
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
assert(params->ith == 0);
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
} }
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(src0); const int nr = ggml_nrows(src0);
const int64_t ne0 = src0->ne[0]; const int64_t ne0 = src0->ne[0];
@ -7796,7 +7797,7 @@ static void ggml_compute_forward_mul_f32(
GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) { if (nb10 == sizeof(float)) {
for (int ir = 0; ir < nr; ++ir) { for (int ir = ith; ir < nr; ir += nth) {
// src0, src1 and dst are same shape => same indices // src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1); const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1; const int i2 = (ir - i3*ne2*ne1)/ne1;
@ -7822,7 +7823,7 @@ static void ggml_compute_forward_mul_f32(
} }
} else { } else {
// src1 is not contiguous // src1 is not contiguous
for (int ir = 0; ir < nr; ++ir) { for (int ir = ith; ir < nr; ir += nth) {
// src0, src1 and dst are same shape => same indices // src0, src1 and dst are same shape => same indices
const int i3 = ir/(ne2*ne1); const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1; const int i2 = (ir - i3*ne2*ne1)/ne1;
@ -10317,7 +10318,6 @@ static void ggml_compute_forward_diag_mask_f32(
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
struct ggml_tensor * dst, struct ggml_tensor * dst,
const float value) { const float value) {
assert(params->ith == 0);
assert(src1->type == GGML_TYPE_I32); assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 2); assert(ggml_nelements(src1) == 2);
@ -10325,6 +10325,9 @@ static void ggml_compute_forward_diag_mask_f32(
return; return;
} }
const int ith = params->ith;
const int nth = params->nth;
const int n_past = ((int32_t *) src1->data)[0]; const int n_past = ((int32_t *) src1->data)[0];
const bool inplace = (bool)((int32_t *) src1->data)[1]; const bool inplace = (bool)((int32_t *) src1->data)[1];
@ -10343,7 +10346,7 @@ static void ggml_compute_forward_diag_mask_f32(
assert(src0->nb[0] == sizeof(float)); assert(src0->nb[0] == sizeof(float));
for (int k = 0; k < nz; k++) { for (int k = 0; k < nz; k++) {
for (int j = 0; j < nr; j++) { for (int j = ith; j < nr; j += nth) {
for (int i = n_past; i < nc; i++) { for (int i = n_past; i < nc; i++) {
if (i > n_past + j) { if (i > n_past + j) {
*(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value; *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
@ -13609,7 +13612,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
work_size = MAX(work_size, cur); work_size = MAX(work_size, cur);
} break; } break;
case GGML_OP_SUB: case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV: case GGML_OP_DIV:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SQRT: case GGML_OP_SQRT:
@ -13626,18 +13628,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
{ {
node->n_tasks = 1; node->n_tasks = 1;
} break; } break;
case GGML_OP_MUL:
case GGML_OP_GELU: case GGML_OP_GELU:
{
node->n_tasks = n_threads;
} break;
case GGML_OP_SILU: case GGML_OP_SILU:
{
node->n_tasks = n_threads;
} break;
case GGML_OP_SILU_BACK: case GGML_OP_SILU_BACK:
{
node->n_tasks = n_threads;
} break;
case GGML_OP_NORM: case GGML_OP_NORM:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK:
@ -13715,11 +13709,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
case GGML_OP_GET_ROWS_BACK: case GGML_OP_GET_ROWS_BACK:
case GGML_OP_DIAG: case GGML_OP_DIAG:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_DIAG_MASK_ZERO: case GGML_OP_DIAG_MASK_ZERO:
{ {
node->n_tasks = 1; node->n_tasks = 1;
} break; } break;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE: case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK: case GGML_OP_ROPE_BACK: