From 72d967bce4c03b0a48f9491129e3e0e9cd7b1e80 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 17 Oct 2022 21:44:16 +0300 Subject: [PATCH] Use Accelerate framework on Apple silicon Huge performance improvement in the Encode (almost x2 on MacBook M1 Pro) Also various extra optimizations: - Multi-threaded NORM operator - Faster GELU via F16 cast --- Makefile | 9 +- README.md | 20 ++-- ggml.c | 295 +++++++++++++++++++++++++++++++++------------------- main.cpp | 2 +- whisper.cpp | 12 +-- 5 files changed, 217 insertions(+), 121 deletions(-) diff --git a/Makefile b/Makefile index 6552b28..61384ae 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ UNAME_M := $(shell uname -m) CFLAGS = -O3 -std=c11 CXXFLAGS = -O3 -std=c++11 +LDFLAGS = CFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-unused-function CXXFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-unused-function @@ -37,7 +38,11 @@ ifeq ($(UNAME_M),amd64) CFLAGS += -mavx -mavx2 -mfma -mf16c endif ifneq ($(filter arm%,$(UNAME_M)),) - # Mac M1 + # Mac M1 - include Accelerate framework + ifeq ($(UNAME_S),Darwin) + CFLAGS += -DGGML_USE_ACCELERATE + LDFLAGS += -framework Accelerate + endif endif ifneq ($(filter aarch64%,$(UNAME_M)),) endif @@ -59,7 +64,7 @@ endif # main: main.cpp ggml.o whisper.o - $(CXX) $(CXXFLAGS) main.cpp whisper.o ggml.o -o main + $(CXX) $(CXXFLAGS) main.cpp whisper.o ggml.o -o main $(LDFLAGS) ./main -h ggml.o: ggml.c ggml.h diff --git a/README.md b/README.md index 19de94f..36972ac 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,8 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: - Plain C/C++ implementation without dependencies -- ARM_NEON and AVX intrinsics support +- Apple silicon first-class citizen - optimized via Arm Neon and Accelerate framework +- AVX intrinsics support for x86 architectures - Mixed F16 / F32 precision - Low memory usage (Flash Attention + Flash Forward) - Zero memory allocations at runtime @@ -224,7 +225,7 @@ https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a The `stream` tool depends on SDL2 library to capture audio from the microphone. You can build it like this: ```bash -# Install SDL2 on Linux +# Install SDL2 on Linux sudo apt-get install libsdl2-dev # Install SDL2 on Mac OS @@ -240,6 +241,10 @@ make stream - Simple usage is demonstrated in [main.cpp](main.cpp) - Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](stream.cpp) +The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD +instrisics or CBLAS Accelerate framwork routines are used. The latter are especially effective for bigger sizes since +the framwork utilizes the special-purpose AMX coprocessor available in modern Apple products. + ## Limitations - Very basic greedy sampling scheme - always pick up the top token. You can implement your own strategy @@ -250,11 +255,12 @@ make stream | Model | Disk | Mem | | --- | --- | --- | -| tiny | 75 MB | ~240 MB | -| base | 142 MB | ~380 MB | -| small | 466 MB | ~970 MB | -| medium | 1.5 GB | ~2.5 GB | -| large | 2.9 GB | ~4.6 GB | +| tiny | 75 MB | ~280 MB | +| base | 142 MB | ~430 MB | +| small | 466 MB | ~1.0 GB | +| medium | 1.5 GB | ~2.6 GB | +| large | 2.9 GB | ~4.7 GB | + ## ggml format diff --git a/ggml.c b/ggml.c index 6c585d8..7f11c96 100644 --- a/ggml.c +++ b/ggml.c @@ -716,12 +716,6 @@ inline static float ggml_gelu_f32(float x) { return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); } -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = ggml_gelu_f32(x[i]); - } -} - inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { const uint16_t * i16 = (const uint16_t *) x; for (int i = 0; i < n; ++i) { @@ -729,6 +723,21 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp } } +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = ggml_fp32_to_fp16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = table_gelu_f16[t]; + } +} + +//inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { +// for (int i = 0; i < n; ++i) { +// y[i] = ggml_gelu_f32(x[i]); +// } +//} + inline static void ggml_vec_sum_f32 (const int n, float * s, const float * x) { ggml_float sum = 0.0; for (int i = 0; i < n; ++i) sum += x[i]; *s += sum; } inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); } @@ -2867,13 +2876,15 @@ void ggml_compute_forward_add_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - GGML_ASSERT(params->ith == 0); GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } + const int ith = params->ith; + const int nth = params->nth; + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; @@ -2890,7 +2901,7 @@ void ggml_compute_forward_add_f32( GGML_ASSERT(nb00 == sizeof(float)); if (nb10 == sizeof(float)) { - for (int j = 0; j < n; j++) { + for (int j = ith; j < n; j += nth) { ggml_vec_add_f32(nc, (float *) ((char *) dst->data + j*nb1), (float *) ((char *) src0->data + j*nb01), @@ -2898,7 +2909,7 @@ void ggml_compute_forward_add_f32( } } else { // src1 is not contiguous - for (int j = 0; j < n; j++) { + for (int j = ith; j < n; j += nth) { float * dst_ptr = (float *) ((char *) dst->data + j*nb1); float * src0_ptr = (float *) ((char *) src0->data + j*nb01); for (int i = 0; i < nc; i++) { @@ -3669,14 +3680,16 @@ void ggml_compute_forward_norm_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - assert(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -3696,7 +3709,7 @@ void ggml_compute_forward_norm_f32( // TODO: optimize for (int i03 = 0; i03 < ne03; i03++) { for (int i02 = 0; i02 < ne02; i02++) { - for (int i01 = 0; i01 < ne01; i01++) { + for (int i01 = ith; i01 < ne01; i01 += nth) { const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); ggml_float mean = 0.0; @@ -3745,6 +3758,28 @@ void ggml_compute_forward_norm( // ggml_compute_forward_mul_mat +// helper function to determine if it is better to use BLAS or not +// for large matrices, BLAS is faster +bool ggml_compute_forward_mul_mat_use_blas( + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + UNUSED(src0); + + const int ne10 = src1->ne[0]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + if (ggml_is_contiguous(src1) && ne0 >= 32 && ne1 >= 32 && ne10 >= 32) { + //printf("BLAS: %d %d %d\n", ne0, ne1, ne10); + return true; + } + + return false; +} + void ggml_compute_forward_mul_mat_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -3812,6 +3847,47 @@ void ggml_compute_forward_mul_mat_f32( // nb00 < nb01 - src0 is transposed // compute by src0 columns +//#ifdef GGML_USE_ACCELERATE +// if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { +// GGML_ASSERT(ggml_is_contiguous(src0)); +// GGML_ASSERT(nb10 == sizeof(float)); +// +// if (params->ith != 0) return; +// +// if (params->type == GGML_TASK_INIT) { +// return; +// } +// +// if (params->type == GGML_TASK_FINALIZE) { +// return; +// } +// +// float * const wdata = params->wdata; +// +// for (int i03 = 0; i03 < ne03; i03++) { +// for (int i02 = 0; i02 < ne02; i02++) { +// const float * x = (float *) (src0->data); +// const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); +// +// float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); +// +// // zT = y * xT +// { +// cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, +// ne11, ne01, ne10, +// 1.0f, y, ne10, +// x, ne10, +// 0.0f, d, ne01); +// } +// } +// } +// +// //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); +// +// return; +// } +//#endif + if (params->type == GGML_TASK_INIT) { if (nb01 >= nb00) { return; @@ -3848,78 +3924,6 @@ void ggml_compute_forward_mul_mat_f32( return; } -//#ifdef GGML_USE_ACCELERATE -// // try to use BLAS -// -// if (nb01 >= nb00 && ne0 > 1024 && ne1 > 1024) { -// if (params->ith != 0) return; -// printf("XXXXXXXX\n"); -// -// GGML_ASSERT(ggml_is_contiguous(src0)); -// GGML_ASSERT(ggml_is_contiguous(src1)); -// -// printf("ne00 = %d, ne01 = %d, ne02 = %d, ne03 = %d\n", ne00, ne01, ne02, ne03); -// printf("ne10 = %d, ne11 = %d, ne12 = %d, ne13 = %d\n", ne10, ne11, ne12, ne13); -// printf("ne0 = %d, ne1 = %d, ne2 = %d, ne3 = %d\n", ne0, ne1, ne2, ne3); -// -// printf("nb00 = %d, nb01 = %d, nb02 = %d, nb03 = %d\n", nb00, nb01, nb02, nb03); -// printf("nb10 = %d, nb11 = %d, nb12 = %d, nb13 = %d\n", nb10, nb11, nb12, nb13); -// printf("nb0 = %d, nb1 = %d, nb2 = %d, nb3 = %d\n", nb0, nb1, nb2, nb3); -// -// float * const wdata = params->wdata; -// -// int64_t tsum = 0.0; -// for (int i03 = 0; i03 < ne03; i03++) { -// for (int i02 = 0; i02 < ne02; i02++) { -// const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); -// const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); -// float * z = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); -// -// // transpose src1 -// for (int j = 0; j < ne11; ++j) { -// for (int i = 0; i < ne10; ++i) { -// wdata[i*ne11 + j] = y[j*ne10 + i]; -// } -// } -// -// { -// const int64_t tt0 = ggml_time_us(); -// cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, -// 1500, 1500, 64, -// 1.0, x, 64, -// wdata, 1500, -// 0.0, z, 1500); -// const int64_t tt1 = ggml_time_us(); -// tsum += tt1 - tt0; -// } -// -// // transpose z -// for (int j = 0; j < ne1; ++j) { -// for (int i = 0; i < ne0; ++i) { -// wdata[i*ne1 + j] = z[j*ne0 + i]; -// } -// } -// -// memcpy(z, wdata, ne0*ne1*sizeof(float)); -// -// //cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, -// // ne0, ne1, 64, -// // 1.0f, -// // x, ne00, -// // y, ne11, -// // 0.0f, -// // z, 1500); -// } -// } -// printf("time = %f ms\n", tsum/1000.0); -// return; -// } else { -// //cblas_sgemv(CblasRowMajor, CblasTrans, ne00, ne01, 1.0, src0->data, ne01, src1->data, 1, 0.0, dst->data, 1); -// } -// -//#endif - - if (nb01 >= nb00) { // TODO: do not support transposed src1 assert(nb10 == sizeof(float)); @@ -4064,24 +4068,24 @@ void ggml_compute_forward_mul_mat_f16_f32( const int ith = params->ith; const int nth = params->nth; - assert(ne02 == ne12); - assert(ne03 == ne13); - assert(ne2 == ne12); - assert(ne3 == ne13); + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); // TODO: we don't support permuted src0 - assert(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t)); // dst cannot be transposed or permuted - assert(nb0 == sizeof(float)); - assert(nb0 <= nb1); - assert(nb1 <= nb2); - assert(nb2 <= nb3); + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); - assert(ne0 == ne01); - assert(ne1 == ne11); - assert(ne2 == ne02); - assert(ne3 == ne03); + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); // nb01 >= nb00 - src0 is not transposed // compute by src0 rows @@ -4089,6 +4093,73 @@ void ggml_compute_forward_mul_mat_f16_f32( // nb00 < nb01 - src0 is transposed // compute by src0 columns +#ifdef GGML_USE_ACCELERATE + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->ith != 0) return; + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + float * const wdata = params->wdata; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + { + int id = 0; + for (int i01 = 0; i01 < ne01; ++i01) { + for (int i00 = 0; i00 < ne00; ++i00) { + wdata[id++] = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); + } + } + } + + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + // float * z = wdata + ne00*ne01; + + // z = x * yT + //{ + // cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + // ne01, ne11, ne00, + // 1.0f, x, ne00, + // y, ne00, + // 0.0f, z, ne11); + //} + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // transpose z + //for (int j = 0; j < ne11; ++j) { + // for (int i = 0; i < ne01; ++i) { + // d[j*ne01 + i] = z[i*ne11 + j]; + // } + //} + + // zT = y * xT + { + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01); + } + } + } + + //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + if (params->type == GGML_TASK_INIT) { if (nb01 >= nb00) { ggml_fp16_t * const wdata = params->wdata; @@ -6534,7 +6605,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) switch (node->op) { case GGML_OP_DUP: + { + node->n_tasks = 1; + } break; case GGML_OP_ADD: + { + node->n_tasks = 1; + } break; case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: @@ -6553,11 +6630,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_GELU: { - node->n_tasks = MIN(n_threads, ggml_nrows(node->src0)); + node->n_tasks = n_threads; } break; case GGML_OP_NORM: { - node->n_tasks = 1; + node->n_tasks = n_threads; } break; case GGML_OP_MUL_MAT: { @@ -6572,7 +6649,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } else { if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { +#ifdef GGML_USE_ACCELERATE + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + cur = sizeof(float)*(node->src0->ne[0]*node->src0->ne[1]); + } else { + cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); + } +#else cur = sizeof(ggml_fp16_t)*ggml_nelements(node->src1); +#endif } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; @@ -6585,7 +6670,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_SCALE: { - node->n_tasks = MIN(n_threads, ggml_nrows(node->src0)); + node->n_tasks = n_threads; } break; case GGML_OP_CPY: case GGML_OP_RESHAPE: @@ -6599,7 +6684,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_SOFT_MAX: { - node->n_tasks = MIN(n_threads, ggml_nrows(node->src0)); + node->n_tasks = n_threads; } break; case GGML_OP_ROPE: { @@ -6714,7 +6799,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) struct ggml_compute_params params = { /*.type =*/ GGML_TASK_INIT, /*.ith =*/ 0, - /*.nth =*/ n_threads, + /*.nth =*/ node->n_tasks, /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, }; @@ -6898,9 +6983,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { perf_total_per_op_us[node->op] += node->perf_time_us; - GGML_PRINT(" - %3d: [ %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", + GGML_PRINT(" - %3d: [ %6d, %6d, %6d] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", i, - node->ne[0], node->ne[1], + node->ne[0], node->ne[1], node->ne[2], GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, (double) node->perf_cycles / (double) ggml_cycles_per_ms(), (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, diff --git a/main.cpp b/main.cpp index 43838cf..b913522 100644 --- a/main.cpp +++ b/main.cpp @@ -21,7 +21,7 @@ std::string to_timestamp(int64_t t) { msec = msec - min * (1000 * 60); int64_t sec = msec / 1000; msec = msec - sec * 1000; - + char buf[32]; snprintf(buf, sizeof(buf), "%02d:%02d:%02d.%03d", (int) hr, (int) min, (int) sec, (int) msec); diff --git a/whisper.cpp b/whisper.cpp index b984c46..d0e6e76 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -15,7 +15,7 @@ #include #define USE_FLASH_ATTN -#define USE_FLASH_FF +//#define USE_FLASH_FF // available whisper models enum e_model { @@ -148,11 +148,11 @@ static const std::map MEM_REQ_ENCODE = { }; static const std::map MEM_REQ_ENCODE_LAYER = { - { MODEL_TINY, 64ull*MB }, - { MODEL_BASE, 84ull*MB }, - { MODEL_SMALL, 128ull*MB }, - { MODEL_MEDIUM, 172ull*MB }, - { MODEL_LARGE, 216ull*MB }, + { MODEL_TINY, 104ull*MB }, + { MODEL_BASE, 138ull*MB }, + { MODEL_SMALL, 208ull*MB }, + { MODEL_MEDIUM, 280ull*MB }, + { MODEL_LARGE, 354ull*MB }, }; static const std::map MEM_REQ_DECODE = {