From 3e916a07ac093045d88ef0c4fa78647ae0efc010 Mon Sep 17 00:00:00 2001 From: gwjr <502526+gwjr@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:48:19 +0000 Subject: [PATCH] finetune : speed-up ggml_compute_forward_out_prod_f32 via BLAS (#4079) * Remove logically superfluous assertions and order by dimension * Use cblas_sgemm() to implement ggml_compute_forward_out_prod() * Remove ggml_compute_forward_out_prod_use_blas(), fix compiling errors on cmake/zig, remove trailing whitespace * Add openBLAS support for sgemm() in compute_forward_out_prod() --- ggml.c | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 61 insertions(+), 8 deletions(-) diff --git a/ggml.c b/ggml.c index ada1067da..c7086ba84 100644 --- a/ggml.c +++ b/ggml.c @@ -9611,10 +9611,12 @@ static void ggml_compute_forward_out_prod_f32( const int ith = params->ith; const int nth = params->nth; + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - GGML_ASSERT(ne2 == ne12); GGML_ASSERT(ne3 == ne13); + GGML_ASSERT(ne03 == ne13); // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == sizeof(float)); @@ -9625,18 +9627,25 @@ static void ggml_compute_forward_out_prod_f32( // GGML_ASSERT(nb1 <= nb2); // GGML_ASSERT(nb2 <= nb3); - GGML_ASSERT(ne0 == ne00); - GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - // nb01 >= nb00 - src0 is not transposed // compute by src0 rows // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod - // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) + // TODO: #if defined(GGML_USE_CLBLAST) + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + bool use_blas = ggml_is_matrix(src0) && + ggml_is_matrix(src1) && + ggml_is_contiguous(src0) && + (ggml_is_contiguous(src1) || ggml_is_transposed(src1)); +#endif if (params->type == GGML_TASK_INIT) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst + if (use_blas) { + return; + } +#endif ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); return; } @@ -9645,6 +9654,50 @@ static void ggml_compute_forward_out_prod_f32( return; } +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (use_blas) { + if (params->ith != 0) { // All threads other than the first do no work. + return; + } + // Arguments to ggml_compute_forward_out_prod (expressed as major,minor) + // src0: (k,n) + // src1: (k,m) + // dst: (m,n) + // + // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f) + // Also expressed as (major,minor) + // a: (m,k): so src1 transposed + // b: (k,n): so src0 + // c: (m,n) + // + // However, if ggml_is_transposed(src1) is true, then + // src1->data already contains a transposed version, so sgemm mustn't + // transpose it further. + + int n = src0->ne[0]; + int k = src0->ne[1]; + int m = src1->ne[0]; + + int transposeA, lda; + + if (!ggml_is_transposed(src1)) { + transposeA = CblasTrans; + lda = m; + } else { + transposeA = CblasNoTrans; + lda = k; + } + + float * a = (float *) ((char *) src1->data); + float * b = (float *) ((char *) src0->data); + float * c = (float *) ((char *) dst->data); + + cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n); + + return; + } +#endif + // dst[:,:,:,:] = 0 // for i2,i3: // for i1: