From e410cfc3ce9a47a3ea676e46b39aa863117faf4e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 May 2023 18:56:30 +0300 Subject: [PATCH] ggml : sync latest ggml repo - new Q4 and Q8 quantization - updated CUDA --- examples/common.cpp | 3 +- ggml-cuda.cu | 291 ++++++++++++++------------ ggml-cuda.h | 2 + ggml.c | 485 +++++++++++++++++++++++++++++--------------- ggml.h | 16 +- 5 files changed, 502 insertions(+), 295 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index a8461fb..76da30d 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -26,7 +26,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "-n" || arg == "--n_predict") { params.n_predict = std::stoi(argv[++i]); } else if (arg == "--top_k") { - params.top_k = std::stoi(argv[++i]); + params.top_k = std::max(1, std::stoi(argv[++i])); } else if (arg == "--top_p") { params.top_p = std::stof(argv[++i]); } else if (arg == "--temp") { @@ -259,6 +259,7 @@ std::vector gpt_tokenize(const gpt_vocab & vocab, const std::stri if (it != vocab.token_to_id.end()) { tokens.push_back(it->second); i = j; + j = n; break; } --j; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index eb9f0df..35d2e45 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -42,19 +42,19 @@ typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, #define QK4_0 32 #define QR4_0 2 typedef struct { - float d; // delta + half d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); #define QK4_1 32 #define QR4_1 2 typedef struct { - float d; // delta - float m; // min + half d; // delta + half m; // min uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; -static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); +static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); #define QK5_0 32 #define QR5_0 2 @@ -78,12 +78,23 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + #define QK8_0 32 #define QR8_0 1 typedef struct { - float d; // delta + half d; // delta int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); -#define CUDA_DMMV_BLOCK_SIZE 32 +#define CUDA_MUL_BLOCK_SIZE 256 +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec + +static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= kx) { + return; + } + dst[i] = x[i] * y[i%ky]; +} static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ const block_q4_0 * x = (const block_q4_0 *) vx; @@ -170,104 +181,23 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, v1 = __half2float(x[ib + 1]); } -static __global__ void dequantize_block_q4_0(const void * vx, float * y) { - static const int qk = QK4_0; +template +static __global__ void dequantize_block(const void * vx, float * y, const int k) { + const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; - const block_q4_0 * x = (const block_q4_0 *) vx; - - const int i = blockIdx.x; - - const float d = x[i].d; - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0xf) - 8; - const int x1 = (x[i].qs[j] >> 4) - 8; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; + if (i >= k) { + return; } -} -static __global__ void dequantize_block_q4_1(const void * vx, float * y) { - static const int qk = QK4_1; + const int ib = i/qk; // block index + const int iqs = (i%qk)/qr; // quant index + const int iybs = i - i%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; - const block_q4_1 * x = (const block_q4_1 *) vx; - - const int i = blockIdx.x; - - const float d = x[i].d; - const float m = x[i].m; - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0xf); - const int x1 = (x[i].qs[j] >> 4); - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } -} - -static __global__ void dequantize_block_q5_0(const void * vx, float * y) { - static const int qk = QK5_0; - - const block_q5_0 * x = (const block_q5_0 *) vx; - - const int i = blockIdx.x; - - const float d = x[i].d; - - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } -} - -static __global__ void dequantize_block_q5_1(const void * vx, float * y) { - static const int qk = QK5_1; - - const block_q5_1 * x = (const block_q5_1 *) vx; - - const int i = blockIdx.x; - - const float d = x[i].d; - const float m = x[i].m; - - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int x0 = (x[i].qs[j] & 0xf) | xh_0; - const int x1 = (x[i].qs[j] >> 4) | xh_1; - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } -} - -static __global__ void dequantize_block_q8_0(const void * vx, float * y) { - static const int qk = QK8_0; - - const block_q8_0 * x = (const block_q8_0 *) vx; - - const int i = blockIdx.x; - - const float d = x[i].d; - - for (int j = 0; j < qk; ++j) { - y[i*qk + j] = x[i].qs[j]*d; - } + // dequantize + float & v0 = y[iybs + iqs + 0]; + float & v1 = y[iybs + iqs + y_offset]; + dequantize_kernel(vx, ib, iqs, v0, v1); } template @@ -308,29 +238,34 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } } -static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_0; - dequantize_block_q4_0<<>>(vx, y); +static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { + const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; + mul_f32<<>>(x, y, dst, kx, ky); } -static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_1; - dequantize_block_q4_1<<>>(vx, y); +static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK5_0; - dequantize_block_q5_0<<>>(vx, y); +static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK5_1; - dequantize_block_q5_1<<>>(vx, y); +static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK8_0; - dequantize_block_q8_0<<>>(vx, y); +static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); } static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { @@ -363,17 +298,9 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f <<>>(vx, y, dst, ncols); } -// TODO: optimize -static __global__ void convert_fp16_to_fp32(const void * vx, float * y) { - const half * x = (const half *) vx; - - const int i = blockIdx.x; - - y[i] = __half2float(x[i]); -} - -static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) { - convert_fp16_to_fp32<<>>(x, y); +static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<32, 1, convert_f16><<>>(vx, y, k); } static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { @@ -555,6 +482,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor } } +static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA); + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[2]; + const int64_t ne0 = ne00 * ne01 * ne02 * ne03; + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + size_t x_size, d_size; + + float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0 + float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted. + float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const int i0 = i03*ne02 + i02; + float * c_X2 = d_X + i0*ne01*ne00; + float * c_D2 = d_D + i0*ne01*ne00; + + cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS]; + cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS]; + cudaEvent_t cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS]; + + // copy src0 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2)); + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + + // wait for data + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + + for (int64_t i01 = 0; i01 < ne01; i01++) { + const int64_t i13 = i03%ne13; + const int64_t i12 = i02%ne12; + const int64_t i11 = i01%ne11; + const int i1 = i13*ne12*ne11 + i12*ne11 + i11; + + float * c_X1 = c_X2 + i01*ne00; + float * c_Y = d_Y + i1*ne10; + float * c_D1 = c_D2 + i01*ne00; + + // compute + mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream); + CUDA_CHECK(cudaGetLastError()); + } + + // copy dst to host + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream)); + } + } + CUDA_CHECK(cudaDeviceSynchronize()); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_D, d_size); +} + static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -812,6 +800,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor ggml_cuda_pool_free(d_Q, q_size); } +void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + ggml_cuda_mul_f32(src0, src1, dst); +} + bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { const int64_t ne10 = src1->ne[0]; @@ -885,14 +878,48 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) { const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); size_t q_size; - char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); + char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); cudaStream_t cudaStream2 = g_cudaStreams2[0]; // copy tensor to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2)); - CUDA_CHECK(cudaDeviceSynchronize()); + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + int i = i3*ne2 + i2; + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2)); + } + } - tensor->data = d_Q; + tensor->data = dst; tensor->backend = GGML_BACKEND_CUDA; } + +void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) { + FILE * fp = fopen(fname, "rb"); + + const size_t size = ggml_nbytes(tensor); + + void * buf; + CUDA_CHECK(cudaMalloc(&buf, size)); + void * buf_host = malloc(size); + +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, SEEK_SET); +#else + int ret = fseek(fp, (long) offset, SEEK_SET); +#endif + GGML_ASSERT(ret == 0); // same + + size_t ret2 = fread(buf_host, size, 1, fp); + if (ret2 != 1) { + fprintf(stderr, "unexpectedly reached end of file"); + exit(1); + } + + cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + tensor->data = buf; + free(buf_host); + fclose(fp); +} diff --git a/ggml-cuda.h b/ggml-cuda.h index 4e2c242..6a04dde 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -6,6 +6,7 @@ extern "C" { void ggml_init_cublas(void); +void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize); @@ -15,6 +16,7 @@ void * ggml_cuda_host_malloc(size_t size); void ggml_cuda_host_free(void * ptr); void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); +void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset); #ifdef __cplusplus } diff --git a/ggml.c b/ggml.c index da3d914..f4a5a8d 100644 --- a/ggml.c +++ b/ggml.c @@ -512,7 +512,7 @@ static inline int hsum_i32_4(const __m128i a) { return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); } -#if __AVX2__ || __AVX512F__ +#if defined(__AVX2__) || defined(__AVX512F__) // spread 32 bits to 32 bytes { 0x00, 0xFF } static inline __m256i bytes_from_bits_32(const uint8_t * x) { uint32_t x32; @@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { return _mm256_cvtepi32_ps(summed_pairs); } -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(x, x); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(y, x); +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { #if __AVXVNNI__ const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); @@ -560,6 +555,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { #endif } +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_float(ax, sy); +#endif +} + static inline __m128i packNibbles( __m256i bytes ) { // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh @@ -619,6 +629,17 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { return _mm256_cvtepi32_ps(summed_pairs); } +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m128i axl = _mm256_castsi256_si128(ax); + const __m128i axh = _mm256_extractf128_si256(ax, 1); + const __m128i syl = _mm256_castsi256_si128(sy); + const __m128i syh = _mm256_extractf128_si256(sy, 1); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + // multiply int8_t, add results pairwise twice and return as float vector static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { const __m128i xl = _mm256_castsi256_si128(x); @@ -667,7 +688,7 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 #endif // __AVX__ || __AVX2__ || __AVX512F__ #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) -#if __ARM_NEON +#if defined(__ARM_NEON) #if !defined(__aarch64__) @@ -748,18 +769,18 @@ int32x4_t vcvtnq_s32_f32(float32x4_t v) { #define QK4_0 32 typedef struct { - float d; // delta + ggml_fp16_t d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); #define QK4_1 32 typedef struct { - float d; // delta - float m; // min + ggml_fp16_t d; // delta + ggml_fp16_t m; // min uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; -static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding"); +static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); #define QK5_0 32 typedef struct { @@ -780,16 +801,16 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + #define QK8_0 32 typedef struct { - float d; // delta - int8_t qs[QK8_0]; // quants + ggml_fp16_t d; // delta + int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); #define QK8_1 32 typedef struct { - float d; // delta - float s; // d * sum(qs[i]) - int8_t qs[QK8_1]; // quants + float d; // delta + float s; // d * sum(qs[i]) + int8_t qs[QK8_1]; // quants } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); @@ -816,7 +837,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r const float d = max / -8; const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); for (int j = 0; j < qk/2; ++j) { const float x0 = x[i*qk + 0 + j]*id; @@ -856,8 +877,8 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r const float d = (max - min) / ((1 << 4) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; - y[i].m = min; + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); for (int j = 0; j < qk/2; ++j) { const float x0 = (x[i*qk + 0 + j] - min)*id; @@ -988,7 +1009,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); for (int j = 0; j < QK8_0; ++j) { const float x0 = x[i*QK8_0 + j]*id; @@ -1023,7 +1044,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); for (int j = 0; j < 8; j++) { const float32x4_t v = vmulq_n_f32(srcv[j], id); @@ -1058,7 +1079,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int // Quantize these floats const float d = maxScalar / 127.f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; const __m256 mul = _mm256_set1_ps( id ); @@ -1157,7 +1178,7 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r sum += y[i].qs[QK8_1/2 + j]; } - y[i].s = d * sum; + y[i].s = sum*d; } } @@ -1309,7 +1330,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict const int nb = k / qk; for (int i = 0; i < nb; i++) { - const float d = x[i].d; + const float d = GGML_FP16_TO_FP32(x[i].d); for (int j = 0; j < qk/2; ++j) { const int x0 = (x[i].qs[j] & 0x0F) - 8; @@ -1329,8 +1350,8 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict const int nb = k / qk; for (int i = 0; i < nb; i++) { - const float d = x[i].d; - const float m = x[i].m; + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); for (int j = 0; j < qk/2; ++j) { const int x0 = (x[i].qs[j] & 0x0F); @@ -1405,7 +1426,7 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in const block_q8_0 * restrict x = vx; for (int i = 0; i < nb; i++) { - const float d = x[i].d; + const float d = GGML_FP16_TO_FP32(x[i].d); for (int j = 0; j < qk; ++j) { y[i*qk + j] = x[i].qs[j]*d; @@ -1669,8 +1690,9 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) { static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { float tmp[8]; - for (int i = 0; i < 8; i++) + for (int i = 0; i < 8; i++) { tmp[i] = GGML_FP16_TO_FP32(x[i]); + } return _mm256_loadu_ps(tmp); } @@ -2090,8 +2112,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2119,8 +2141,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); @@ -2137,8 +2159,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #endif } @@ -2150,7 +2172,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; ++i) { /* Compute combined scale for the block */ - const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); __m256i bx = bytes_from_nibbles_32(x[i].qs); @@ -2174,7 +2196,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; ++i) { // Compute combined scale for the block - const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); const __m128i lowMask = _mm_set1_epi8(0xF); const __m128i off = _mm_set1_epi8(8); @@ -2216,7 +2238,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) ); + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); @@ -2234,7 +2256,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) ); + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); @@ -2267,7 +2289,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) ); + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); @@ -2285,7 +2307,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) ); + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); @@ -2333,7 +2355,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); } - sumf += (x[i].d*y[i].d)*sumi; + sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); } *s = sumf; @@ -2363,7 +2385,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const block_q8_1 * restrict y0 = &y[i + 0]; const block_q8_1 * restrict y1 = &y[i + 1]; - summs += x0->m * y0->s + x1->m * y1->s; + summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s; const uint8x16_t m4b = vdupq_n_u8(0x0F); @@ -2387,8 +2409,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); @@ -2405,8 +2427,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); #endif } @@ -2419,13 +2441,13 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; ++i) { - const float * d0 = &x[i].d; - const float * d1 = &y[i].d; + const float d0 = GGML_FP16_TO_FP32(x[i].d); + const float d1 = y[i].d; - summs += x[i].m * y[i].s; + summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; - const __m256 d0v = _mm256_broadcast_ss( d0 ); - const __m256 d1v = _mm256_broadcast_ss( d1 ); + const __m256 d0v = _mm256_set1_ps( d0 ); + const __m256 d1v = _mm256_set1_ps( d1 ); // Compute combined scales const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); @@ -2434,7 +2456,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const __m256i bx = bytes_from_nibbles_32(x[i].qs); const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); - const __m256 xy = mul_sum_i8_pairs_float(bx, by); + const __m256 xy = mul_sum_us8_pairs_float(bx, by); // Accumulate d0*d1*x*y #if defined(__AVX2__) @@ -2459,7 +2481,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); } - sumf += (x[i].d*y[i].d)*sumi + x[i].m*y[i].s; + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; } *s = sumf; @@ -2535,16 +2557,13 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - const float x0d = GGML_FP16_TO_FP32(x0->d); - const float x1d = GGML_FP16_TO_FP32(x1->d); - #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d); + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d); + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); @@ -2561,8 +2580,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #endif } @@ -2637,7 +2656,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; i++) { /* Compute combined scale for the block */ - const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d)); + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); __m256i bx = bytes_from_nibbles_32(x[i].qs); __m256i bxhi = bytes_from_bits_32(x[i].qh); @@ -2661,7 +2680,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; i++) { /* Compute combined scale for the block */ - const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d)); + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); __m256i bx = bytes_from_nibbles_32(x[i].qs); const __m256i bxhi = bytes_from_bits_32(x[i].qh); @@ -2704,7 +2723,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi; + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; } *s = sumf; @@ -2786,16 +2805,13 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - const float x0d = GGML_FP16_TO_FP32(x0->d); - const float x1d = GGML_FP16_TO_FP32(x1->d); - #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d); + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d); + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); @@ -2812,8 +2828,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); #endif } @@ -2873,15 +2889,14 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - const float x0d = GGML_FP16_TO_FP32(x0->d); - // dot product - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( - wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d))); + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)); } *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + @@ -2903,10 +2918,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); bx = _mm256_or_si256(bx, bxhi); - const __m256 dy = _mm256_broadcast_ss(&y[i].d); + const __m256 dy = _mm256_set1_ps(y[i].d); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 q = mul_sum_i8_pairs_float(bx, by); + const __m256 q = mul_sum_us8_pairs_float(bx, by); acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); } @@ -2937,10 +2952,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * bxh = _mm_or_si128(bxh, bxhih); bx = _mm256_set_m128i(bxh, bxl); - const __m256 dy = _mm256_broadcast_ss(&y[i].d); + const __m256 dy = _mm256_set1_ps(y[i].d); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 q = mul_sum_i8_pairs_float(bx, by); + const __m256 q = mul_sum_us8_pairs_float(bx, by); acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); } @@ -3007,11 +3022,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), - vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d); + vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), - vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d); + vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #else const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); @@ -3029,8 +3044,8 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #endif } @@ -3042,7 +3057,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; ++i) { // Compute combined scale for the block - const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); @@ -3068,7 +3083,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * sumi += x[i].qs[j]*y[i].qs[j]; } - sumf += (x[i].d*y[i].d)*sumi; + sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); } *s = sumf; @@ -3457,6 +3472,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "ROPE", "ROPE_BACK", "ALIBI", + "CLAMP", "CONV_1D_1S", "CONV_1D_2S", @@ -3467,7 +3483,8 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "MAP_BINARY", }; -static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50"); +static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); + static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3517,6 +3534,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rope(x)", "rope_back(x)", "alibi(x)", + "clamp(x)", "conv_1d_1s(x)", "conv_1d_2s(x)", @@ -3527,7 +3545,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "f(x,y)", }; -static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50"); +static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -3761,6 +3779,12 @@ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct g (t1->ne[3]%t0->ne[3] == 0); } +static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1); +} + static inline int ggml_up32(int n) { return (n + 31) & ~31; } @@ -4643,11 +4667,15 @@ struct ggml_tensor * ggml_mul_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); + // TODO: support less-strict constraint + // GGML_ASSERT(ggml_can_repeat(b, a)); + GGML_ASSERT(ggml_can_repeat_rows(b, a)); bool is_node = false; if (!inplace && (a->grad || b->grad)) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); is_node = true; } @@ -6189,7 +6217,8 @@ struct ggml_tensor * ggml_alibi( struct ggml_context * ctx, struct ggml_tensor * a, int n_past, - int n_head) { + int n_head, + float bias_max) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6208,6 +6237,8 @@ struct ggml_tensor * ggml_alibi( ((int32_t *) b->data)[0] = n_past; ((int32_t *) b->data)[1] = n_head; + GGML_ASSERT(sizeof(float) == sizeof(int32_t)); + (((float *) b->data)[2]) = bias_max; ggml_scratch_load(ctx); @@ -6219,6 +6250,40 @@ struct ggml_tensor * ggml_alibi( return result; } +// ggml_clamp + +struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + ggml_scratch_save(ctx); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + + ((float *) b->data)[0] = min; + ((float *) b->data)[1] = max; + + ggml_scratch_load(ctx); + + result->op = GGML_OP_CLAMP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + // ggml_conv_1d_1s struct ggml_tensor * ggml_conv_1d_1s( @@ -7945,7 +8010,7 @@ static void ggml_compute_forward_mul_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -7953,10 +8018,25 @@ static void ggml_compute_forward_mul_f32( const int ith = params->ith; const int nth = params->nth; - const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; +#ifdef GGML_USE_CUBLAS + if (src1->backend == GGML_BACKEND_CUDA) { + if (ith == 0) { + ggml_cuda_mul(src0, src1, dst); + } + return; + } +#endif + + const int64_t nr = ggml_nrows(src0); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; @@ -7975,44 +8055,51 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(ne00 == ne10); if (nb10 == sizeof(float)) { - for (int ir = ith; ir < nr; ir += nth) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); #ifdef GGML_USE_ACCELERATE UNUSED(ggml_vec_mul_f32); - vDSP_vmul( - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); + vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); #else - ggml_vec_mul_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr); #endif // } // } } } else { // src1 is not contiguous - for (int ir = ith; ir < nr; ir += nth) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne00; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); } @@ -10501,34 +10588,29 @@ static void ggml_compute_forward_diag_mask_f32( assert(src1->type == GGML_TYPE_I32); assert(ggml_nelements(src1) == 2); - const int n_past = ((int32_t *) src1->data)[0]; - const bool inplace = (bool)((int32_t *) src1->data)[1]; - - if (params->type == GGML_TASK_INIT) { - // TODO: this hack is not good, need a better way to handle this - if (!inplace) { - // use the init task to copy src -> dst - struct ggml_compute_params params_cpy = *params; - - params_cpy.ith = 0; - params_cpy.nth = 1; - params_cpy.type = GGML_TASK_COMPUTE; - - ggml_compute_forward_dup_same_cont(¶ms_cpy, src0, dst); - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; + const int n_past = ((int32_t *) src1->data)[0]; + const bool inplace = (bool)((int32_t *) src1->data)[1]; + assert(n_past >= 0); + if (!inplace && (params->type == GGML_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); + } + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + // TODO: handle transposed/permuted matrices const int n = ggml_nrows(src0); @@ -10682,14 +10764,15 @@ static void ggml_compute_forward_alibi_f32( struct ggml_tensor * dst) { assert(params->ith == 0); assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); + assert(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int n_past = ((int32_t *) src1->data)[0]; - const int n_head = ((int32_t *) src1->data)[1]; + const int n_past = ((int32_t *) src1->data)[0]; + const int n_head = ((int32_t *) src1->data)[1]; + const float max_bias = ((float *) src1->data)[2]; assert(n_past >= 0); @@ -10712,8 +10795,8 @@ static void ggml_compute_forward_alibi_f32( // add alibi to src0 (KQ_scaled) const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor); - const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor); + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); for (int i = 0; i < ne0; i++) { for (int j = 0; j < ne1; j++) { @@ -10731,13 +10814,13 @@ static void ggml_compute_forward_alibi_f32( m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); } - pdst[0] = i * m_k + src[0]; + pdst[0] = (i-ne0+1) * m_k + src[0]; + } } } } - static void ggml_compute_forward_alibi_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -10745,14 +10828,15 @@ static void ggml_compute_forward_alibi_f16( struct ggml_tensor * dst) { assert(params->ith == 0); assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); + assert(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int n_past = ((int32_t *) src1->data)[0]; - const int n_head = ((int32_t *) src1->data)[1]; + const int n_past = ((int32_t *) src1->data)[0]; + const int n_head = ((int32_t *) src1->data)[1]; + const float max_bias = ((float *) src1->data)[2]; assert(n_past >= 0); @@ -10775,8 +10859,8 @@ static void ggml_compute_forward_alibi_f16( // add alibi to src0 (KQ_scaled) const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor); - const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor); + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); for (int i = 0; i < ne0; i++) { for (int j = 0; j < ne1; j++) { @@ -10795,7 +10879,7 @@ static void ggml_compute_forward_alibi_f16( } // we return F32 - pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]); + pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]); } } } @@ -10831,6 +10915,77 @@ static void ggml_compute_forward_alibi( } } + +// ggml_compute_forward_clamp + +static void ggml_compute_forward_clamp_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(src1->type == GGML_TYPE_I32); + assert(ggml_nelements(src1) == 2); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int min = ((float *) src1->data)[0]; + const int max = ((float *) src1->data)[1]; + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * src0_ptr = (float *) ((char *) src0->data + j*nb01); + + for (int i = 0; i < nc; i++) { + dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); + } + } +} + +static void ggml_compute_forward_clamp( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_clamp_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_rope static void ggml_compute_forward_rope_f32( @@ -12812,6 +12967,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor); } break; + case GGML_OP_CLAMP: + { + ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor); + } break; case GGML_OP_CONV_1D_1S: { ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor); @@ -13119,6 +13278,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_CLAMP: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_SILU: { // necessary for llama @@ -13998,6 +14161,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = 1; //TODO } break; + case GGML_OP_CLAMP: + { + node->n_tasks = 1; //TODO + } break; case GGML_OP_CONV_1D_1S: case GGML_OP_CONV_1D_2S: { diff --git a/ggml.h b/ggml.h index 255541d..51a616c 100644 --- a/ggml.h +++ b/ggml.h @@ -190,7 +190,7 @@ #define GGML_FILE_MAGIC 0x67676d6c // "ggml" #define GGML_FILE_VERSION 1 -#define GGML_QNT_VERSION 1 // bump this on quantization format changes +#define GGML_QNT_VERSION 2 // bump this on quantization format changes #define GGML_QNT_VERSION_FACTOR 1000 // do not change this #define GGML_MAX_DIMS 4 @@ -313,6 +313,7 @@ extern "C" { GGML_OP_ROPE, GGML_OP_ROPE_BACK, GGML_OP_ALIBI, + GGML_OP_CLAMP, GGML_OP_CONV_1D_1S, GGML_OP_CONV_1D_2S, @@ -849,7 +850,7 @@ extern "C" { int n_past); // in-place, returns view(a) - GGML_API struct ggml_tensor * gml_diag_mask_zero_inplace( + GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace( struct ggml_context * ctx, struct ggml_tensor * a, int n_past); @@ -897,7 +898,16 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, int n_past, - int n_head); + int n_head, + float bias_max); + + // clamp + // in-place, returns view(a) + struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max); // padding = 1 // TODO: we don't support extra parameters for now