diff --git a/.gitignore b/.gitignore index d5c4b0c..9ff35d0 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ .DS_Store build/ +build-coreml/ build-em/ build-debug/ build-release/ diff --git a/Makefile b/Makefile index d134b76..20b02c3 100644 --- a/Makefile +++ b/Makefile @@ -307,7 +307,7 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h $(CC) $(CFLAGS) -c $< -o $@ -WHISPER_OBJ += ggml-alloc.o ggml-backend.o ggml-quants.o +WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h $(CXX) $(CXXFLAGS) -c $< -o $@ @@ -331,11 +331,11 @@ ggml-metal.o: ggml-metal.m ggml-metal.h WHISPER_OBJ += ggml-metal.o endif -libwhisper.a: ggml.o $(WHISPER_OBJ) - $(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ) +libwhisper.a: $(WHISPER_OBJ) + $(AR) rcs libwhisper.a $(WHISPER_OBJ) -libwhisper.so: ggml.o $(WHISPER_OBJ) - $(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS) +libwhisper.so: $(WHISPER_OBJ) + $(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS) clean: rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so @@ -349,30 +349,30 @@ CC_SDL=`sdl2-config --cflags --libs` SRC_COMMON = examples/common.cpp examples/common-ggml.cpp SRC_COMMON_SDL = examples/common-sdl.cpp -main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) - $(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS) +main: examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ) + $(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o main $(LDFLAGS) ./main -h -bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) - $(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS) +bench: examples/bench/bench.cpp $(WHISPER_OBJ) + $(CXX) $(CXXFLAGS) examples/bench/bench.cpp $(WHISPER_OBJ) -o bench $(LDFLAGS) -quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON) - $(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS) +quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON) + $(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS) -stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) - $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS) +stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) + $(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS) -command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) - $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS) +command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) + $(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS) -lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) - $(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS) +lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) + $(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS) -talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) - $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS) +talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) + $(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS) -talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) - $(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS) +talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) + $(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS) # # Audio samples diff --git a/examples/common.h b/examples/common.h index 9a94bab..54f0b00 100644 --- a/examples/common.h +++ b/examples/common.h @@ -181,7 +181,7 @@ private: // It is assumed that PCM data is normalized to a range from -1 to 1 bool write_audio(const float * data, size_t length) { for (size_t i = 0; i < length; ++i) { - const auto intSample = static_cast(data[i] * 32767); + const int16_t intSample = data[i] * 32767; file.write(reinterpret_cast(&intSample), sizeof(int16_t)); dataSize += sizeof(int16_t); } diff --git a/examples/talk/gpt-2.cpp b/examples/talk/gpt-2.cpp index a2319db..8f9a3e9 100644 --- a/examples/talk/gpt-2.cpp +++ b/examples/talk/gpt-2.cpp @@ -121,13 +121,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & return false; } - std::string word; + char word[129]; + for (int i = 0; i < n_vocab; i++) { uint32_t len; fin.read((char *) &len, sizeof(len)); - - word.resize(len); - fin.read((char *) word.data(), len); + word[len] = '\0'; + fin.read((char *) word, len); vocab.token_to_id[word] = i; vocab.id_to_token[i] = word; diff --git a/extra/bench-all.sh b/extra/bench-all.sh index 8fd18b7..db04267 100755 --- a/extra/bench-all.sh +++ b/extra/bench-all.sh @@ -18,11 +18,11 @@ else fi models=( \ - "tiny" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \ - "base" "base-q5_0" "base-q5_1" "base-q8_0" \ - "small" "small-q5_0" "small-q5_1" "small-q8_0" \ - "medium" "medium-q5_0" "medium-q5_1" "medium-q8_0" \ - "large" "large-q5_0" "large-q5_1" "large-q8_0" \ + "tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \ + "base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \ + "small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \ + "medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" \ + "large" "large-q4_0" "large-q4_1" "large-q5_0" "large-q5_1" "large-q8_0" \ ) if [ "$encoder_only" -eq 0 ]; then @@ -83,6 +83,10 @@ for model in "${models[@]}"; do config="$config COREML" fi + if [[ $system_info == *"CUDA = 1"* ]]; then + config="$config CUDA" + fi + if [[ $system_info == *"METAL = 1"* ]]; then config="$config METAL" fi diff --git a/extra/quantize-all.sh b/extra/quantize-all.sh index bfef21e..767462b 100755 --- a/extra/quantize-all.sh +++ b/extra/quantize-all.sh @@ -15,33 +15,13 @@ declare -a filedex cd `dirname $0` cd ../ -# Let's loop across all the objects in the 'models' dir: -for i in ./models/*; do - # Check to see if it's a file or directory - if [ -d "$i" ]; then - # It's a directory! We should make sure it's not empty first: - if [ "$(ls -A $i)" ]; then - # Passed! Let's go searching for bin files (shouldn't need to go more than a layer deep here) - for f in "$i"/*.bin; do - # [Neuron Activation] - newfile=`echo "${f##*/}" | cut -d _ -f 1`; - if [ "$newfile" != "q5" ]; then - ./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" ${qtype1}; - ./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" ${qtype0}; - filedex+=( "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" ) - fi - done - fi - else - # It's a file! Let's make sure it's the right type: - if [ "${i##*.}" == "bin" ]; then - # And we probably want to skip the testing files - if [ "${i:9:8}" != "for-test" ]; then - # [Neuron Activation] - ./quantize "${i}" "${i:-4}-${qtype1}.bin" ${qtype1}; - ./quantize "${i}" "${i:-4}-${qtype0}.bin" ${qtype0}; - filedex+=( "${i:-4}-${qtype1}.bin" "${i:-4}-${qtype0}.bin" ) - fi +for i in `ls ./models | grep ^ggml-.*.bin | grep -v "\-q"`; do + m="models/$i" + if [ -f "$m" ]; then + if [ "${m##*.}" == "bin" ]; then + ./quantize "${m}" "${m::${#m}-4}-${qtype1}.bin" ${qtype1}; + ./quantize "${m}" "${m::${#m}-4}-${qtype0}.bin" ${qtype0}; + filedex+=( "${m::${#m}-4}-${qtype1}.bin" "${m::${#m}-4}-${qtype0}.bin" ) fi fi done diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f4a6795..34c45f3 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4476,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { *dsti = __float2half(*xi); } +static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { + const half * xi = (const half *) cxi; + half * dsti = (half *) cdsti; + + *dsti = *xi; +} + template static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, @@ -4729,6 +4736,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min, dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); } +static __global__ void im2col_f32_f16( + const float * x, half * dst, + int ofs0, int ofs1, int IW, int IH, int CHW, + int s0, int s1, int p0, int p1, int d0, int d1) { + const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0; + const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1; + + const int offset_dst = + (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW + + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = __float2half(0.0f); + } else { + const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1; + dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]); + } +} + template static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); @@ -5618,6 +5644,16 @@ static void ggml_cpy_f32_f16_cuda( (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } +static void ggml_cpy_f16_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_f32_f16<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; scale_f32<<>>(x, dst, scale, k); @@ -5701,6 +5737,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c soft_max_f32<<>>(x, dst, ncols_x); } +static void im2col_f32_f16_cuda(const float * x, half * dst, + int OH, int IW, int IH, int OW, int IC, + int KH, int KW, int N, int ofs0, int ofs1, + int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) { + dim3 block_nums(IC, OH, OW); + dim3 block_dims(N, KH, KW); + im2col_f32_f16<<>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1); +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 256 @@ -6483,7 +6528,7 @@ inline void ggml_cuda_op_mul_mat_cublas( src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream); to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); } - const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; + const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16; size_t dst_f16_as = 0; half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream); @@ -6659,6 +6704,45 @@ inline void ggml_cuda_op_alibi( (void) src1_dd; } +inline void ggml_cuda_op_im2col( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + + const int64_t N = src1->ne[is_2D ? 3 : 2]; + const int64_t IC = src1->ne[is_2D ? 2 : 1]; + const int64_t IH = is_2D ? src1->ne[1] : 1; + const int64_t IW = src1->ne[0]; + + const int64_t KH = is_2D ? src0->ne[1] : 1; + const int64_t KW = src0->ne[0]; + + const int64_t OH = is_2D ? dst->ne[2] : 1; + const int64_t OW = dst->ne[1]; + + const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + + im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, + OH, IW, IH, OW, IC, KH, KW, N, + ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream); + + (void) src0; + (void) src0_dd; +} + inline void ggml_cuda_op_diag_mask_inf( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -7549,6 +7633,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, + ne10, ne11, nb10, nb11, nb12, main_stream); } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); @@ -7580,6 +7667,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); } +void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col); +} + static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { (void) src0; (void) src1; @@ -7943,6 +8034,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_ALIBI: func = ggml_cuda_alibi; break; + case GGML_OP_IM2COL: + func = ggml_cuda_im2col; + break; default: return false; } diff --git a/ggml-metal.h b/ggml-metal.h index 096b844..be2731f 100644 --- a/ggml-metal.h +++ b/ggml-metal.h @@ -26,7 +26,7 @@ #include // max memory buffers that can be mapped to the device -#define GGML_METAL_MAX_BUFFERS 16 +#define GGML_METAL_MAX_BUFFERS 64 #define GGML_METAL_MAX_COMMAND_BUFFERS 32 struct ggml_tensor; diff --git a/ggml-metal.m b/ggml-metal.m index 3bee839..6293908 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -86,6 +86,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); + GGML_METAL_DECL_KERNEL(mul_mv_f16_f16); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); @@ -114,6 +115,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(rope_f32); GGML_METAL_DECL_KERNEL(rope_f16); GGML_METAL_DECL_KERNEL(alibi_f32); + GGML_METAL_DECL_KERNEL(im2col_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f16_f16); @@ -287,6 +289,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); + GGML_METAL_ADD_KERNEL(mul_mv_f16_f16); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); @@ -317,6 +320,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(rope_f32); GGML_METAL_ADD_KERNEL(rope_f16); GGML_METAL_ADD_KERNEL(alibi_f32); + GGML_METAL_ADD_KERNEL(im2col_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f16_f16); @@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); + GGML_METAL_DEL_KERNEL(mul_mv_f16_f16); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); @@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(rope_f32); GGML_METAL_DEL_KERNEL(rope_f16); GGML_METAL_DEL_KERNEL(alibi_f32); + GGML_METAL_DEL_KERNEL(im2col_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f16_f16); @@ -473,6 +479,10 @@ static id ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru const int64_t tsize = ggml_nbytes(t); + if (t->buffer && t->buffer->backend && t->buffer->backend->context) { + ctx = t->buffer->backend->context; + } + // find the view that contains the tensor fully for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; @@ -1139,6 +1149,7 @@ void ggml_metal_graph_compute( switch (src0t) { case GGML_TYPE_F32: { + GGML_ASSERT(src1t == GGML_TYPE_F32); [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32]; nrows = 4; } break; @@ -1146,13 +1157,18 @@ void ggml_metal_graph_compute( { nth0 = 32; nth1 = 1; - if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; - nrows = ne11; + if (src1t == GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; + nrows = ne11; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; + nrows = 4; + } } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16]; nrows = 4; } } break; @@ -1464,6 +1480,58 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case GGML_OP_IM2COL: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + switch (src0->type) { + case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break; + default: GGML_ASSERT(false); + }; + + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml-metal.metal b/ggml-metal.metal index 7c35f23..5d1357c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32( constant int64_t & ne0, constant int64_t & ne1, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_F32_F32; @@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32( } } +#define N_F16_F16 4 + +kernel void kernel_mul_mv_f16_f16( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F16_F16; + const int64_t im = tgpig.z; + + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); + device const half4 * y4 = (device const half4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + kernel void kernel_mul_mv_f16_f32_1row( device const char * src0, device const char * src1, @@ -1229,6 +1302,39 @@ kernel void kernel_rope( template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; +kernel void kernel_im2col_f16( + device const float * x, + device half * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; + const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; + + const int32_t offset_dst = + (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml.c b/ggml.c index d1b7e94..584ee46 100644 --- a/ggml.c +++ b/ggml.c @@ -1634,13 +1634,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ROPE_BACK", "ALIBI", "CLAMP", - "CONV_1D", - "CONV_1D_STAGE_0", - "CONV_1D_STAGE_1", "CONV_TRANSPOSE_1D", - "CONV_2D", - "CONV_2D_STAGE_0", - "CONV_2D_STAGE_1", + "IM2COL", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", @@ -1671,7 +1666,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); +static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1721,13 +1716,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rope_back(x)", "alibi(x)", "clamp(x)", - "conv_1d(x)", - "conv_1d_stage_0(x)", - "conv_1d_stage_1(x)", "conv_transpose_1d(x)", - "conv_2d(x)", - "conv_2d_stage_0(x)", - "conv_2d_stage_1(x)", + "im2col(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", @@ -1758,7 +1748,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73"); +static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1786,13 +1776,7 @@ static void ggml_setup_op_has_task_pass(void) { p[GGML_OP_GET_ROWS_BACK ] = true; p[GGML_OP_DIAG_MASK_INF ] = true; p[GGML_OP_DIAG_MASK_ZERO ] = true; - p[GGML_OP_CONV_1D ] = true; - p[GGML_OP_CONV_1D_STAGE_0 ] = true; - p[GGML_OP_CONV_1D_STAGE_1 ] = true; p[GGML_OP_CONV_TRANSPOSE_1D ] = true; - p[GGML_OP_CONV_2D ] = true; - p[GGML_OP_CONV_2D_STAGE_0 ] = true; - p[GGML_OP_CONV_2D_STAGE_1 ] = true; p[GGML_OP_CONV_TRANSPOSE_2D ] = true; p[GGML_OP_FLASH_ATTN_BACK ] = true; p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; @@ -5137,82 +5121,6 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; } -// im2col: [N, IC, IL] => [N, OL, IC*K] -// a: [OC,IC, K] -// b: [N, IC, IL] -// result: [N, OL, IC*K] -static struct ggml_tensor * ggml_conv_1d_stage_0( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0) { - GGML_ASSERT(a->ne[1] == b->ne[1]); - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); - - const int64_t ne[4] = { - a->ne[1] * a->ne[0], - OL, - b->ne[2], - 1, - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); - - int32_t params[] = { s0, p0, d0 }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_CONV_1D_STAGE_0; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_conv_1d_stage_1 - -// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] -// a: [OC, IC, K] -// b: [N, OL, IC * K] -// result: [N, OC, OL] -static struct ggml_tensor * ggml_conv_1d_stage_1( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { - b->ne[1], - a->ne[2], - b->ne[2], - 1, - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - result->op = GGML_OP_CONV_1D_STAGE_1; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_conv_1d - GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, @@ -5220,44 +5128,18 @@ GGML_API struct ggml_tensor * ggml_conv_1d( int s0, int p0, int d0) { - struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0); - result = ggml_conv_1d_stage_1(ctx, a, result); + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K] + + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K] + + result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL] + return result; } -// GGML_API struct ggml_tensor * ggml_conv_1d( -// struct ggml_context * ctx, -// struct ggml_tensor * a, -// struct ggml_tensor * b, -// int s0, -// int p0, -// int d0) { -// GGML_ASSERT(ggml_is_matrix(b)); -// GGML_ASSERT(a->ne[1] == b->ne[1]); -// bool is_node = false; - -// if (a->grad || b->grad) { -// GGML_ASSERT(false); // TODO: implement backward -// is_node = true; -// } - -// const int64_t ne[4] = { -// ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), -// a->ne[2], 1, 1, -// }; -// struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - -// int32_t params[] = { s0, p0, d0 }; -// ggml_set_op_params(result, params, sizeof(params)); - -// result->op = GGML_OP_CONV_1D; -// result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; -// result->src[0] = a; -// result->src[1] = b; - -// return result; -// } - // ggml_conv_1d_ph struct ggml_tensor* ggml_conv_1d_ph( @@ -5319,7 +5201,7 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d( // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] // result: [N, OH, OW, IC*KH*KW] -static struct ggml_tensor * ggml_conv_2d_stage_0( +struct ggml_tensor * ggml_im2col( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -5328,9 +5210,14 @@ static struct ggml_tensor * ggml_conv_2d_stage_0( int p0, int p1, int d0, - int d1) { + int d1, + bool is_2D) { - GGML_ASSERT(a->ne[2] == b->ne[2]); + if(is_2D) { + GGML_ASSERT(a->ne[2] == b->ne[2]); + } else { + GGML_ASSERT(a->ne[1] == b->ne[1]); + } bool is_node = false; if (a->grad || b->grad) { @@ -5338,81 +5225,51 @@ static struct ggml_tensor * ggml_conv_2d_stage_0( is_node = true; } - const int64_t OH = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); - const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; + const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); const int64_t ne[4] = { - a->ne[2] * a->ne[1] * a->ne[0], + is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], OW, - OH, - b->ne[3], + is_2D ? OH : b->ne[2], + is_2D ? b->ne[3] : 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_CONV_2D_STAGE_0; + result->op = GGML_OP_IM2COL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; return result; - -} - -// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] -// a: [OC, IC, KH, KW] -// b: [N, OH, OW, IC * KH * KW] -// result: [N, OC, OH, OW] -static struct ggml_tensor * ggml_conv_2d_stage_1( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { - b->ne[1], - b->ne[2], - a->ne[3], - b->ne[3], - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - result->op = GGML_OP_CONV_2D_STAGE_1; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; - } // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] // result: [N, OC, OH, OW] struct ggml_tensor * ggml_conv_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW] - struct ggml_tensor * result = ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW] - result = ggml_conv_2d_stage_1(ctx, a, result); + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OC,IC, KH, KW] => [OC, IC * KH * KW] + + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW] return result; - } // ggml_conv_2d_sk_p0 @@ -9507,6 +9364,8 @@ static bool ggml_compute_forward_mul_mat_use_blas( // TODO: find the optimal values for these if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && + src0->type == GGML_TYPE_F32 && + src1->type == GGML_TYPE_F32 && (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ @@ -9517,6 +9376,7 @@ static bool ggml_compute_forward_mul_mat_use_blas( } #endif + static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -9545,7 +9405,7 @@ static void ggml_compute_forward_mul_mat( // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); - GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -11637,416 +11497,6 @@ static void ggml_compute_forward_rope_back( } } -// ggml_compute_forward_conv_1d - -static void ggml_compute_forward_conv_1d_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - - // size of the convolution row - the kernel size unrolled across all input channels - const int ew0 = nk*ne01; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - ggml_fp16_t * dst_data = wdata; - - for (int64_t i0 = 0; i0 < ne0; i0++) { - for (int64_t ik = 0; ik < nk; ik++) { - const int idx0 = i0*s0 + ik*d0 - p0; - - if(!(idx0 < 0 || idx0 >= ne10)) { - dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]); - } - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne2; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int i2 = 0; i2 < ne2; i2++) { - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); - - for (int i0 = 0; i0 < ne0; i0++) { - ggml_vec_dot_f16(ew0, dst_data + i0, - (ggml_fp16_t *) ((char *) src0->data + i1*nb02), - (ggml_fp16_t *) wdata + i2*nb2 + i0*ew0); - } - } - } -} - -static void ggml_compute_forward_conv_1d_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - - const int ew0 = nk*ne01; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - float * const wdata = (float *) params->wdata + 0; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; - - for (int64_t i0 = 0; i0 < ne0; i0++) { - for (int64_t ik = 0; ik < nk; ik++) { - const int idx0 = i0*s0 + ik*d0 - p0; - - if(!(idx0 < 0 || idx0 >= ne10)) { - dst_data[i0*ew0 + i11*nk + ik] = src[idx0]; - } - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne02; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * const wdata = (float *) params->wdata + 0; - - for (int i2 = 0; i2 < ne2; i2++) { - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); - - for (int i0 = 0; i0 < ne0; i0++) { - ggml_vec_dot_f32(ew0, dst_data + i0, - (float *) ((char *) src0->data + i1*nb02), - (float *) wdata + i2*nb2 + i0*ew0); - } - } - } -} - -// TODO: reuse ggml_mul_mat or implement ggml_im2col and remove stage_0 and stage_1 -static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k, - ggml_fp16_t * A, - ggml_fp16_t * B, - float * C, - const int ith, const int nth) { - // does not seem to make a difference - int64_t m0, m1, n0, n1; - // patches per thread - if (m > n) { - n0 = 0; - n1 = n; - - // total patches in dst - const int np = m; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - m0 = dp*ith; - m1 = MIN(m0 + dp, np); - } else { - m0 = 0; - m1 = m; - - // total patches in dst - const int np = n; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - n0 = dp*ith; - n1 = MIN(n0 + dp, np); - } - - // block-tiling attempt - int64_t blck_n = 16; - int64_t blck_m = 16; - - // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB - // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K); - // if (blck_size > 0) { - // blck_0 = 4; - // blck_1 = blck_size / blck_0; - // if (blck_1 < 0) { - // blck_1 = 1; - // } - // // blck_0 = (int64_t)sqrt(blck_size); - // // blck_1 = blck_0; - // } - // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1); - - for (int j = n0; j < n1; j+=blck_n) { - for (int i = m0; i < m1; i+=blck_m) { - // printf("i j k => %d %d %d\n", i, j, K); - for (int ii = i; ii < i + blck_m && ii < m1; ii++) { - for (int jj = j; jj < j + blck_n && jj < n1; jj++) { - ggml_vec_dot_f16(k, - C + ii*n + jj, - A + ii * k, - B + jj * k); - } - } - } - } -} - -// src0: kernel [OC, IC, K] -// src1: signal [N, IC, IL] -// dst: result [N, OL, IC*K] -static void ggml_compute_forward_conv_1d_stage_0_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS; - - const int64_t N = ne12; - const int64_t IC = ne11; - const int64_t IL = ne10; - - const int64_t K = ne00; - - const int64_t OL = ne1; - - const int ith = params->ith; - const int nth = params->nth; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(dst->data, 0, ggml_nbytes(dst)); - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // im2col: [N, IC, IL] => [N, OL, IC*K] - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; - - for (int64_t in = 0; in < N; in++) { - for (int64_t iol = 0; iol < OL; iol++) { - for (int64_t iic = ith; iic < IC; iic+=nth) { - - // micro kernel - ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K] - const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL] - - for (int64_t ik = 0; ik < K; ik++) { - const int64_t iil = iol*s0 + ik*d0 - p0; - - if (!(iil < 0 || iil >= IL)) { - dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]); - } - } - } - } - } - } -} - -// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] -// src0: [OC, IC, K] -// src1: [N, OL, IC * K] -// result: [N, OC, OL] -static void ggml_compute_forward_conv_1d_stage_1_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_BINARY_OP_LOCALS; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb0 == sizeof(float)); - - const int N = ne12; - const int OL = ne11; - - const int OC = ne02; - const int IC = ne01; - const int K = ne00; - - const int ith = params->ith; - const int nth = params->nth; - - int64_t m = OC; - int64_t n = OL; - int64_t k = IC * K; - - // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] - for (int i = 0; i < N; i++) { - ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k] - ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k] - float * C = (float *)dst->data + i * m * n; // [m, n] - - gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); - } -} - -static void ggml_compute_forward_conv_1d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch(src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_conv_1d_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_conv_1d_stage_0( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch(src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_conv_1d_stage_1( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch(src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - // ggml_compute_forward_conv_transpose_1d static void ggml_compute_forward_conv_transpose_1d_f16_f32( @@ -12258,12 +11708,10 @@ static void ggml_compute_forward_conv_transpose_1d( } } -// ggml_compute_forward_conv_2d - // src0: kernel [OC, IC, KH, KW] // src1: image [N, IC, IH, IW] // dst: result [N, OH, OW, IC*KH*KW] -static void ggml_compute_forward_conv_2d_stage_0_f32( +static void ggml_compute_forward_im2col_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -12277,34 +11725,35 @@ static void ggml_compute_forward_conv_2d_stage_0_f32( GGML_TENSOR_BINARY_OP_LOCALS; - const int64_t N = ne13; - const int64_t IC = ne12; - const int64_t IH = ne11; - const int64_t IW = ne10; - - // const int64_t OC = ne03; - // const int64_t IC = ne02; - const int64_t KH = ne01; - const int64_t KW = ne00; - - const int64_t OH = ne2; - const int64_t OW = ne1; + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; const int ith = params->ith; const int nth = params->nth; - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + const int64_t N = is_2D ? ne13 : ne12; + const int64_t IC = is_2D ? ne12 : ne11; + const int64_t IH = is_2D ? ne11 : 1; + const int64_t IW = ne10; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne2 : 1; + const int64_t OW = ne1; + + int ofs0 = is_2D ? nb13 : nb12; + int ofs1 = is_2D ? nb12 : nb11; GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); if (params->type == GGML_TASK_INIT) { - memset(dst->data, 0, ggml_nbytes(dst)); return; } @@ -12317,20 +11766,22 @@ static void ggml_compute_forward_conv_2d_stage_0_f32( ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; for (int64_t in = 0; in < N; in++) { - for (int64_t ioh = 0; ioh < OH; ioh++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 for (int64_t iow = 0; iow < OW; iow++) { - for (int64_t iic = ith; iic < IC; iic+=nth) { + for (int64_t iic = ith; iic < IC; iic += nth) { // micro kernel ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW] + const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] - for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 for (int64_t ikw = 0; ikw < KW; ikw++) { const int64_t iiw = iow*s0 + ikw*d0 - p0; const int64_t iih = ioh*s1 + ikh*d1 - p1; - if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) { + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; + } else { dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); } } @@ -12342,180 +11793,7 @@ static void ggml_compute_forward_conv_2d_stage_0_f32( } } -// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] -// src0: [OC, IC, KH, KW] -// src1: [N, OH, OW, IC * KH * KW] -// result: [N, OC, OH, OW] -static void ggml_compute_forward_conv_2d_stage_1_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_BINARY_OP_LOCALS; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb0 == sizeof(float)); - - const int N = ne13; - const int OH = ne12; - const int OW = ne11; - - const int OC = ne03; - const int IC = ne02; - const int KH = ne01; - const int KW = ne00; - - const int ith = params->ith; - const int nth = params->nth; - - int64_t m = OC; - int64_t n = OH * OW; - int64_t k = IC * KH * KW; - - // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] - for (int i = 0; i < N; i++) { - ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k] - ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k] - float * C = (float *)dst->data + i * m * n; // [m, n] - - gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); - } -} - -static void ggml_compute_forward_conv_2d_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - // src1: image [N, IC, IH, IW] - // src0: kernel [OC, IC, KH, KW] - // dst: result [N, OC, OH, OW] - // ne12: IC - // ne0: OW - // ne1: OH - // nk0: KW - // nk1: KH - // ne13: N - - const int N = ne13; - const int IC = ne12; - const int IH = ne11; - const int IW = ne10; - - const int OC = ne03; - // const int IC = ne02; - const int KH = ne01; - const int KW = ne00; - - const int OH = ne1; - const int OW = ne0; - - const int ith = params->ith; - const int nth = params->nth; - - // const int nk0 = ne00; - // const int nk1 = ne01; - - // size of the convolution row - the kernel size unrolled across all channels - // const int ew0 = nk0*nk1*ne02; - // ew0: IC*KH*KW - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - // prepare source data (src1) - // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW] - - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int in = 0; in < N; in++) { - for (int iic = 0; iic < IC; iic++) { - for (int ioh = 0; ioh < OH; ioh++) { - for (int iow = 0; iow < OW; iow++) { - - // micro kernel - ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW] - - for (int ikh = 0; ikh < KH; ikh++) { - for (int ikw = 0; ikw < KW; ikw++) { - const int iiw = iow*s0 + ikw*d0 - p0; - const int iih = ioh*s1 + ikh*d1 - p1; - - if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); - } - } - } - } - } - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - // wdata: [N*OH*OW, IC*KH*KW] - // dst: result [N, OC, OH, OW] - // src0: kernel [OC, IC, KH, KW] - - int64_t m = OC; - int64_t n = OH * OW; - int64_t k = IC * KH * KW; - - // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] - for (int i = 0; i < N; i++) { - ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k] - ggml_fp16_t * B = (ggml_fp16_t *)wdata + i * m * k; // [n, k] - float * C = (float *)dst->data + i * m * n; // [m * k] - - gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); - } -} - -static void ggml_compute_forward_conv_2d( +static void ggml_compute_forward_im2col( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -12523,50 +11801,7 @@ static void ggml_compute_forward_conv_2d( switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - //ggml_compute_forward_conv_2d_f32(params, src0, src1, dst); - GGML_ASSERT(false); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_conv_2d_stage_0( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(false); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_conv_2d_stage_1( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst); + ggml_compute_forward_im2col_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { @@ -14783,33 +14018,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_clamp(params, tensor->src[0], tensor); } break; - case GGML_OP_CONV_1D: - { - ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_1D_STAGE_0: - { - ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_1D_STAGE_1: - { - ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor); - } break; case GGML_OP_CONV_TRANSPOSE_1D: { ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor); } break; - case GGML_OP_CONV_2D: + case GGML_OP_IM2COL: { - ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_2D_STAGE_0: - { - ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_2D_STAGE_1: - { - ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_im2col(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_CONV_TRANSPOSE_2D: { @@ -15780,31 +14995,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; - case GGML_OP_CONV_1D: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_1D_STAGE_0: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_1D_STAGE_1: - { - GGML_ASSERT(false); // TODO: not implemented - } break; case GGML_OP_CONV_TRANSPOSE_1D: { GGML_ASSERT(false); // TODO: not implemented } break; - case GGML_OP_CONV_2D: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_2D_STAGE_0: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_2D_STAGE_1: + case GGML_OP_IM2COL: { GGML_ASSERT(false); // TODO: not implemented } break; @@ -16533,31 +15728,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = 1; //TODO } break; - case GGML_OP_CONV_1D: - { - n_tasks = n_threads; - } break; - case GGML_OP_CONV_1D_STAGE_0: - { - n_tasks = n_threads; - } break; - case GGML_OP_CONV_1D_STAGE_1: - { - n_tasks = n_threads; - } break; case GGML_OP_CONV_TRANSPOSE_1D: { n_tasks = n_threads; } break; - case GGML_OP_CONV_2D: - { - n_tasks = n_threads; - } break; - case GGML_OP_CONV_2D_STAGE_0: - { - n_tasks = n_threads; - } break; - case GGML_OP_CONV_2D_STAGE_1: + case GGML_OP_IM2COL: { n_tasks = n_threads; } break; @@ -16642,6 +15817,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; default: { + printf("%s: op %s not implemented\n", __func__, ggml_op_name(node->op)); GGML_ASSERT(false); } break; } @@ -16844,38 +16020,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; } } break; - case GGML_OP_CONV_1D: - { - GGML_ASSERT(node->src[0]->ne[3] == 1); - GGML_ASSERT(node->src[1]->ne[2] == 1); - GGML_ASSERT(node->src[1]->ne[3] == 1); - - const int64_t ne00 = node->src[0]->ne[0]; - const int64_t ne01 = node->src[0]->ne[1]; - const int64_t ne02 = node->src[0]->ne[2]; - - const int64_t ne10 = node->src[1]->ne[0]; - const int64_t ne11 = node->src[1]->ne[1]; - - const int64_t ne0 = node->ne[0]; - const int64_t ne1 = node->ne[1]; - const int64_t nk = ne00; - const int64_t ew0 = nk * ne01; - - UNUSED(ne02); - UNUSED(ne10); - UNUSED(ne11); - - if (node->src[0]->type == GGML_TYPE_F16 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0); - } else if (node->src[0]->type == GGML_TYPE_F32 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*(ne0*ne1*ew0); - } else { - GGML_ASSERT(false); - } - } break; case GGML_OP_CONV_TRANSPOSE_1D: { GGML_ASSERT(node->src[0]->ne[3] == 1); @@ -16901,37 +16045,9 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { GGML_ASSERT(false); } } break; - case GGML_OP_CONV_2D: + case GGML_OP_IM2COL: { - const int64_t ne00 = node->src[0]->ne[0]; // W - const int64_t ne01 = node->src[0]->ne[1]; // H - const int64_t ne02 = node->src[0]->ne[2]; // C - const int64_t ne03 = node->src[0]->ne[3]; // N - - const int64_t ne10 = node->src[1]->ne[0]; // W - const int64_t ne11 = node->src[1]->ne[1]; // H - const int64_t ne12 = node->src[1]->ne[2]; // C - - const int64_t ne0 = node->ne[0]; - const int64_t ne1 = node->ne[1]; - const int64_t ne2 = node->ne[2]; - const int64_t ne3 = node->ne[3]; - const int64_t nk = ne00*ne01; - const int64_t ew0 = nk * ne02; - - UNUSED(ne03); - UNUSED(ne2); - - if (node->src[0]->type == GGML_TYPE_F16 && - node->src[1]->type == GGML_TYPE_F32) { - // im2col: [N*OH*OW, IC*KH*KW] - cur = sizeof(ggml_fp16_t)*(ne3*ne0*ne1*ew0); - } else if (node->src[0]->type == GGML_TYPE_F32 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)* (ne10*ne11*ne12); - } else { - GGML_ASSERT(false); - } + n_tasks = n_threads; } break; case GGML_OP_CONV_TRANSPOSE_2D: { diff --git a/ggml.h b/ggml.h index e56a833..52ae675 100644 --- a/ggml.h +++ b/ggml.h @@ -403,13 +403,8 @@ extern "C" { GGML_OP_ROPE_BACK, GGML_OP_ALIBI, GGML_OP_CLAMP, - GGML_OP_CONV_1D, - GGML_OP_CONV_1D_STAGE_0, // internal - GGML_OP_CONV_1D_STAGE_1, // internal GGML_OP_CONV_TRANSPOSE_1D, - GGML_OP_CONV_2D, - GGML_OP_CONV_2D_STAGE_0, // internal - GGML_OP_CONV_2D_STAGE_1, // internal + GGML_OP_IM2COL, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, @@ -1398,6 +1393,18 @@ extern "C" { float min, float max); + GGML_API struct ggml_tensor * ggml_im2col( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D); + GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/whisper.cpp b/whisper.cpp index 681727f..244cfeb 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1,10 +1,15 @@ #include "whisper.h" + #ifdef WHISPER_USE_COREML #include "coreml/whisper-encoder.h" #endif #ifdef GGML_USE_METAL -# include "ggml-metal.h" +#include "ggml-metal.h" +#endif + +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" #endif #ifdef WHISPER_USE_OPENVINO @@ -13,6 +18,7 @@ #include "ggml.h" #include "ggml-alloc.h" +#include "ggml-backend.h" #include #include @@ -97,10 +103,32 @@ static void byteswap_tensor(ggml_tensor * tensor) { #define BYTESWAP_TENSOR(t) do {} while (0) #endif +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define WHISPER_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +WHISPER_ATTRIBUTE_FORMAT(2, 3) +static void whisper_log_internal (ggml_log_level level, const char* format, ...); +static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) +#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) + #define WHISPER_ASSERT(x) \ do { \ if (!(x)) { \ - log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ abort(); \ } \ } while (0) @@ -127,8 +155,8 @@ static void byteswap_tensor(ggml_tensor * tensor) { // static void ggml_graph_compute_helper( + struct ggml_cgraph * graph, std::vector & buf, - ggml_cgraph * graph, int n_threads, whisper_abort_callback abort_callback, void * abort_callback_data) { @@ -145,6 +173,21 @@ static void ggml_graph_compute_helper( ggml_graph_compute(graph, &plan); } +static void ggml_graph_compute_helper( + struct ggml_backend * backend, + struct ggml_cgraph * graph, + int n_threads) { + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, graph); +} + // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad" // the idea is to represent the original matrix multiplication: // @@ -179,6 +222,7 @@ static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct g } // TODO: check if other platforms can benefit from this optimization +// TODO: CUDA is currently broken - seems ggml_mul_mat does not handle views correctly #if defined(GGML_USE_METAL) #define ggml_mul_mat ggml_mul_mat_pad #endif @@ -305,75 +349,6 @@ static const std::map> g_lang = { { "yue", { 99, "cantonese", } }, }; -static const size_t MB = 1ull*1024*1024; - -// TODO: avoid using GGUF -static const std::map> MEM_REQ_MODEL = { - { GGML_TYPE_F32, - { - { MODEL_TINY, 74ull*MB }, - { MODEL_BASE, 142ull*MB }, - { MODEL_SMALL, 466ull*MB }, - { MODEL_MEDIUM, 1464ull*MB }, - { MODEL_LARGE, 2952ull*MB }, - }, - }, - { GGML_TYPE_F16, - { - { MODEL_TINY, 74ull*MB }, - { MODEL_BASE, 142ull*MB }, - { MODEL_SMALL, 466ull*MB }, - { MODEL_MEDIUM, 1464ull*MB }, - { MODEL_LARGE, 2952ull*MB }, - }, - }, - { GGML_TYPE_Q4_0, - { - { MODEL_TINY, 26ull*MB }, - { MODEL_BASE, 50ull*MB }, - { MODEL_SMALL, 154ull*MB }, - { MODEL_MEDIUM, 470ull*MB }, - { MODEL_LARGE, 940ull*MB }, - }, - }, - { GGML_TYPE_Q4_1, - { - { MODEL_TINY, 32ull*MB }, - { MODEL_BASE, 58ull*MB }, - { MODEL_SMALL, 182ull*MB }, - { MODEL_MEDIUM, 562ull*MB }, - { MODEL_LARGE, 1124ull*MB }, - }, - }, - { GGML_TYPE_Q5_0, - { - { MODEL_TINY, 30ull*MB }, - { MODEL_BASE, 54ull*MB }, - { MODEL_SMALL, 170ull*MB }, - { MODEL_MEDIUM, 516ull*MB }, - { MODEL_LARGE, 1034ull*MB }, - }, - }, - { GGML_TYPE_Q5_1, - { - { MODEL_TINY, 32ull*MB }, - { MODEL_BASE, 58ull*MB }, - { MODEL_SMALL, 182ull*MB }, - { MODEL_MEDIUM, 562ull*MB }, - { MODEL_LARGE, 1124ull*MB }, - }, - }, - { GGML_TYPE_Q8_0, - { - { MODEL_TINY, 45ull*MB }, - { MODEL_BASE, 84ull*MB }, - { MODEL_SMALL, 268ull*MB }, - { MODEL_MEDIUM, 834ull*MB }, - { MODEL_LARGE, 1674ull*MB }, - }, - }, -}; - struct whisper_mel { int n_len; int n_len_org; @@ -554,8 +529,7 @@ struct whisper_kv_cache { struct ggml_context * ctx; - // buf points to the memory allocated for both ggml_tensor 'k' and 'v' (see kv_cache_init) - std::vector buf; + ggml_backend_buffer_t buffer; int n; // number of tokens currently in the cache }; @@ -594,11 +568,11 @@ struct whisper_model { std::vector layers_encoder; std::vector layers_decoder; - // context + // ggml context that contains all the meta information about the model tensors struct ggml_context * ctx; - // the model memory buffer is read-only and can be shared between processors - std::vector * buf; + // the model backend data is read-only and can be shared between processors + struct ggml_backend_buffer * buffer; // tensors int n_loaded; @@ -663,37 +637,47 @@ struct whisper_allocr { ggml_allocr * alloc = nullptr; std::vector meta; - std::vector data; + + ggml_backend_buffer_t buffer; }; static size_t whisper_allocr_size(struct whisper_allocr & allocr) { - return allocr.meta.size() + allocr.data.size(); + return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc); } // measure the memory usage of a graph and prepare the allocr's internal data buffer -static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function && get_graph) { - const int tensor_alignment = 32; +static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function && get_graph) { + auto & alloc = allocr.alloc; + auto & meta = allocr.meta; - auto & alloc = allocr.alloc; - auto & meta = allocr.meta; - auto & data = allocr.data; + alloc = ggml_allocr_new_measure_from_backend(backend); meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); - alloc = ggml_allocr_new_measure(tensor_alignment); + ggml_allocr_alloc_graph(alloc, get_graph()); +} - const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment; +static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) { + if (allocr.alloc == nullptr) { + // this can be null if we use external encoder like CoreML or OpenVINO + return; + } + + auto & alloc = allocr.alloc; + auto & buffer = allocr.buffer; + + size_t size = ggml_allocr_max_size(alloc); ggml_allocr_free(alloc); - data.resize(alloc_size); - - alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); + buffer = ggml_backend_alloc_buffer(backend, size); + alloc = ggml_allocr_new_from_buffer(buffer); } static void whisper_allocr_free(struct whisper_allocr & allocr) { if (allocr.alloc) { ggml_allocr_free(allocr.alloc); + ggml_backend_buffer_free(allocr.buffer); allocr.alloc = nullptr; } } @@ -722,8 +706,7 @@ struct whisper_state { // buffer for swapping KV caches between decoders during beam-search std::vector kv_swap_bufs; - // reusable buffer for `struct ggml_graph_plan.work_data` - std::vector work_buffer; + ggml_backend_t backend = nullptr; // ggml-alloc: // - stores meta info about the intermediate tensors into the `meta` buffers @@ -737,6 +720,9 @@ struct whisper_state { struct ggml_tensor * embd_conv = nullptr; struct ggml_tensor * embd_enc = nullptr; + // helper for GPU offloading + std::vector inp_mel; + // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -751,22 +737,21 @@ struct whisper_state { int lang_id = 0; // english by default std::string path_model; // populated by whisper_init_from_file_with_params() + #ifdef WHISPER_USE_COREML whisper_coreml_context * ctx_coreml = nullptr; #endif -#ifdef GGML_USE_METAL - ggml_metal_context * ctx_metal = nullptr; -#endif - #ifdef WHISPER_USE_OPENVINO whisper_openvino_context * ctx_openvino = nullptr; #endif // [EXPERIMENTAL] token-level timestamps data - int64_t t_beg = 0; + int64_t t_beg = 0; int64_t t_last = 0; + whisper_token tid_last; + std::vector energy; // PCM signal energy // [EXPERIMENTAL] speed-up techniques @@ -780,35 +765,25 @@ struct whisper_context { ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX) ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16) + whisper_context_params params; + whisper_model model; whisper_vocab vocab; + whisper_state * state = nullptr; + ggml_backend_t backend = nullptr; + std::string path_model; // populated by whisper_init_from_file_with_params() - whisper_context_params params; }; -static void whisper_default_log(const char * text) { - fprintf(stderr, "%s", text); -} +struct whisper_global { + // We save the log callback globally + ggml_log_callback log_callback = whisper_log_callback_default; + void * log_callback_user_data = nullptr; +}; -static whisper_log_callback whisper_log = whisper_default_log; - -#ifdef __GNUC__ -#ifdef __MINGW32__ -__attribute__((gnu_format(printf, 1, 2))) -#else -__attribute__((format(printf, 1, 2))) -#endif -#endif -static void log(const char * fmt, ...) { - if (!whisper_log) return; - char buf[1024]; - va_list args; - va_start(args, fmt); - vsnprintf(buf, sizeof(buf), fmt, args); - whisper_log(buf); -} +static whisper_global g_state; template static void read_safe(whisper_model_loader * loader, T & dest) { @@ -819,6 +794,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) { static bool kv_cache_init( const struct whisper_hparams & hparams, struct whisper_kv_cache & cache, + ggml_backend_t backend, ggml_type wtype, int n_ctx) { const int64_t n_text_state = hparams.n_text_state; @@ -827,30 +803,41 @@ static bool kv_cache_init( const int64_t n_mem = n_text_layer*n_ctx; const int64_t n_elements = n_text_state*n_mem; - const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead()); - - cache.buf.resize(mem_bytes); - struct ggml_init_params params = { - /*.mem_size =*/ cache.buf.size(), - /*.mem_buffer =*/ cache.buf.data(), - /*.no_alloc =*/ false, + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, }; cache.ctx = ggml_init(params); if (!cache.ctx) { - log("%s: failed to allocate memory for kv cache\n", __func__); + WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); return false; } cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v); + + cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes); + + // allocate the tensors into the backend buffer + { + ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer); + + ggml_allocr_alloc(alloc, cache.k); + ggml_allocr_alloc(alloc, cache.v); + + ggml_allocr_free(alloc); + } + return true; } -static bool kv_cache_reinit(struct whisper_kv_cache & cache) { +// TODO: remove after batched decoding +static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) { WHISPER_ASSERT(cache.ctx); const int n_elements = ggml_nelements(cache.k); @@ -859,34 +846,78 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) { const ggml_type wtype = cache.k->type; WHISPER_ASSERT(wtype == cache.v->type); - WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype)); - struct ggml_init_params params = { - /*.mem_size =*/ cache.buf.size(), - /*.mem_buffer =*/ cache.buf.data(), - /*.no_alloc =*/ false, + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, }; cache.ctx = ggml_init(params); if (!cache.ctx) { - log("%s: failed to allocate memory for kv cache\n", __func__); + WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); return false; } cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v); + + cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes); + + // allocate the tensors into the backend buffer + { + ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer); + + ggml_allocr_alloc(alloc, cache.k); + ggml_allocr_alloc(alloc, cache.v); + + ggml_allocr_free(alloc); + } + return true; } static void kv_cache_free(struct whisper_kv_cache & cache) { if (cache.ctx) { ggml_free(cache.ctx); + ggml_backend_buffer_free(cache.buffer); cache.ctx = nullptr; } } +static ggml_backend_t whisper_backend_init(const whisper_context_params & params) { + ggml_backend_t backend_gpu = NULL; + + // initialize the backends +#ifdef GGML_USE_CUBLAS + if (params.use_gpu) { + WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); + backend_gpu = ggml_backend_cuda_init(); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (params.use_gpu) { + WHISPER_LOG_INFO("%s: using Metal backend\n", __func__); + ggml_metal_log_set_callback(whisper_log_callback_default, nullptr); + backend_gpu = ggml_backend_metal_init(); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if (backend_gpu) { + return backend_gpu; + } + return ggml_backend_cpu_init(); +} + // load the model from a ggml file // // file format: @@ -899,7 +930,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) { // see the convert-pt-to-ggml.py script for details // static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { - log("%s: loading model\n", __func__); + WHISPER_LOG_INFO("%s: loading model\n", __func__); const int64_t t_start_us = ggml_time_us(); @@ -913,7 +944,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con uint32_t magic; read_safe(loader, magic); if (magic != GGML_FILE_MAGIC) { - log("%s: invalid model data (bad magic)\n", __func__); + WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); return false; } } @@ -970,41 +1001,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // in order to save memory and also to speed up the computation wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); if (wctx.wtype == GGML_TYPE_COUNT) { - log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); + WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); return false; } - const size_t scale = model.hparams.ftype ? 1 : 2; - - log("%s: n_vocab = %d\n", __func__, hparams.n_vocab); - log("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); - log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); - log("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); - log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); - log("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); - log("%s: n_text_state = %d\n", __func__, hparams.n_text_state); - log("%s: n_text_head = %d\n", __func__, hparams.n_text_head); - log("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); - log("%s: n_mels = %d\n", __func__, hparams.n_mels); - log("%s: ftype = %d\n", __func__, model.hparams.ftype); - log("%s: qntvr = %d\n", __func__, qntvr); - log("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str()); - - // print memory requirements - { - // TODO - //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, - // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); - } - - // initialize all memory buffers - // always have at least one decoder - - wctx.model.buf = new std::vector(); - wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type)); - - // we skip initialization of the state until it is needed - // because it might be that state will always be provided externally. + WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); + WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state); + WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head); + WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); + WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype); + WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str()); } // load mel filters @@ -1025,7 +1038,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con read_safe(loader, n_vocab); //if (n_vocab != model.hparams.n_vocab) { - // log("%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n", // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); // return false; //} @@ -1045,7 +1058,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word.assign(&tmp[0], tmp.size()); } else { // seems like we have an empty-string token in multi-language models (i = 50256) - //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + //WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); word = ""; } @@ -1073,7 +1086,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } if (n_vocab < model.hparams.n_vocab) { - log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); + WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); for (int i = n_vocab; i < model.hparams.n_vocab; i++) { if (i > vocab.token_beg) { word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; @@ -1099,140 +1112,35 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - log("%s: n_langs = %d\n", __func__, vocab.num_languages()); + WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages()); } - size_t ctx_size = 0; - const ggml_type wtype = wctx.wtype; const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type + // create the ggml context { const auto & hparams = model.hparams; - const int n_vocab = hparams.n_vocab; - - const int n_audio_ctx = hparams.n_audio_ctx; - const int n_audio_state = hparams.n_audio_state; const int n_audio_layer = hparams.n_audio_layer; + const int n_text_layer = hparams.n_text_layer; - const int n_text_ctx = hparams.n_text_ctx; - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; + const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer; - const int n_mels = hparams.n_mels; - - // encoder - { - ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe; - - ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(vtype); // e_conv_1_w - ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b - - ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(vtype); // e_conv_2_w - ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b - - ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w; - ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b; - } - - // decoder - { - ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe; - - ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te; - - ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w; - ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b; - } - - // encoder layers - { - ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b - - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_0_w - ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b - - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b - - ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_q_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_v_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_ln_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b - } - - // decoder layers - { - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b - - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_0_w - ctx_size += n_text_layer*( 4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b - - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b - - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b - // - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b - } - - ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead - - log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); - } - - // create the ggml context - { struct ggml_init_params params = { - /*.mem_size =*/ wctx.model.buf->size(), - /*.mem_buffer =*/ wctx.model.buf->data(), - /*.no_alloc =*/ false, + /*.mem_size =*/ n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, }; model.ctx = ggml_init(params); if (!model.ctx) { - log("%s: ggml_init() failed\n", __func__); + WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__); return false; } } - // prepare memory for the weights + // prepare tensors for the weights { auto & ctx = model.ctx; @@ -1255,16 +1163,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // encoder { - model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); + model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); - model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state); - model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state); + model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state); - model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state); - model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_ctx, n_audio_state); - model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); // map by name model.tensors["encoder.positional_embedding"] = model.e_pe; @@ -1428,12 +1336,37 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } + wctx.backend = whisper_backend_init(wctx.params); + + { + size_t size_main = 0; + + for (const auto & t : model.tensors) { + size_main += ggml_nbytes(t.second) + ggml_tensor_overhead(); + } + + model.buffer = ggml_backend_alloc_buffer(wctx.backend, size_main); + + WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1024.0 / 1024.0); + } + + ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer); + + // allocate tensors in the backend buffers + { + for (const auto & t : model.tensors) { + ggml_allocr_alloc(alloc, t.second); + } + } + // load weights { size_t total_size = 0; model.n_loaded = 0; + std::vector read_buf; + while (true) { int32_t n_dims; int32_t length; @@ -1460,50 +1393,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con name.assign(&tmp[0], tmp.size()); if (model.tensors.find(name) == model.tensors.end()) { - log("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); return false; } auto tensor = model.tensors[name.data()]; - if (ggml_nelements(tensor) != nelements) { - log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", - __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); - return false; + + const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias"); + + if (!is_conv_bias) { + if (ggml_nelements(tensor) != nelements) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); + return false; + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } } - if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { - log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", - __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); - return false; + ggml_backend_t backend = wctx.backend; + + //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str()); + + if ((ggml_backend_is_cpu(backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(backend) +#endif + ) && !is_conv_bias) { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + + // we repeat the 2 bias tensors along dim 0: + // [1, 512] -> [3000, 512] (conv1.bias) + // [1, 512] -> [1500, 512] (conv2.bias) + if (is_conv_bias) { + loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]); + + float * data_f32 = (float *) read_buf.data(); + for (int64_t y = 0; y < tensor->ne[1]; ++y) { + const int64_t yy = tensor->ne[1] - y - 1; + const float val = data_f32[yy]; + + for (int64_t x = 0; x < tensor->ne[0]; ++x) { + data_f32[yy*tensor->ne[0] + x] = val; + } + } + } else { + loader->read(loader->context, read_buf.data(), read_buf.size()); + } + + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); } - const size_t bpe = ggml_type_size(ggml_type(ttype)); - - if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { - log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); - return false; - } - - loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); - BYTESWAP_TENSOR(tensor); - //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0); total_size += ggml_nbytes(tensor); model.n_loaded++; } - log("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); + WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); if (model.n_loaded == 0) { - log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); } else if (model.n_loaded != (int) model.tensors.size()) { - log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); return false; } } + ggml_allocr_free(alloc); + wctx.t_load_us = ggml_time_us() - t_start_us; return true; @@ -1559,10 +1534,12 @@ static struct ggml_cgraph * whisper_build_graph_conv( if (!ggml_allocr_is_measure(alloc)) { assert(mel_inp.n_mel == n_mels); - float * dst = (float *) mel->data; + wstate.inp_mel.resize(ggml_nelements(mel)); + + float * dst = wstate.inp_mel.data(); memset(dst, 0, ggml_nbytes(mel)); - const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i0 = std::min(mel_offset, mel_inp.n_len); const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); for (int j = 0; j < mel_inp.n_mel; ++j) { @@ -1570,6 +1547,8 @@ static struct ggml_cgraph * whisper_build_graph_conv( dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; } } + + ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); } struct ggml_tensor * cur = nullptr; @@ -1578,24 +1557,27 @@ static struct ggml_cgraph * whisper_build_graph_conv( // convolution + gelu { cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_1_b, - cur), - cur); + cur = ggml_add(ctx0, cur, model.e_conv_1_b); + //cur = ggml_add(ctx0, + // ggml_repeat(ctx0, + // model.e_conv_1_b, + // cur), + // cur); cur = ggml_gelu(ctx0, cur); cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_2_b, - cur), - cur); + cur = ggml_add(ctx0, cur, model.e_conv_2_b); + //cur = ggml_add(ctx0, + // ggml_repeat(ctx0, + // model.e_conv_2_b, + // cur), + // cur); cur = ggml_gelu(ctx0, cur); } + ggml_set_name(cur, "embd_conv"); wstate.embd_conv = cur; } else { #ifdef WHISPER_USE_COREML @@ -1615,6 +1597,7 @@ static struct ggml_cgraph * whisper_build_graph_conv( } #endif + ggml_set_name(cur, "embd_enc"); wstate.embd_enc = cur; } @@ -1648,15 +1631,22 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_allocr * alloc = wstate.alloc_encode.alloc; + //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state); + //ggml_allocr_alloc(alloc, cur); + + //if (!ggml_allocr_is_measure(alloc)) { + // ggml_backend_tensor_copy(wstate.embd_conv, cur); + //} + struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); + struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(alloc, KQscale); if (!ggml_allocr_is_measure(alloc)) { - ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head)); + const float val = 1.0f/sqrtf(float(n_state)/n_head); + ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float)); } - struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); - // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) //static int iter = -1; @@ -1675,7 +1665,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder( const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); - cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); // =================================================================== @@ -1897,13 +1886,20 @@ static struct ggml_cgraph * whisper_build_graph_cross( ggml_allocr * alloc = wstate.alloc_cross.alloc; + //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); + //ggml_allocr_alloc(alloc, cur); + + //if (!ggml_allocr_is_measure(alloc)) { + // ggml_backend_tensor_copy(wstate.embd_enc, cur); + //} struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(alloc, Kscale); if (!ggml_allocr_is_measure(alloc)) { - ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25)); + const float val = pow(float(n_state) / n_head, -0.25); + ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float)); } for (int il = 0; il < model.hparams.n_text_layer; ++il) { @@ -1974,7 +1970,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); if (!whisper_encode_external(wstate)) { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } } @@ -1988,16 +1984,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); -#ifdef GGML_USE_METAL - if (wstate.ctx_metal) { - ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - ggml_metal_graph_compute(wstate.ctx_metal, gf); - } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); - } -#else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); -#endif + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } // cross @@ -2010,20 +1997,9 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); -#ifdef GGML_USE_METAL - if (wstate.ctx_metal) { - ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - ggml_metal_graph_compute(wstate.ctx_metal, gf); - } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); - } -#else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); -#endif + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } - // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - wstate.t_encode_us += ggml_time_us() - t_start_us; wstate.n_encode++; @@ -2070,7 +2046,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_allocr_alloc(alloc, embd); if (!ggml_allocr_is_measure(alloc)) { - memcpy(embd->data, tokens, N*ggml_element_size(embd)); + ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd)); } struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); @@ -2078,7 +2054,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( if (!ggml_allocr_is_measure(alloc)) { for (int i = 0; i < N; ++i) { - ((int32_t *) position->data)[i] = n_past + i; + const int32_t val = n_past + i; + ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t)); } } @@ -2086,7 +2063,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_allocr_alloc(alloc, KQscale); if (!ggml_allocr_is_measure(alloc)) { - ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25)); + const float val = pow(float(n_state)/n_head, -0.25); + ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float)); } // token encoding + position encoding @@ -2410,25 +2388,18 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; -#ifdef GGML_USE_METAL - if (wstate.ctx_metal) { - ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - ggml_metal_graph_compute(wstate.ctx_metal, gf); - } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); - } -#else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); -#endif + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } // extract logits for all N tokens //logits_out.resize(n_tokens*n_vocab); //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab); + //ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab); // extract logits only for the last token logits_out.resize(n_vocab); - memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); + //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); + ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab); if (n_tokens > 1) { //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, @@ -2794,7 +2765,7 @@ static std::vector tokenize(const whisper_vocab & vocab, cons --j; } if (!found) { - log("unknown token\n"); + WHISPER_LOG_ERROR("unknown token\n"); ++i; } } @@ -2857,45 +2828,48 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) { struct whisper_state * whisper_init_state(whisper_context * ctx) { fill_sin_cos_table(); + whisper_state * state = new whisper_state; - if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) { - log("%s: kv_cache_init() failed for self-attention cache\n", __func__); + state->backend = whisper_backend_init(ctx->params); + + if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state; return nullptr; } { const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v); - log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) { - log("%s: kv_cache_init() failed for cross-attention cache\n", __func__); + if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); delete state; return nullptr; } { const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v); - log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } #ifdef WHISPER_USE_COREML const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); - log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); - log("%s: first run on a device may take a while ...\n", __func__); + WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); + WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); state->ctx_coreml = whisper_coreml_init(path_coreml.c_str()); if (!state->ctx_coreml) { - log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); + WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); #ifndef WHISPER_COREML_ALLOW_FALLBACK delete state; return nullptr; #endif } else { - log("%s: Core ML model loaded\n", __func__); + WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__); } #endif @@ -2912,37 +2886,37 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // conv allocator { - whisper_allocr_graph_init(state->alloc_conv, + whisper_allocr_graph_init(state->alloc_conv, ctx->backend, [&]() { return whisper_build_graph_conv(*ctx, *state, 0); }); - log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0); } // encoder allocator if (!whisper_encode_external(*state)) { - whisper_allocr_graph_init(state->alloc_encode, + whisper_allocr_graph_init(state->alloc_encode, ctx->backend, [&]() { return whisper_build_graph_encoder(*ctx, *state); }); - log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0); } // cross allocator { - whisper_allocr_graph_init(state->alloc_cross, + whisper_allocr_graph_init(state->alloc_cross, ctx->backend, [&]() { return whisper_build_graph_cross(*ctx, *state); }); - log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0); } // decoder allocator { - whisper_allocr_graph_init(state->alloc_decode, + whisper_allocr_graph_init(state->alloc_decode, ctx->backend, [&]() { const auto & hparams = ctx->model.hparams; @@ -2953,69 +2927,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past); }); - log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); } -#ifdef GGML_USE_METAL - if (ctx->params.use_gpu) { - state->ctx_metal = ggml_metal_init(1); - if (!state->ctx_metal) { - log("%s: ggml_metal_init() failed\n", __func__); - delete state; - return nullptr; - } - } - - if (state->ctx_metal) { - log("%s: Metal context initialized\n", __func__); - - // this allocates all Metal resources and memory buffers - - void * data_ptr = NULL; - size_t data_size = 0; - - // TODO: add mmap support - //if (params.use_mmap) { - // data_ptr = ctx->model.mapping->addr; - // data_size = ctx->model.mapping->size; - //} else { - // data_ptr = ggml_get_mem_buffer(ctx->model.ctx); - // data_size = ggml_get_mem_size (ctx->model.ctx); - //} - - data_ptr = ggml_get_mem_buffer(ctx->model.ctx); - data_size = ggml_get_mem_size (ctx->model.ctx); - - const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); - - log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); - -#define WHISPER_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - log("%s: failed to add metal buffer\n", __func__); \ - delete state; \ - return nullptr; \ - } - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size)); - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0)); - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0)); - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0)); - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0)); -#undef WHISPER_METAL_CHECK_BUF - - } -#endif + whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend); + whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend); + whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend); + whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend); state->rng = std::mt19937(0); @@ -3036,7 +2954,7 @@ int whisper_ctx_init_openvino_encoder( return 1; #else if (!model_path && ctx->path_model.empty()) { - log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); + WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); return 1; } @@ -3056,15 +2974,15 @@ int whisper_ctx_init_openvino_encoder( path_cache = cache_dir; } - log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str()); - log("%s: first run on a device may take a while ...\n", __func__); + WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str()); + WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str()); if (!ctx->state->ctx_openvino) { - log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); + WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); return 1; } else { - log("%s: OpenVINO model loaded\n", __func__); + WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__); } return 0; @@ -3079,11 +2997,11 @@ struct whisper_context_params whisper_context_default_params() { } struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) { - log("%s: loading model from '%s'\n", __func__, path_model); + WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); auto fin = std::ifstream(path_model, std::ios::binary); if (!fin) { - log("%s: failed to open '%s'\n", __func__, path_model); + WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); return nullptr; } @@ -3125,7 +3043,7 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; - log("%s: loading model from buffer\n", __func__); + WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__); whisper_model_loader loader = {}; @@ -3161,7 +3079,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ if (!whisper_model_load(loader, *ctx)) { loader->close(loader->context); - log("%s: failed to load model\n", __func__); + WHISPER_LOG_ERROR("%s: failed to load model\n", __func__); delete ctx; return nullptr; } @@ -3256,13 +3174,6 @@ void whisper_free_state(struct whisper_state * state) } #endif -#ifdef GGML_USE_METAL - if (state->ctx_metal) { - ggml_metal_free(state->ctx_metal); - state->ctx_metal = nullptr; - } -#endif - #ifdef WHISPER_USE_OPENVINO if (state->ctx_openvino != nullptr) { whisper_openvino_free(state->ctx_openvino); @@ -3271,9 +3182,11 @@ void whisper_free_state(struct whisper_state * state) #endif whisper_allocr_free(state->alloc_conv); - whisper_allocr_free(state->alloc_decode); - whisper_allocr_free(state->alloc_cross); whisper_allocr_free(state->alloc_encode); + whisper_allocr_free(state->alloc_cross); + whisper_allocr_free(state->alloc_decode); + + ggml_backend_free(state->backend); delete state; } @@ -3284,12 +3197,15 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.ctx) { ggml_free(ctx->model.ctx); } - if (ctx->model.buf) { - delete ctx->model.buf; + + if (ctx->model.buffer) { + ggml_backend_buffer_free(ctx->model.buffer); } whisper_free_state(ctx->state); + ggml_backend_free(ctx->backend); + delete ctx; } } @@ -3308,7 +3224,7 @@ void whisper_free_params(struct whisper_full_params * params) { int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { - log("%s: failed to compute mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -3322,7 +3238,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { - log("%s: failed to compute mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -3350,7 +3266,7 @@ int whisper_set_mel_with_state( int n_len, int n_mel) { if (n_mel != ctx->model.filters.n_mel) { - log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); + WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); return -1; } @@ -3374,7 +3290,7 @@ int whisper_set_mel( int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return -1; } @@ -3383,7 +3299,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return -1; } @@ -3394,7 +3310,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state const int selected_decoder_id = 0; if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3406,12 +3322,12 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i const int selected_decoder_id = 0; if (ctx->state == nullptr) { - log("%s: ERROR state was not loaded.\n", __func__); + WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__); return false; } if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3422,7 +3338,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to const auto res = tokenize(ctx->vocab, text); if (n_max_tokens < (int) res.size()) { - log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); return -1; } @@ -3450,7 +3366,7 @@ int whisper_lang_id(const char * lang) { } } - log("%s: unknown language '%s'\n", __func__, lang); + WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang); return -1; } return g_lang.at(lang).first; @@ -3463,7 +3379,7 @@ const char * whisper_lang_str(int id) { } } - log("%s: unknown language id %d\n", __func__, id); + WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id); return nullptr; } @@ -3476,25 +3392,25 @@ int whisper_lang_auto_detect_with_state( const int seek = offset_ms/10; if (seek < 0) { - log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms); + WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms); return -1; } if (seek >= state->mel.n_len_org) { - log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10); + WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10); return -2; } // run the encoder if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) { - log("%s: failed to encode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); return -6; } const std::vector prompt = { whisper_token_sot(ctx) }; if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) { - log("%s: failed to decode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } @@ -3694,8 +3610,8 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) { void whisper_print_timings(struct whisper_context * ctx) { const int64_t t_end_us = ggml_time_us(); - log("\n"); - log("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + WHISPER_LOG_INFO("\n"); + WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); if (ctx->state != nullptr) { const int32_t n_sample = std::max(1, ctx->state->n_sample); @@ -3703,14 +3619,14 @@ void whisper_print_timings(struct whisper_context * ctx) { const int32_t n_decode = std::max(1, ctx->state->n_decode); const int32_t n_prompt = std::max(1, ctx->state->n_prompt); - log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); - log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); - log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); - log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); - log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); - log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); + WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); } - log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); + WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } void whisper_reset_timings(struct whisper_context * ctx) { @@ -3762,6 +3678,7 @@ const char * whisper_print_system_info(void) { s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | "; s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + s += "CUDA = " + std::to_string(ggml_cpu_has_cublas()) + " | "; s += "COREML = " + std::to_string(whisper_has_coreml()) + " | "; s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | "; @@ -4056,7 +3973,7 @@ static void whisper_process_logits( const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; - //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + //WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); if (last_was_timestamp) { if (penultimate_was_timestamp) { @@ -4132,7 +4049,7 @@ static void whisper_process_logits( const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); - //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); if (timestamp_logprob > max_text_token_logprob) { for (int i = 0; i < vocab.token_beg; ++i) { @@ -4427,8 +4344,10 @@ static bool whisper_kv_swap_fast( for (auto & i : two_copy) { // make a copy of KV caches WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i); - memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size()); - memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size()); + //memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size()); + //memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size()); + ggml_backend_tensor_get(src[i].kv_self.k, kv_swap_bufs[i].k.data(), 0, kv_swap_bufs[i].k.size()); + ggml_backend_tensor_get(src[i].kv_self.v, kv_swap_bufs[i].v.data(), 0, kv_swap_bufs[i].v.size()); } // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first @@ -4441,13 +4360,17 @@ static bool whisper_kv_swap_fast( if (two_copy.find(view[i]) != two_copy.end()) { // modify KV caches of decoder using data from kv_swap_bufs WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); - memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); - memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); + //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); + //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); + ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size()); + ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size()); } else { // modify KV caches of decoder using data from correspond decoder KV caches directly WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); - memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); - memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); + //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); + //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); + ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k); + ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v); } } @@ -4461,13 +4384,17 @@ static bool whisper_kv_swap_fast( if (two_copy.find(view[i]) != two_copy.end()) { // modify KV caches of decoder using data from kv_swap_bufs WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i); - memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); - memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); + //memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size()); + //memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size()); + ggml_backend_tensor_set(src[i].kv_self.k, kv_swap_bufs[view[i]].k.data(), 0, kv_swap_bufs[view[i]].k.size()); + ggml_backend_tensor_set(src[i].kv_self.v, kv_swap_bufs[view[i]].v.data(), 0, kv_swap_bufs[view[i]].v.size()); } else { // modify KV caches of decoder using data from correspond decoder KV caches directly WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i); - memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); - memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); + //memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k)); + //memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v)); + ggml_backend_tensor_copy(src[view[i]].kv_self.k, src[i].kv_self.k); + ggml_backend_tensor_copy(src[view[i]].kv_self.v, src[i].kv_self.v); } } @@ -4495,11 +4422,11 @@ int whisper_full_with_state( // compute log mel spectrogram if (params.speed_up) { // TODO: Replace PV with more advanced algorithm - log("%s: failed to compute log mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); return -1; } else { if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { - log("%s: failed to compute log mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); return -2; } } @@ -4511,13 +4438,13 @@ int whisper_full_with_state( const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); if (lang_id < 0) { - log("%s: failed to auto-detect language\n", __func__); + WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__); return -3; } state->lang_id = lang_id; params.language = whisper_lang_str(lang_id); - log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); if (params.detect_language) { return 0; } @@ -4575,8 +4502,8 @@ int whisper_full_with_state( if (decoder.kv_self.ctx == nullptr) { decoder.kv_self = state->decoders[0].kv_self; - if (!kv_cache_reinit(decoder.kv_self)) { - log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); + if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) { + WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); return -4; } @@ -4587,23 +4514,6 @@ int whisper_full_with_state( decoder.probs.resize (ctx->vocab.n_vocab); decoder.logits.resize (ctx->vocab.n_vocab); decoder.logprobs.resize(ctx->vocab.n_vocab); - - // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0 -#ifdef GGML_USE_METAL - if (state->ctx_metal) { -#define WHISPER_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - log("%s: failed to add metal buffer\n", __func__); \ - return 0; \ - } - - const std::string kv_name = "kv_self_" + std::to_string(j); - auto & kv_self = decoder.kv_self; - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0)); -#undef WHISPER_METAL_CHECK_BUF - } -#endif } } @@ -4637,7 +4547,7 @@ int whisper_full_with_state( // overwrite audio_ctx, max allowed is hparams.n_audio_ctx if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { - log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); return -5; } state->exp_n_audio_ctx = params.audio_ctx; @@ -4662,7 +4572,7 @@ int whisper_full_with_state( // distilled models require the "no_timestamps" token // TODO: add input parameter (#1229) if (is_distil) { - log("%s: using distilled model - forcing no_timestamps\n", __func__); + WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__); prompt_init.push_back(whisper_token_not(ctx)); } } @@ -4699,14 +4609,14 @@ int whisper_full_with_state( if (params.encoder_begin_callback) { if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) { - log("%s: encoder_begin_callback returned false - aborting\n", __func__); + WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); break; } } // encode audio features starting at offset seek if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { - log("%s: failed to encode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); return -6; } @@ -4789,7 +4699,7 @@ int whisper_full_with_state( WHISPER_PRINT_DEBUG("\n\n"); if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { - log("%s: failed to decode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } @@ -4803,8 +4713,11 @@ int whisper_full_with_state( for (int j = 1; j < n_decoders_cur; ++j) { auto & decoder = state->decoders[j]; - memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); - memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); + // TODO: fix CUDA + //memcpy(decoder.kv_self.k->data, state->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); + //memcpy(decoder.kv_self.v->data, state->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); + ggml_backend_tensor_copy(state->decoders[0].kv_self.k, decoder.kv_self.k); + ggml_backend_tensor_copy(state->decoders[0].kv_self.v, decoder.kv_self.v); decoder.kv_self.n += prompt.size(); @@ -5013,7 +4926,7 @@ int whisper_full_with_state( //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { - log("%s: failed to decode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -8; } @@ -5339,12 +5252,12 @@ int whisper_full_parallel( ctx->state->t_decode_us /= n_processors; // print information about the audio boundaries - log("\n"); - log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + WHISPER_LOG_WARN("\n"); + WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); for (int i = 0; i < n_processors - 1; ++i) { - log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); + WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); } - log("%s: the transcription quality may be degraded near these boundaries\n", __func__); + WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__); return ret; } @@ -5586,12 +5499,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { double tsum = 0.0; // heat-up - ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr); + ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr); for (int i = 0; i < n_max; ++i) { const int64_t t0 = ggml_time_us(); - ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr); + ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr); const int64_t t1 = ggml_time_us(); @@ -5709,7 +5622,7 @@ static void whisper_exp_compute_token_level_timestamps( const int n_samples = state.energy.size(); if (n_samples == 0) { - log("%s: no signal data available\n", __func__); + WHISPER_LOG_ERROR("%s: no signal data available\n", __func__); return; } @@ -5930,6 +5843,38 @@ static void whisper_exp_compute_token_level_timestamps( //} } -void whisper_set_log_callback(whisper_log_callback callback) { - whisper_log = callback; +void whisper_log_set(ggml_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default; + g_state.log_callback_user_data = user_data; +} + +static void whisper_log_internal_v(ggml_log_level level, const char * format, va_list args) { + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args_copy); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args_copy); +} + +static void whisper_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + whisper_log_internal_v(level, format, args); + va_end(args); +} + +static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); } diff --git a/whisper.h b/whisper.h index ed1612b..0ea5237 100644 --- a/whisper.h +++ b/whisper.h @@ -1,6 +1,8 @@ #ifndef WHISPER_H #define WHISPER_H +#include "ggml.h" + #include #include #include @@ -110,15 +112,15 @@ extern "C" { // Various functions for loading a ggml whisper model. // Allocate (almost) all memory needed for the model. // Return NULL on failure - WHISPER_API struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params); // These are the same as the above, but the internal state of the context is not allocated automatically // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) - WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params); WHISPER_DEPRECATED( WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model), @@ -570,8 +572,7 @@ extern "C" { // Control logging output; default behavior is to print to stderr - typedef void (*whisper_log_callback)(const char * line); - WHISPER_API void whisper_set_log_callback(whisper_log_callback callback); + WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data); #ifdef __cplusplus }