From b6c5f49b78b214b7b4aa7392a8ba489c78b7382a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 15 Nov 2023 16:12:52 +0200 Subject: [PATCH] whisper : add batched decoding (#1486) * whisper : add whisper_batch * whisper : move kv_self to whisper_state * whisper : full batched decoding support * whisper : fix memory leak in whisper_batch * whisper : fix mem leak again + remove oboslete function * whisper : clear kv cache when using whisper_decode API * whisper : speed-up sampling * whisper : fix decoders initializer * bench : add batch size 5 bench * whisper : add comment about the KV cache size * whisper : add check for max number of decoders * whisper : avoid starting sampling threads with bs=1 * whisper : enable beam-search by default * cuda : sync llama.cpp fixes --- examples/bench/bench.cpp | 30 +- examples/main/main.cpp | 8 +- extra/bench-all.sh | 7 +- ggml-cuda.cu | 306 ++++++----- ggml-cuda.h | 5 + whisper.cpp | 1048 ++++++++++++++++++++++---------------- whisper.h | 4 +- 7 files changed, 836 insertions(+), 572 deletions(-) diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index db1c4e8..949e573 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) { } // heat encoder if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + fprintf(stderr, "error: failed to encode: %d\n", ret); return 4; } @@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) { // prompt heat if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + fprintf(stderr, "error: failed to decode: %d\n", ret); return 4; } // text-generation heat if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + fprintf(stderr, "error: failed to decode: %d\n", ret); return 4; } @@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) { // actual run if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + fprintf(stderr, "error: failed to encode: %d\n", ret); return 4; } - for (int i = 0; i < 16; i++) { - if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + // text-generation + for (int i = 0; i < 256; i++) { + if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) { + fprintf(stderr, "error: failed to decode: %d\n", ret); return 4; } } - for (int i = 0; i < 256; i++) { - if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode model: %d\n", ret); + // batched decoding + for (int i = 0; i < 64; i++) { + if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to decode: %d\n", ret); + return 4; + } + } + + // prompt processing + for (int i = 0; i < 16; i++) { + if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to decode: %d\n", ret); return 4; } } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e43dfe3..98af583 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -62,8 +62,8 @@ struct whisper_params { int32_t progress_step = 5; int32_t max_context = -1; int32_t max_len = 0; - int32_t best_of = 2; - int32_t beam_size = -1; + int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of; + int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size; float word_thold = 0.01f; float entropy_thold = 2.40f; @@ -925,9 +925,9 @@ int main(int argc, char ** argv) { if (params.detect_language) { params.language = "auto"; } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n", + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n", __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, - params.n_threads, params.n_processors, + params.n_threads, params.n_processors, params.beam_size, params.best_of, params.language.c_str(), params.translate ? "translate" : "transcribe", params.tinydiarize ? "tdrz = 1, " : "", diff --git a/extra/bench-all.sh b/extra/bench-all.sh index db04267..af8f675 100755 --- a/extra/bench-all.sh +++ b/extra/bench-all.sh @@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then printf "\n" fi -printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit" -printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" +printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit" +printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---" for model in "${models[@]}"; do # actual run @@ -56,6 +56,7 @@ for model in "${models[@]}"; do # parse the output: encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}') decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}') + batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}') prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}') system_info=$(echo "$output" | grep "system_info") n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}') @@ -94,6 +95,6 @@ for model in "${models[@]}"; do commit=$(git rev-parse --short HEAD) if [ $ret -eq 0 ]; then - printf "| | | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit" + printf "| | | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit" fi done diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 058011a..c0c9edd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -39,7 +39,6 @@ #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess -#define cudaDeviceGetMemPool hipDeviceGetMemPool #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t @@ -49,7 +48,6 @@ #define cudaEvent_t hipEvent_t #define cudaEventDestroy hipEventDestroy #define cudaFree hipFree -#define cudaFreeAsync hipFreeAsync #define cudaFreeHost hipHostFree #define cudaGetDevice hipGetDevice #define cudaGetDeviceCount hipGetDeviceCount @@ -57,7 +55,6 @@ #define cudaGetErrorString hipGetErrorString #define cudaGetLastError hipGetLastError #define cudaMalloc hipMalloc -#define cudaMallocFromPoolAsync hipMallocFromPoolAsync #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #define cudaMemcpy hipMemcpy #define cudaMemcpy2DAsync hipMemcpy2DAsync @@ -66,9 +63,6 @@ #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost #define cudaMemcpyHostToDevice hipMemcpyHostToDevice #define cudaMemcpyKind hipMemcpyKind -#define cudaMemPool_t hipMemPool_t -#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold -#define cudaMemPoolSetAttribute hipMemPoolSetAttribute #define cudaMemset hipMemset #define cudaMemsetAsync hipMemsetAsync #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize @@ -94,6 +88,8 @@ #define CC_OFFSET_AMD 1000000 #define CC_RDNA2 (CC_OFFSET_AMD + 1030) +#define GGML_CUDA_MAX_NODES 8192 + // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant // for large computational tasks. the drawback is that this requires some extra amount of VRAM: @@ -188,11 +184,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); do { \ cudaError_t err_ = (err); \ if (err_ != cudaSuccess) { \ - int dev_id; \ - cudaGetDevice(&dev_id); \ + int id; \ + cudaGetDevice(&id); \ fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ cudaGetErrorString(err_)); \ - fprintf(stderr, "current device: %d\n", dev_id); \ + fprintf(stderr, "current device: %d\n", id); \ exit(1); \ } \ } while (0) @@ -202,11 +198,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); do { \ cublasStatus_t err_ = (err); \ if (err_ != CUBLAS_STATUS_SUCCESS) { \ - int dev_id; \ - cudaGetDevice(&dev_id); \ + int id; \ + cudaGetDevice(&id); \ fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ - fprintf(stderr, "current device: %d\n", dev_id); \ + fprintf(stderr, "current device: %d\n", id); \ exit(1); \ } \ } while (0) @@ -440,6 +436,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_RELU_BLOCK_SIZE 256 +#define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 #define CUDA_CLAMP_BLOCK_SIZE 256 @@ -472,7 +470,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA #define MAX_STREAMS 8 static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; -static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr }; struct ggml_tensor_extra_gpu { void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors @@ -561,6 +558,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } +static __global__ void relu_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = fmaxf(x[i], 0); +} + +static __global__ void sqr_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = x[i] * x[i]; +} + static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -990,7 +1005,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; @@ -1094,7 +1109,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; @@ -1198,7 +1213,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; @@ -1452,7 +1467,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; @@ -4262,7 +4277,7 @@ template static __global__ void template static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row >= nrows) { return; @@ -4302,7 +4317,7 @@ template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block // qr = number of quantized weights per data value in x block - const int row = blockIdx.y*blockDim.y + threadIdx.y; + const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row >= nrows) { return; @@ -4741,7 +4756,7 @@ static __global__ void im2col_f32_f16( int ofs0, int ofs1, int IW, int IH, int CHW, int s0, int s1, int p0, int p1, int d0, int d1) { const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0; - const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1; + const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1; const int offset_dst = (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW + @@ -4793,6 +4808,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ silu_f32<<>>(x, dst, k); } +static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + relu_f32<<>>(x, dst, k); +} + +static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE; + sqr_f32<<>>(x, dst, k); +} + static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { @@ -4901,7 +4926,8 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); @@ -4910,7 +4936,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); @@ -4919,7 +4945,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); @@ -4928,7 +4954,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); @@ -4937,7 +4963,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); @@ -4947,7 +4973,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f GGML_ASSERT(ncols % QK_K == 0); const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); } @@ -4956,7 +4982,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); } @@ -4965,7 +4991,7 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); } @@ -4980,7 +5006,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); } @@ -4988,7 +5014,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -4997,7 +5023,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK4_1 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5006,7 +5032,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK5_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5015,7 +5041,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK5_1 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5024,7 +5050,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK8_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5033,7 +5059,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5042,7 +5068,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5051,7 +5077,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5060,7 +5086,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5069,7 +5095,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); @@ -5088,7 +5114,7 @@ static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cu static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec<1, 1, convert_f16> <<>>(vx, y, dst, ncols, nrows); @@ -5825,16 +5851,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } -static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) { - if (g_cudaMemPools[id] == nullptr) { - return ggml_cuda_pool_malloc(size, actual_size); - } - void *ptr; - CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream)); - *actual_size = size; - return ptr; -} - static void ggml_cuda_pool_free(void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); int id; @@ -5852,12 +5868,10 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { CUDA_CHECK(cudaFree(ptr)); } +static bool g_cublas_loaded = false; -static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) { - if (g_cudaMemPools[id] == nullptr) { - return ggml_cuda_pool_free(ptr, actual_size); - } - CUDA_CHECK(cudaFreeAsync(ptr, stream)); +bool ggml_cublas_loaded(void) { + return g_cublas_loaded; } void ggml_init_cublas() { @@ -5872,7 +5886,12 @@ void ggml_init_cublas() { CUDA_CHECK(cudaDeviceSynchronize()); #endif - CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); + if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) { + initialized = true; + g_cublas_loaded = false; + return; + } + GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; #if defined(GGML_CUDA_FORCE_MMQ) @@ -5914,19 +5933,13 @@ void ggml_init_cublas() { // create cublas handle CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); - - // configure memory pool - cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id); - if (err == cudaSuccess) { - size_t treshold = UINT64_MAX; - CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold)); - } } // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); initialized = true; + g_cublas_loaded = true; } } @@ -6193,6 +6206,34 @@ inline void ggml_cuda_op_silu( (void) src1_dd; } +inline void ggml_cuda_op_relu( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_sqr( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + inline void ggml_cuda_op_norm( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -6514,7 +6555,7 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); GGML_ASSERT(to_fp16_cuda != nullptr); size_t ne = row_diff*ne00; - src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream); + src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); } const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; @@ -6525,12 +6566,12 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); GGML_ASSERT(to_fp16_cuda != nullptr); size_t ne = src1_ncols*ne10; - src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream); + src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16; - size_t dst_f16_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream); + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); const half alpha_f16 = 1.0f; const half beta_f16 = 0.0f; @@ -6548,15 +6589,14 @@ inline void ggml_cuda_op_mul_mat_cublas( const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); - if (dst_f16_as != 0) { - ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream); - } + ggml_cuda_pool_free(dst_f16, dst_as); if (src0_as != 0) { - ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream); + ggml_cuda_pool_free(src0_as_f16, src0_as); } + if (src1_as != 0) { - ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream); + ggml_cuda_pool_free(src1_as_f16, src1_as); } } else { @@ -6566,7 +6606,7 @@ inline void ggml_cuda_op_mul_mat_cublas( if (src0->type != GGML_TYPE_F32) { const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); GGML_ASSERT(to_fp32_cuda != nullptr); - src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT + src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); } const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; @@ -6583,7 +6623,7 @@ inline void ggml_cuda_op_mul_mat_cublas( &beta, dst_dd_i, ldc)); if (src0_as != 0) { - ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream); + ggml_cuda_pool_free(src0_ddq_as_f32, src0_as); } } @@ -7008,6 +7048,8 @@ static void ggml_cuda_op_mul_mat( int64_t row_low[GGML_CUDA_MAX_DEVICES]; int64_t row_high[GGML_CUDA_MAX_DEVICES]; + int used_devices = 0; + for (int64_t id = 0; id < g_device_count; ++id) { // by default, use all rows row_low[id] = 0; @@ -7035,6 +7077,8 @@ static void ggml_cuda_op_mul_mat( continue; } + used_devices++; + const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; @@ -7045,22 +7089,21 @@ static void ggml_cuda_op_mul_mat( src0_dd[id] = (char *) src0_extra->data_device[id]; } else { const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); - src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream); + src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); } if (src1_on_device && src1_is_contiguous) { src1_ddf[id] = (float *) src1_extra->data_device[id]; } else { - src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream); + src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]); } if (convert_src1_to_q8_1) { - const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; - src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream); + src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); if (src1_on_device && src1_is_contiguous) { quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); - // CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaGetLastError()); } } @@ -7068,18 +7111,18 @@ static void ggml_cuda_op_mul_mat( dst_dd[id] = (float *) dst_extra->data_device[id]; } else { const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); - dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream); + dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]); } } // if multiple devices are used they need to wait for the main device // here an event is recorded that signals that the main device has finished calculating the input data - if (split && g_device_count > 1) { + if (split && used_devices > 1) { CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0])); } - const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; + const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0; const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; @@ -7194,6 +7237,27 @@ static void ggml_cuda_op_mul_mat( } } + for (int64_t id = 0; id < g_device_count; ++id) { + if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { + continue; + } + CUDA_CHECK(ggml_cuda_set_device(id)); + + // free buffers again when done + if (src0_as[id] > 0) { + ggml_cuda_pool_free(src0_dd[id], src0_as[id]); + } + if (src1_asf[id] > 0) { + ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); + } + if (src1_asq[id] > 0) { + ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); + } + if (dst_as[id] > 0) { + ggml_cuda_pool_free(dst_dd[id], dst_as[id]); + } + } + // main device waits for all other devices to be finished if (split && g_device_count > 1) { int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; @@ -7201,6 +7265,9 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(ggml_cuda_set_device(g_main_device)); for (int64_t id = 0; id < g_device_count; ++id) { + if (row_low[id] == row_high[id]) { + continue; + } for (int64_t is = 0; is < is_max; ++is) { CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0)); } @@ -7211,21 +7278,6 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(cudaDeviceSynchronize()); } - - for (int64_t id = 0; id < g_device_count; ++id) { - if (src0_as[id] > 0) { - ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]); - } - if (src1_asf[id] > 0) { - ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]); - } - if (src1_asq[id] > 0) { - ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]); - } - if (dst_as[id] > 0) { - ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]); - } - } } static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -7252,6 +7304,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu); } +static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu); +} + +static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr); +} + static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm); } @@ -7261,6 +7321,8 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src } bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + if (!g_cublas_loaded) return false; + const int64_t ne10 = src1->ne[0]; const int64_t ne0 = dst->ne[0]; @@ -7412,11 +7474,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const GGML_ASSERT(to_fp16_cuda != nullptr); size_t src1_as = 0; - half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream); + half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream); + half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -7470,8 +7532,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const size_t ptrs_src_s = 0; size_t ptrs_dst_s = 0; - ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream); - ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream); + ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s); + ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s); dim3 block_dims(ne13, ne12); k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( @@ -7484,6 +7546,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const dst->nb[2], dst->nb[3], r2, r3); CUDA_CHECK(cudaGetLastError()); + CUBLAS_CHECK( cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, @@ -7495,30 +7558,29 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUBLAS_GEMM_DEFAULT_TENSOR_OP)); if (ptrs_src_s != 0) { - ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream); + ggml_cuda_pool_free(ptrs_src, ptrs_src_s); } if (ptrs_dst_s != 0) { - ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream); + ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s); } } #endif const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); - if (src1_as != 0) { - ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream); - } - if (dst_as != 0) { - ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream); - } + + ggml_cuda_pool_free(src1_as_f16, src1_as); + ggml_cuda_pool_free(dst_f16, dst_as); } static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = - (src0->backend == GGML_BACKEND_GPU) && + (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && (src1->backend == GGML_BACKEND_GPU) && ( dst->backend == GGML_BACKEND_GPU); + const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + int64_t min_compute_capability = INT_MAX; for (int64_t id = 0; id < g_device_count; ++id) { if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { @@ -7540,13 +7602,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // KQ single-batch ggml_cuda_mul_mat_vec_p021(src0, src1, dst); - } else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_cuda_mul_mat_vec_nc(src0, src1, dst); - } else if (all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { + } else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { @@ -7667,7 +7729,7 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); } -void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col); } @@ -7782,11 +7844,11 @@ static size_t g_temp_tensor_extra_index = 0; static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { if (g_temp_tensor_extras == nullptr) { - g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE]; + g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES]; } size_t alloc_index = g_temp_tensor_extra_index; - g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE; + g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES; ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index]; memset(extra, 0, sizeof(*extra)); @@ -7953,6 +8015,8 @@ void ggml_cuda_free_scratch() { } bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { + if (!g_cublas_loaded) return false; + ggml_cuda_func_t func; const bool any_on_device = tensor->backend == GGML_BACKEND_GPU || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) @@ -7995,6 +8059,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_UNARY_OP_SILU: func = ggml_cuda_silu; break; + case GGML_UNARY_OP_RELU: + func = ggml_cuda_relu; + break; default: return false; } break; @@ -8013,6 +8080,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_SCALE: func = ggml_cuda_scale; break; + case GGML_OP_SQR: + func = ggml_cuda_sqr; + break; case GGML_OP_CLAMP: if (!any_on_device) { return false; @@ -8105,11 +8175,11 @@ struct ggml_backend_buffer_context_cuda { ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { if (temp_tensor_extras == nullptr) { - temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE]; + temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES]; } size_t alloc_index = temp_tensor_extra_index; - temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE; + temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES; ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index]; memset(extra, 0, sizeof(*extra)); diff --git a/ggml-cuda.h b/ggml-cuda.h index 57adc9c..528e66c 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -17,7 +17,12 @@ extern "C" { #define GGML_CUDA_MAX_DEVICES 16 +// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`. GGML_API void ggml_init_cublas(void); + +// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`. +GGML_API bool ggml_cublas_loaded(void); + GGML_API void * ggml_cuda_host_malloc(size_t size); GGML_API void ggml_cuda_host_free(void * ptr); diff --git a/whisper.cpp b/whisper.cpp index c0e9115..a3e0fbd 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -20,6 +20,7 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +#include #include #include #define _USE_MATH_DEFINES @@ -147,7 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text //#define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF -#define WHISPER_MAX_DECODERS 16 +#define WHISPER_MAX_DECODERS 8 #define WHISPER_MAX_NODES 4096 // @@ -406,6 +407,121 @@ struct whisper_segment { bool speaker_turn_next; }; +struct whisper_batch { + int32_t n_tokens; + + whisper_token * token; + whisper_pos * pos; + int32_t * n_seq_id; + whisper_seq_id ** seq_id; // null terminated + int8_t * logits; +}; + +static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) { + whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, }; + + batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens)); + batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max); + } + batch.seq_id[n_tokens] = nullptr; + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void whisper_batch_free(struct whisper_batch batch) { + if (batch.token) free(batch.token); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i]; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) { + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) { + if (tokens) { + batch.token[i] = tokens[i]; + } + batch.pos [i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + +// replace std::pair by using customized pair struct (reason: std::pair is very slow) +template +struct whisper_pair { + A first; + B second; + + // Define a constructor that takes two arguments. + whisper_pair(const A& a, const B& b) : first(a), second(b) {} + // Define a constructor that takes no argument. + whisper_pair() : first(A()), second(B()) {} +}; + +// ggml_allocr wrapper for whisper usage +struct whisper_allocr { + ggml_allocr * alloc = nullptr; + + std::vector meta; + + ggml_backend_buffer_t buffer; +}; + +static size_t whisper_allocr_size(struct whisper_allocr & allocr) { + return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc); +} + +// measure the memory usage of a graph and prepare the allocr's internal data buffer +static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function && get_graph) { + auto & alloc = allocr.alloc; + auto & meta = allocr.meta; + + alloc = ggml_allocr_new_measure_from_backend(backend); + + meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); + + ggml_allocr_alloc_graph(alloc, get_graph()); +} + +static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) { + if (allocr.alloc == nullptr) { + // this can be null if we use external encoder like CoreML or OpenVINO + return; + } + + auto & alloc = allocr.alloc; + auto & buffer = allocr.buffer; + + size_t size = ggml_allocr_max_size(alloc); + + ggml_allocr_free(alloc); + + buffer = ggml_backend_alloc_buffer(backend, size); + alloc = ggml_allocr_new_from_buffer(buffer); +} + +static void whisper_allocr_free(struct whisper_allocr & allocr) { + if (allocr.alloc) { + ggml_allocr_free(allocr.alloc); + ggml_backend_buffer_free(allocr.buffer); + allocr.alloc = nullptr; + } +} + // medium // hparams: { // 'n_mels': 80, @@ -523,15 +639,31 @@ struct whisper_layer_decoder { struct ggml_tensor * mlp_1_b; }; +struct whisper_kv_cell { + whisper_pos pos = -1; + + std::set seq_id; + + bool has_seq_id(const whisper_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } +}; + struct whisper_kv_cache { + uint32_t head = 0; + uint32_t size = 0; + + // computed before each graph build + uint32_t n = 0; + + std::vector cells; + struct ggml_tensor * k; struct ggml_tensor * v; struct ggml_context * ctx; ggml_backend_buffer_t buffer; - - int n; // number of tokens currently in the cache }; struct whisper_model { @@ -585,11 +717,11 @@ struct whisper_partial_utf8 { }; struct whisper_grammar { - /*const*/ std::vector> rules; - std::vector> stacks; + /*const*/ std::vector> rules; + std::vector> stacks; // buffer for partially generated UTF-8 sequence from accepted tokens - whisper_partial_utf8 partial_utf8; + whisper_partial_utf8 partial_utf8; }; struct whisper_grammar_candidate { @@ -613,15 +745,13 @@ struct whisper_sequence { // TAGS: WHISPER_DECODER_INIT struct whisper_decoder { - // each decoder keeps its own KV-cache - whisper_kv_cache kv_self; - // the currently generated sequence of tokens whisper_sequence sequence; // grammar parse state of generated sequence of tokens whisper_grammar grammar; + int i_batch; // the index of the token in the current batch int seek_delta; // the window shift found so far based on the decoded timestamp tokens bool failed; // has the current segment failed to decode? @@ -633,100 +763,40 @@ struct whisper_decoder { std::vector logits; std::vector logprobs; - std::vector tokens_tmp; // used for whisper_decode calls + // work container used to avoid memory allocations + std::vector> logits_id; + + mutable std::mt19937 rng; // used for sampling at t > 0.0 }; -// replace std::pair by using customized pair struct (reason: std::pair is very slow) -template -struct whisper_pair { - A first; - B second; - - // Define a constructor that takes two arguments. - whisper_pair(const A& a, const B& b) : first(a), second(b) {} - // Define a constructor that takes no argument. - whisper_pair() : first(A()), second(B()) {} -}; - -// beam-search helpers -struct kv_buf { - std::vector k; - std::vector v; -}; - -// ggml_allocr wrapper for whisper usage -struct whisper_allocr { - ggml_allocr * alloc = nullptr; - - std::vector meta; - - ggml_backend_buffer_t buffer; -}; - -static size_t whisper_allocr_size(struct whisper_allocr & allocr) { - return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc); -} - -// measure the memory usage of a graph and prepare the allocr's internal data buffer -static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function && get_graph) { - auto & alloc = allocr.alloc; - auto & meta = allocr.meta; - - alloc = ggml_allocr_new_measure_from_backend(backend); - - meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); - - ggml_allocr_alloc_graph(alloc, get_graph()); -} - -static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) { - if (allocr.alloc == nullptr) { - // this can be null if we use external encoder like CoreML or OpenVINO - return; - } - - auto & alloc = allocr.alloc; - auto & buffer = allocr.buffer; - - size_t size = ggml_allocr_max_size(alloc); - - ggml_allocr_free(alloc); - - buffer = ggml_backend_alloc_buffer(backend, size); - alloc = ggml_allocr_new_from_buffer(buffer); -} - -static void whisper_allocr_free(struct whisper_allocr & allocr) { - if (allocr.alloc) { - ggml_allocr_free(allocr.alloc); - ggml_backend_buffer_free(allocr.buffer); - allocr.alloc = nullptr; - } -} - struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; int64_t t_decode_us = 0; + int64_t t_batchd_us = 0; int64_t t_prompt_us = 0; int64_t t_mel_us = 0; int32_t n_sample = 0; // number of tokens sampled int32_t n_encode = 0; // number of encoder calls - int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) - int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding) + int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures + // unified self-attention KV cache for all decoders + whisper_kv_cache kv_self; + // cross-attention KV cache for the decoders // shared between all decoders whisper_kv_cache kv_cross; + whisper_mel mel; - whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; + whisper_batch batch; - // buffer for swapping KV caches between decoders during beam-search - std::vector kv_swap_bufs; + whisper_decoder decoders[WHISPER_MAX_DECODERS]; ggml_backend_t backend = nullptr; @@ -742,8 +812,9 @@ struct whisper_state { struct ggml_tensor * embd_conv = nullptr; struct ggml_tensor * embd_enc = nullptr; - // helper for GPU offloading + // helpers for GPU offloading std::vector inp_mel; + std::vector inp_mask; // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -751,11 +822,6 @@ struct whisper_state { std::vector result_all; std::vector prompt_past; - // work container used to avoid memory allocations - std::vector> logits_id; - - mutable std::mt19937 rng; // used for sampling at t > 0.0 - int lang_id = 0; // english by default std::string path_model; // populated by whisper_init_from_file_with_params() @@ -831,48 +897,11 @@ static bool kv_cache_init( /*.no_alloc =*/ true, }; - cache.ctx = ggml_init(params); + cache.head = 0; + cache.size = n_ctx; - if (!cache.ctx) { - WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); - return false; - } - - cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); - cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); - - const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v); - - cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes); - - // allocate the tensors into the backend buffer - { - ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer); - - ggml_allocr_alloc(alloc, cache.k); - ggml_allocr_alloc(alloc, cache.v); - - ggml_allocr_free(alloc); - } - - return true; -} - -// TODO: remove after batched decoding -static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) { - WHISPER_ASSERT(cache.ctx); - - const int n_elements = ggml_nelements(cache.k); - WHISPER_ASSERT(n_elements == ggml_nelements(cache.v)); - - const ggml_type wtype = cache.k->type; - WHISPER_ASSERT(wtype == cache.v->type); - - struct ggml_init_params params = { - /*.mem_size =*/ 2*ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; + cache.cells.clear(); + cache.cells.resize(n_ctx); cache.ctx = ggml_init(params); @@ -909,12 +938,130 @@ static void kv_cache_free(struct whisper_kv_cache & cache) { } } +static bool whisper_kv_cache_find_slot( + struct whisper_kv_cache & cache, + const struct whisper_batch & batch) { + const uint32_t n_ctx = cache.size; + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens > n_ctx) { + WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); + return false; + } + + uint32_t n_tested = 0; + + while (true) { + if (cache.head + n_tokens > n_ctx) { + n_tested += n_ctx - cache.head; + cache.head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.cells[cache.head + i].pos >= 0) { + found = false; + cache.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= n_ctx) { + //WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + + for (uint32_t i = 0; i < n_tokens; i++) { + cache.cells[cache.head + i].pos = batch.pos[i]; + + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + } + } + + return true; +} + +// find how many cells are currently in use +static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) { + for (uint32_t i = cache.size - 1; i > 0; --i) { + if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) { + return i + 1; + } + } + + return 1; +} + +static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) { + for (int32_t i = 0; i < (int32_t) cache.size; ++i) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } + cache.head = 0; +} + +static void whisper_kv_cache_seq_rm( + struct whisper_kv_cache & cache, + whisper_seq_id seq_id, + whisper_pos p0, + whisper_pos p1) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + if (seq_id < 0) { + cache.cells[i].seq_id.clear(); + } else if (cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cache.cells[i].seq_id.empty()) { + cache.cells[i].pos = -1; + if (new_head == cache.size) new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size) cache.head = new_head; +} + +static void whisper_kv_cache_seq_cp( + struct whisper_kv_cache & cache, + whisper_seq_id seq_id_src, + whisper_seq_id seq_id_dst, + whisper_pos p0, + whisper_pos p1) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + cache.head = 0; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.insert(seq_id_dst); + } + } +} + static ggml_backend_t whisper_backend_init(const whisper_context_params & params) { ggml_backend_t backend_gpu = NULL; // initialize the backends #ifdef GGML_USE_CUBLAS - if (params.use_gpu) { + if (params.use_gpu && ggml_cublas_loaded()) { WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); backend_gpu = ggml_backend_cuda_init(); if (!backend_gpu) { @@ -1116,6 +1263,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word = "[_EOT_]"; } else if (i == vocab.token_sot) { word = "[_SOT_]"; + } else if (i == vocab.token_translate) { + word = "[_TRANSLATE_]"; + } else if (i == vocab.token_transcribe) { + word = "[_TRANSCRIBE_]"; } else if (i == vocab.token_solm) { word = "[_SOLM_]"; } else if (i == vocab.token_prev) { @@ -1126,6 +1277,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word = "[_NOT_]"; } else if (i == vocab.token_beg) { word = "[_BEG_]"; + } else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) { + word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]"; } else { word = "[_extra_token_" + std::to_string(i) + "]"; } @@ -2031,26 +2184,28 @@ static bool whisper_encode_internal( static struct ggml_cgraph * whisper_build_graph_decoder( whisper_context & wctx, whisper_state & wstate, - whisper_decoder & decoder, - const whisper_token * tokens, - int n_tokens, - int n_past) { + const whisper_batch & batch) { const auto & model = wctx.model; const auto & hparams = model.hparams; - auto & kv_self = decoder.kv_self; + auto & kv_self = wstate.kv_self; WHISPER_ASSERT(!!kv_self.ctx); - const int n_ctx = hparams.n_text_ctx; + ggml_allocr * alloc = wstate.alloc_decode.alloc; + + const int n_ctx = kv_self.size; const int n_state = hparams.n_text_state; const int n_head = hparams.n_text_head; const int n_layer = hparams.n_text_layer; - const int N = n_tokens; - const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_tokens = batch.n_tokens; + const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; - //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); + const int32_t n_kv = ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head; + + //WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); struct ggml_init_params params = { /*.mem_size =*/ wstate.alloc_decode.meta.size(), @@ -2062,21 +2217,19 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); - ggml_allocr * alloc = wstate.alloc_decode.alloc; - - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(alloc, embd); if (!ggml_allocr_is_measure(alloc)) { - ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd)); + ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd)); } - struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(alloc, position); if (!ggml_allocr_is_measure(alloc)) { - for (int i = 0; i < N; ++i) { - const int32_t val = n_past + i; + for (int i = 0; i < n_tokens; ++i) { + const int32_t val = batch.pos[i]; ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t)); } } @@ -2089,6 +2242,31 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float)); } + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + ggml_allocr_alloc(alloc, KQ_mask); + + if (!ggml_allocr_is_measure(alloc)) { + wstate.inp_mask.resize(n_kv*n_tokens); + + float * data = wstate.inp_mask.data(); + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const whisper_pos pos = batch.pos[j]; + const whisper_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + + ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float)); + } + // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -2141,12 +2319,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder( Vcur, layer.attn_v_b); - Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N)); + Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state, + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); @@ -2156,12 +2334,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N), + ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), 0, 2, 1, 3); struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_state/n_head, n_past + N, n_head, + n_state/n_head, n_kv, n_head, ggml_element_size(kv_self.k)*n_state, ggml_element_size(kv_self.k)*n_state/n_head, ggml_element_size(kv_self.k)*n_state*n_ctx*il); @@ -2171,16 +2349,17 @@ static struct ggml_cgraph * whisper_build_graph_decoder( //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_state/n_head, n_head, + n_kv, n_state/n_head, n_head, n_ctx*ggml_element_size(kv_self.v), n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, - il*n_ctx*ggml_element_size(kv_self.v)*n_state); + n_ctx*ggml_element_size(kv_self.v)*n_state*il); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); @@ -2188,7 +2367,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); } // projection @@ -2232,33 +2411,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // Kcross is already scaled struct ggml_tensor * Kcross = ggml_view_3d(ctx0, wstate.kv_cross.k, - n_state/n_head, M, n_head, + n_state/n_head, n_audio_ctx, n_head, ggml_element_size(wstate.kv_cross.k)*n_state, ggml_element_size(wstate.kv_cross.k)*n_state/n_head, - ggml_element_size(wstate.kv_cross.k)*n_state*M*il); + ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); //struct ggml_tensor * Vcross = // ggml_reshape_3d(ctx0, - // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state), - // n_state/n_head, n_head, M); + // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state), + // n_state/n_head, n_head, n_audio_ctx); //struct ggml_tensor * V_trans = // ggml_cpy(ctx0, // ggml_permute(ctx0, Vcross, 1, 2, 0, 3), - // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head)); + // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head)); struct ggml_tensor * V = ggml_view_3d(ctx0, wstate.kv_cross.v, - M, n_state/n_head, n_head, - M*ggml_element_size(wstate.kv_cross.v), - M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, - il*M*ggml_element_size(wstate.kv_cross.v)*n_state); + n_audio_ctx, n_state/n_head, n_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v), + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); // ------ struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N), + ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), 0, 2, 1, 3); // K * Q @@ -2279,10 +2458,10 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - // cur = KQV_merged.contiguous().view(n_state, N) + // cur = KQV_merged.contiguous().view(n_state, n_tokens) cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); } // projection @@ -2354,9 +2533,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder( } // compute logits only for the last token - // comment this line to compute logits for all N tokens + // comment this line to compute logits for all n_tokens // might be useful in the future - cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + //cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); @@ -2380,10 +2559,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( static bool whisper_decode_internal( whisper_context & wctx, whisper_state & wstate, - whisper_decoder & decoder, - const whisper_token * tokens, - const int n_tokens, - const int n_past, + const whisper_batch & batch, const int n_threads, whisper_abort_callback abort_callback, void * abort_callback_data) { @@ -2392,19 +2568,33 @@ static bool whisper_decode_internal( const auto & model = wctx.model; const auto & hparams = model.hparams; - const int n_vocab = hparams.n_vocab; + const int n_vocab = hparams.n_vocab; + const int n_tokens = batch.n_tokens; auto & logits_out = wstate.logits; struct ggml_tensor * logits; + // find KV slot for the batch + { + auto & kv_self = wstate.kv_self; + + if (!whisper_kv_cache_find_slot(kv_self, batch)) { + return false; + } + + kv_self.n = whisper_kv_cache_cell_max(kv_self); + //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); + //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); + } + // decoder { auto & alloc = wstate.alloc_decode.alloc; ggml_allocr_reset(alloc); - ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past); + ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch); ggml_allocr_alloc_graph(alloc, gf); @@ -2413,17 +2603,15 @@ static bool whisper_decode_internal( ggml_graph_compute_helper(wstate.backend, gf, n_threads); } - // extract logits for all N tokens - //logits_out.resize(n_tokens*n_vocab); - //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab); - //ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab); + logits_out.resize(n_tokens*n_vocab); + for (int i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab); + } - // extract logits only for the last token - logits_out.resize(n_vocab); - //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); - ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab); - - if (n_tokens > 1) { + if (batch.n_tokens > 1) { //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, // ggml_used_mem(ctx0)/1024.0/1024.0, // wstate.get_buf_max_mem(0)/1024.0/1024.0, @@ -2432,18 +2620,20 @@ static bool whisper_decode_internal( // wstate.get_buf_max_mem(3)/1024.0/1024.0); } - if (n_tokens == 1) { + if (batch.n_tokens == 1) { wstate.t_decode_us += ggml_time_us() - t_start_us; wstate.n_decode++; + } else if (batch.n_tokens < 16) { + wstate.t_batchd_us += ggml_time_us() - t_start_us; + wstate.n_batchd += n_tokens; } else { wstate.t_prompt_us += ggml_time_us() - t_start_us; - wstate.n_prompt++; + wstate.n_prompt += n_tokens; } return !(abort_callback && abort_callback(abort_callback_data)); } - // 500 -> 00:05.000 // 6000 -> 01:00.000 static std::string to_timestamp(int64_t t, bool comma = false) { @@ -2855,14 +3045,18 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->backend = whisper_backend_init(ctx->params); - if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) { + // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx + // in theory, there can be a case where this is not enough, but in practice it should always be enough + const int factor = 3; + + if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state; return nullptr; } { - const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v); + const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v); WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } @@ -2897,14 +3091,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); - state->logits_id.reserve(ctx->model.hparams.n_vocab); + state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS); // TAGS: WHISPER_DECODER_INIT state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); - state->decoders[0].probs.reserve (ctx->vocab.n_vocab); - state->decoders[0].logits.reserve (ctx->vocab.n_vocab); - state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); + state->decoders[0].probs.reserve (ctx->vocab.n_vocab); + state->decoders[0].logits.reserve (ctx->vocab.n_vocab); + state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab); + state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab); + + state->decoders[0].rng = std::mt19937(0); // conv allocator { @@ -2946,7 +3143,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { const int n_tokens = hparams.n_text_ctx; const int n_past = 0; - return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past); + whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0); + + return whisper_build_graph_decoder(*ctx, *state, state->batch); }); WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); @@ -2957,8 +3156,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend); whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend); - state->rng = std::mt19937(0); - return state; } @@ -3183,12 +3380,9 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa void whisper_free_state(struct whisper_state * state) { if (state) { + kv_cache_free(state->kv_self); kv_cache_free(state->kv_cross); - for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { - kv_cache_free(state->decoders[i].kv_self); - } - #ifdef WHISPER_USE_COREML if (state->ctx_coreml != nullptr) { whisper_coreml_free(state->ctx_coreml); @@ -3203,6 +3397,8 @@ void whisper_free_state(struct whisper_state * state) } #endif + whisper_batch_free(state->batch); + whisper_allocr_free(state->alloc_conv); whisper_allocr_free(state->alloc_encode); whisper_allocr_free(state->alloc_cross); @@ -3329,9 +3525,11 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { } int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - const int selected_decoder_id = 0; + whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0); - if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { + whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1); + + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3340,15 +3538,16 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state } int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - // TODO: add selected_decoder_id to state - const int selected_decoder_id = 0; - if (ctx->state == nullptr) { WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__); return false; } - if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { + whisper_kv_cache_seq_rm(ctx->state->kv_self, 0, n_past, -1); + + whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past, 0); + + if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->batch, n_threads, nullptr, nullptr)) { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3436,7 +3635,7 @@ int whisper_lang_auto_detect_with_state( return -7; } - auto & logits_id = state->logits_id; + auto & logits_id = state->decoders[0].logits_id; logits_id.clear(); for (const auto & kv : g_lang) { @@ -3639,6 +3838,7 @@ void whisper_print_timings(struct whisper_context * ctx) { const int32_t n_sample = std::max(1, ctx->state->n_sample); const int32_t n_encode = std::max(1, ctx->state->n_encode); const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_batchd = std::max(1, ctx->state->n_batchd); const int32_t n_prompt = std::max(1, ctx->state->n_prompt); WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); @@ -3646,6 +3846,7 @@ void whisper_print_timings(struct whisper_context * ctx) { WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd); WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); } WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); @@ -3662,6 +3863,7 @@ void whisper_reset_timings(struct whisper_context * ctx) { ctx->state->n_sample = 0; ctx->state->n_encode = 0; ctx->state->n_decode = 0; + ctx->state->n_batchd = 0; ctx->state->n_prompt = 0; } } @@ -3969,8 +4171,7 @@ static std::vector whisper_grammar_reject_candidates_ if (*tok.code_points == 0) { // reached end of full codepoints in token, reject iff it ended in a partial sequence // that cannot satisfy this position in grammar - if (tok.partial_utf8.n_remain != 0 && - !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { rejects.push_back(tok); } } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) { @@ -4189,7 +4390,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_initial_ts =*/ 1.0f, /*.length_penalty =*/ -1.0f, - /*.temperature_inc =*/ 0.4f, + /*.temperature_inc =*/ 0.2f, /*.entropy_thold =*/ 2.4f, /*.logprob_thold =*/ -1.0f, /*.no_speech_thold =*/ 0.6f, @@ -4229,13 +4430,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str case WHISPER_SAMPLING_GREEDY: { result.greedy = { - /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding + /*.best_of =*/ 5, }; } break; case WHISPER_SAMPLING_BEAM_SEARCH: { result.beam_search = { - /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding + /*.beam_size =*/ 5, /*.patience =*/ -1.0f, }; @@ -4325,11 +4526,12 @@ static const std::vector non_speech_tokens = { // process the logits for the selected decoder // - applies logit filters // - computes logprobs and probs +// TODO: optimize static void whisper_process_logits( struct whisper_context & ctx, struct whisper_state & state, - const struct whisper_full_params params, struct whisper_decoder & decoder, + const struct whisper_full_params params, float temperature) { const auto & vocab = ctx.vocab; const auto & tokens_cur = decoder.sequence.tokens; @@ -4346,7 +4548,7 @@ static void whisper_process_logits( auto & logprobs = decoder.logprobs; { logits.resize(n_logits); - memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float)); + memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float)); if (temperature > 0.0f) { for (int i = 0; i < n_logits; i++) { @@ -4512,30 +4714,31 @@ static void whisper_process_logits( //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); if (timestamp_logprob > max_text_token_logprob) { - //printf("sampling timestamp\n"); for (int i = 0; i < vocab.token_beg; ++i) { logits[i] = -INFINITY; logprobs[i] = -INFINITY; } - } else if (params.n_grammar_rules > 0) { - whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); + } else { + if (params.n_grammar_rules > 0) { + whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); - // populate the logprobs array (log_softmax) - { - const float logit_max = *std::max_element(logits.begin(), logits.end()); - float logsumexp = 0.0f; - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { - logsumexp += expf(logits[i] - logit_max); + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } } - } - logsumexp = logf(logsumexp) + logit_max; + logsumexp = logf(logsumexp) + logit_max; - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { - logprobs[i] = logits[i] - logsumexp; - } else { - logprobs[i] = -INFINITY; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } } } } @@ -4610,7 +4813,6 @@ static void whisper_process_logits( static whisper_token_data whisper_sample_token( whisper_context & ctx, - whisper_state & state, const whisper_decoder & decoder, bool best) { whisper_token_data result = { @@ -4655,7 +4857,7 @@ static whisper_token_data whisper_sample_token( } else { std::discrete_distribution<> dist(probs.begin(), probs.end()); - result.id = dist(state.rng); + result.id = dist(decoder.rng); result.p = probs[result.id]; result.plog = logprobs[result.id]; } @@ -4665,15 +4867,12 @@ static whisper_token_data whisper_sample_token( result.pt = result.p; } - state.n_sample++; - return result; } static std::vector whisper_sample_token_topk( whisper_context & ctx, - whisper_state & state, - const whisper_decoder & decoder, + whisper_decoder & decoder, int k) { const auto & vocab = ctx.vocab; @@ -4683,7 +4882,7 @@ static std::vector whisper_sample_token_topk( const int n_logits = vocab.n_vocab; - auto & logits_id = state.logits_id; + auto & logits_id = decoder.logits_id; logits_id.resize(n_logits); for (int i = 0; i < n_logits; ++i) { @@ -4732,7 +4931,7 @@ static std::vector whisper_sample_token_topk( std::discrete_distribution<> dist(probs.begin(), probs.end()); for (int i = 0; i < k; ++i) { - const auto id = dist(state.rng); + const auto id = dist(decoder.rng); //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); @@ -4743,8 +4942,6 @@ static std::vector whisper_sample_token_topk( } } - state.n_sample++; - return result; } @@ -4797,125 +4994,6 @@ static void whisper_sequence_score( } } -static bool whisper_kv_swap_fast( - std::vector & view, - whisper_decoder src[], - std::vector & kv_swap_bufs, - const int & n_decoders) { - WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders); - - // (decoder->buffer->decoder or decoder->buffer + decoder->decoder) - std::set two_copy; // decoder indices require two copies to safely modify KV caches - - // (buffer->decoder or decoder->decoder) - std::set one_copy; // decoder indices require one copy to safely modify KV caches - - // (decoder<->decoder) - std::set p_swap_set; // decoder indices able to swap KV-cache pointers - std::vector> p_swap_vec; - p_swap_vec.reserve(n_decoders); - - // see https://github.com/ggerganov/whisper.cpp/wiki - for (int i = 0; i < n_decoders; i++) { - // zero-copy (no modification) - if (i == view[i] || view[i] < 0) { - continue; - } - - bool is_one_copy = true; - // since we modify data sequentially, we only consider decoder indices after current index - for (int j = i + 1; j < n_decoders; j++) { - if (i == view[j]) { - // detect symmetric diagram - if (j == view[i]) { - p_swap_set.insert(i); - p_swap_set.insert(j); - p_swap_vec.emplace_back(i, j); - } else { - two_copy.insert(i); - is_one_copy = false; - } - break; - } - } - if (is_one_copy) { - one_copy.insert(i); - } - } - - kv_swap_bufs.resize(n_decoders); - - for (int i = 0; i < n_decoders; i++) { - kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k)); - kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v)); - } - - for (auto & i : two_copy) { - // make a copy of KV caches - WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i); - //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size()); - //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size()); - ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size()); - ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size()); - } - - // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first - for (auto & i : two_copy) { - // skip the decoder indices that require pointer swapping - if (p_swap_set.find(i) != p_swap_set.end()) { - continue; - } - - if (two_copy.find(view[i]) != two_copy.end()) { - // modify KV caches of decoder using data from kv_swap_bufs - WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); - //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); - ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size()); - ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size()); - } else { - // modify KV caches of decoder using data from correspond decoder KV caches directly - WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); - //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); - ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k); - ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v); - } - } - - // then modify one-copy decoder KV caches - for (auto & i : one_copy) { - // skip the decoder indices that require pointer swapping - if (p_swap_set.find(i) != p_swap_set.end()) { - continue; - } - - if (two_copy.find(view[i]) != two_copy.end()) { - // modify KV caches of decoder using data from kv_swap_bufs - WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); - //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); - ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size()); - ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size()); - } else { - // modify KV caches of decoder using data from correspond decoder KV caches directly - WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); - //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); - //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); - ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k); - ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v); - } - } - - // swap the pointers - for (auto & i : p_swap_vec) { - WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second); - std::swap(src[i.first].kv_self, src[i.second].kv_self); - } - - return true; -} - int whisper_full_with_state( struct whisper_context * ctx, struct whisper_state * state, @@ -5005,25 +5083,23 @@ int whisper_full_with_state( n_decoders = std::max(1, n_decoders); + if (n_decoders > WHISPER_MAX_DECODERS) { + WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS); + return -4; + } + // TAGS: WHISPER_DECODER_INIT for (int j = 1; j < n_decoders; j++) { auto & decoder = state->decoders[j]; - if (decoder.kv_self.ctx == nullptr) { - decoder.kv_self = state->decoders[0].kv_self; - if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) { - WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); - return -4; - } + decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); - WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); + decoder.probs.resize (ctx->vocab.n_vocab); + decoder.logits.resize (ctx->vocab.n_vocab); + decoder.logprobs.resize(ctx->vocab.n_vocab); + decoder.logits_id.reserve(ctx->model.hparams.n_vocab); - decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); - - decoder.probs.resize (ctx->vocab.n_vocab); - decoder.logits.resize (ctx->vocab.n_vocab); - decoder.logprobs.resize(ctx->vocab.n_vocab); - } + decoder.rng = std::mt19937(0); } // the accumulated text context so far @@ -5100,8 +5176,10 @@ int whisper_full_with_state( bool has_ts; whisper_sequence sequence; + whisper_grammar grammar; }; + std::vector> bc_per_dec(n_decoders); std::vector beam_candidates; // main loop @@ -5169,8 +5247,6 @@ int whisper_full_with_state( for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; - decoder.kv_self.n = 0; - decoder.sequence.tokens.clear(); decoder.sequence.result_len = 0; decoder.sequence.sum_logprobs_all = 0.0; @@ -5186,15 +5262,14 @@ int whisper_full_with_state( decoder.has_ts = false; if (params.grammar_rules != nullptr) { - decoder.grammar = whisper_grammar_init( - params.grammar_rules, params.n_grammar_rules, params.i_start_rule); + decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule); } else { decoder.grammar = {}; } } // init prompt and kv cache for the current iteration - // run whisper_decoder() only for decoder 0 and copy the results for the other decoders + // TODO: do not recompute the prompt if it is the same as previous time { prompt.clear(); @@ -5216,7 +5291,11 @@ int whisper_full_with_state( } WHISPER_PRINT_DEBUG("\n\n"); - if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + whisper_kv_cache_clear(state->kv_self); + + whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0); + + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } @@ -5224,20 +5303,14 @@ int whisper_full_with_state( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur); + state->decoders[0].i_batch = prompt.size() - 1; - state->decoders[0].kv_self.n += prompt.size(); + whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur); for (int j = 1; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; - // TODO: fix CUDA - //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); - //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); - ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k); - ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v); - - decoder.kv_self.n += prompt.size(); + whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1); memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); @@ -5252,41 +5325,81 @@ int whisper_full_with_state( const int64_t t_start_sample_us = ggml_time_us(); if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { - beam_candidates.clear(); + for (auto & bc : bc_per_dec) { + bc.clear(); + } } - // generate new sequence candidates for each decoder - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + // sampling + // TODO: avoid memory allocations, optimize, avoid threads? + { + std::atomic j_cur(0); - if (decoder.completed || decoder.failed) { - continue; - } + auto process = [&]() { + while (true) { + const int j = j_cur.fetch_add(1); - switch (params.strategy) { - case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: - { - if (t_cur < 1e-6f) { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true)); - } else { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false)); - } + if (j >= n_decoders_cur) { + break; + } - decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; - } break; - case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: - { - const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size); + auto & decoder = state->decoders[j]; - for (const auto & token : tokens_new) { - beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); - beam_candidates.back().sequence.tokens.push_back(token); - beam_candidates.back().sequence.sum_logprobs_all += token.plog; + if (decoder.completed || decoder.failed) { + continue; + } - //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all); - } - } break; + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur < 1e-6f) { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + } else { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto & token : tokens_new) { + bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, }); + bc_per_dec[j].back().sequence.tokens.push_back(token); + bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog; + } + } break; + }; + } }; + + const int n_threads = std::min(params.n_threads, n_decoders_cur); + + if (n_threads == 1) { + process(); + } else { + std::vector threads(n_threads - 1); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t] = std::thread(process); + } + + process(); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t].join(); + } + } + } + + beam_candidates.clear(); + for (const auto & bc : bc_per_dec) { + beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end()); + + if (!bc.empty()) { + state->n_sample += 1; + } } // for beam-search, choose the top candidates and update the KV caches @@ -5299,7 +5412,6 @@ int whisper_full_with_state( }); uint32_t cur_c = 0; - std::vector decoder_idx(n_decoders_cur, -1); for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; @@ -5318,17 +5430,28 @@ int whisper_full_with_state( ++cur_c; } - decoder.sequence = cur.sequence; decoder.seek_delta = cur.seek_delta; decoder.has_ts = cur.has_ts; + decoder.sequence = cur.sequence; + decoder.grammar = cur.grammar; + + whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1); - decoder_idx[j] = cur.decoder_idx; WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); } - // update KV caches - whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1); + whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1); + whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1); + } } // update the decoder state @@ -5437,32 +5560,83 @@ int whisper_full_with_state( state->t_sample_us += ggml_time_us() - t_start_sample_us; // obtain logits for the next token - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + { + auto & batch = state->batch; - if (decoder.failed || decoder.completed) { - continue; + batch.n_tokens = 0; + + const int n_past = prompt.size() + i; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } + + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta); + + decoder.i_batch = batch.n_tokens; + + batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id; + batch.pos [batch.n_tokens] = n_past; + batch.n_seq_id[batch.n_tokens] = 1; + batch.seq_id [batch.n_tokens][0] = j; + batch.logits [batch.n_tokens] = 1; + batch.n_tokens++; } - decoder.tokens_tmp.resize(1); - decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; + assert(batch.n_tokens > 0); - //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); - - if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -8; } + const int64_t t_start_sample_us = ggml_time_us(); + + // TODO: avoid memory allocations, optimize, avoid threads? { - const int64_t t_start_sample_us = ggml_time_us(); + std::atomic j_cur(0); - whisper_process_logits(*ctx, *state, params, decoder, t_cur); + auto process = [&]() { + while (true) { + const int j = j_cur.fetch_add(1); - ++decoder.kv_self.n; + if (j >= n_decoders_cur) { + break; + } - state->t_sample_us += ggml_time_us() - t_start_sample_us; + auto & decoder = state->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } + + whisper_process_logits(*ctx, *state, decoder, params, t_cur); + } + }; + + const int n_threads = std::min(params.n_threads, n_decoders_cur); + + if (n_threads == 1) { + process(); + } else { + std::vector threads(n_threads - 1); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t] = std::thread(process); + } + + process(); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t].join(); + } + } } + + state->t_sample_us += ggml_time_us() - t_start_sample_us; } } @@ -5759,11 +5933,13 @@ int whisper_full_parallel( ctx->state->t_sample_us += states[i]->t_sample_us; ctx->state->t_encode_us += states[i]->t_encode_us; ctx->state->t_decode_us += states[i]->t_decode_us; + ctx->state->t_batchd_us += states[i]->t_batchd_us; ctx->state->t_prompt_us += states[i]->t_prompt_us; ctx->state->n_sample += states[i]->n_sample; ctx->state->n_encode += states[i]->n_encode; ctx->state->n_decode += states[i]->n_decode; + ctx->state->n_batchd += states[i]->n_batchd; ctx->state->n_prompt += states[i]->n_prompt; whisper_free_state(states[i]); diff --git a/whisper.h b/whisper.h index 50f84a8..8454098 100644 --- a/whisper.h +++ b/whisper.h @@ -78,7 +78,9 @@ extern "C" { struct whisper_state; struct whisper_full_params; - typedef int whisper_token; + typedef int32_t whisper_pos; + typedef int32_t whisper_token; + typedef int32_t whisper_seq_id; struct whisper_context_params { bool use_gpu;