From 83456076f0fd8cdc72fbd0dbd275c435db6569fc Mon Sep 17 00:00:00 2001 From: katsu560 Date: Wed, 23 Nov 2022 20:23:24 +0900 Subject: [PATCH] add AVX support --- ggml.c | 166 ++++++++++++++++++++++++++++++++++++++++++++++++++++ ggml.h | 1 + whisper.cpp | 1 + 3 files changed, 168 insertions(+) diff --git a/ggml.c b/ggml.c index d58d988..ab49c23 100644 --- a/ggml.c +++ b/ggml.c @@ -372,6 +372,49 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float sumf = _mm_cvtss_f32(r1); + // leftovers + for (int i = n32; i < n; ++i) { + sumf += x[i]*y[i]; + } +#elif defined(__AVX__) + // AVX 256-bit + const int n32 = (n & ~31); + + __m256 sum0 = _mm256_setzero_ps(); + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + __m256 sum3 = _mm256_setzero_ps(); + + __m256 x0, x1, x2, x3; + __m256 y0, y1, y2, y3; + + for (int i = 0; i < n32; i += 32) { + x0 = _mm256_loadu_ps(x + i + 0); + x1 = _mm256_loadu_ps(x + i + 8); + x2 = _mm256_loadu_ps(x + i + 16); + x3 = _mm256_loadu_ps(x + i + 24); + + y0 = _mm256_loadu_ps(y + i + 0); + y1 = _mm256_loadu_ps(y + i + 8); + y2 = _mm256_loadu_ps(y + i + 16); + y3 = _mm256_loadu_ps(y + i + 24); + + sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0); + sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1); + sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2); + sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3); + } + + sum0 = _mm256_add_ps(sum0, sum1); + sum2 = _mm256_add_ps(sum2, sum3); + sum0 = _mm256_add_ps(sum0, sum2); + + const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0), _mm256_extractf128_ps(sum0, 1)); + const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); + const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); + + sumf = _mm_cvtss_f32(r1); + // leftovers for (int i = n32; i < n; ++i) { sumf += x[i]*y[i]; @@ -569,6 +612,50 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t sumf = _mm_cvtss_f32(r1); + // leftovers + for (int i = n32; i < n; ++i) { + //GGML_ASSERT(false); + sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); + } +#elif defined(__AVX__) + // AVX 256-bit + const int n32 = (n & ~31); + + __m256 sum0 = _mm256_setzero_ps(); + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + __m256 sum3 = _mm256_setzero_ps(); + + __m256 x0, x1, x2, x3; + __m256 y0, y1, y2, y3; + + for (int i = 0; i < n32; i += 32) { + x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); + x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); + x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); + x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); + + y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); + y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); + y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); + y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); + + sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0); + sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1); + sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2); + sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3); + } + + const __m256 sum01 = _mm256_add_ps(sum0, sum1); + const __m256 sum23 = _mm256_add_ps(sum2, sum3); + const __m256 sum0123 = _mm256_add_ps(sum01, sum23); + + const __m128 r4 = _mm_add_ps(_mm256_castps256_ps128(sum0123), _mm256_extractf128_ps(sum0123, 1)); + const __m128 r2 = _mm_add_ps(r4, _mm_movehl_ps(r4, r4)); + const __m128 r1 = _mm_add_ss(r2, _mm_movehdup_ps(r2)); + + sumf = _mm_cvtss_f32(r1); + // leftovers for (int i = n32; i < n; ++i) { //GGML_ASSERT(false); @@ -698,6 +785,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float _mm256_storeu_ps(y + i + 24, y3); } + // leftovers + for (int i = n32; i < n; ++i) { + y[i] += x[i]*v; + } +#elif defined(__AVX__) + // AVX 256-bit + const int n32 = (n & ~31); + + const __m256 v4 = _mm256_set1_ps(v); + + __m256 x0, x1, x2, x3; + __m256 y0, y1, y2, y3; + + for (int i = 0; i < n32; i += 32) { + x0 = _mm256_loadu_ps(x + i + 0); + x1 = _mm256_loadu_ps(x + i + 8); + x2 = _mm256_loadu_ps(x + i + 16); + x3 = _mm256_loadu_ps(x + i + 24); + + y0 = _mm256_loadu_ps(y + i + 0); + y1 = _mm256_loadu_ps(y + i + 8); + y2 = _mm256_loadu_ps(y + i + 16); + y3 = _mm256_loadu_ps(y + i + 24); + + y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0); + y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1); + y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2); + y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3); + + _mm256_storeu_ps(y + i + 0, y0); + _mm256_storeu_ps(y + i + 8, y1); + _mm256_storeu_ps(y + i + 16, y2); + _mm256_storeu_ps(y + i + 24, y3); + } + // leftovers for (int i = n32; i < n; ++i) { y[i] += x[i]*v; @@ -859,6 +981,42 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0)); } + // leftovers + for (int i = n32; i < n; ++i) { + GGML_ASSERT(false); + y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); + } +#elif defined(__AVX__) + // AVX 256-bit + const int n32 = (n & ~31); + + const __m256 v8 = _mm256_set1_ps(v); + + __m256 x0, x1, x2, x3; + __m256 y0, y1, y2, y3; + + for (int i = 0; i < n32; i += 32) { + y0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 ))); + y1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 ))); + y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16))); + y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24))); + + x0 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 ))); + x1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 ))); + x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); + x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); + + y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0); + y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1); + y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2); + y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3); + + _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0)); + _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0)); + _mm_storeu_si128((__m128i*)(y + i + 16), _mm256_cvtps_ph(y2, 0)); + _mm_storeu_si128((__m128i*)(y + i + 24), _mm256_cvtps_ph(y3, 0)); + } + // leftovers for (int i = n32; i < n; ++i) { GGML_ASSERT(false); @@ -8081,6 +8239,14 @@ enum ggml_opt_result ggml_opt( //////////////////////////////////////////////////////////////////////////////// +int ggml_cpu_has_avx(void) { +#if defined(__AVX__) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_avx2(void) { #if defined(__AVX2__) return 1; diff --git a/ggml.h b/ggml.h index f352e71..3e4e962 100644 --- a/ggml.h +++ b/ggml.h @@ -723,6 +723,7 @@ enum ggml_opt_result ggml_opt( // system info // +int ggml_cpu_has_avx(void); int ggml_cpu_has_avx2(void); int ggml_cpu_has_avx512(void); int ggml_cpu_has_neon(void); diff --git a/whisper.cpp b/whisper.cpp index 4f23cde..d729dba 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3041,6 +3041,7 @@ const char * whisper_print_system_info() { static std::string s; s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";