From 93935980f8bcc3d230d313174ff59635c3c80d1b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 15 Sep 2023 12:18:18 +0300 Subject: [PATCH] whisper : Metal and ggml-alloc support (#1270) * metal : init * whisper : factor out graph builds * whisper : allocate encoder and decoder using ggml-alloc * whisper : ggml-alloc is now supported * whisper : CoreML support ggml-alloc * build : fix ggml-alloc * ios : update submodule * extra : update sync-ggml.sh script to also sync ggml-alloc * ci : see if this is causing the crash * whisper : refactor ggml-alloc init * whisper.android : try to fix build * whisper : initial Metal version * ci : try to debug vmem issue * metal : decoder works on GPU! * metal : add multi-decoder support * ggml : fix ggml_nbytes (probably temp solution) * metal : run "cross" step on the GPU * whisper : remove ggml_repeat in the encoder * whisper : offload the Encoder to Metal * ggml : use simpler ggml_bytes() implementation * ggml-alloc : try to make CI happy by reducing vram to 128GB * whisper : add whisper_allocr to wrap ggml_allocr * whisper : factor out alloc init in a function * cmake : update to support Metal build * whisper : add header * objc : fix build (no Metal yet) * ios : add Metal support * swiftui : fix build * metal : speed-up KQ multiplication * metal : sync latest llama.cpp kernels * readme : add Metal info * ios : update submodule * coreml : add code to toggle Core ML config (CPU, ANE, GPU) * bench : fix timings by running a pre-heat * bench : start benching the decoder * whisper : add ggml_mul_mat_pad * bench : fix uninitialized vars * whisper : add comment for disabling mul-mat padding * whisper : add description of ggml_mul_mat_pad * whisper : clean-up ggml_mul_mat_pad * metal : remove the "concurrent" flag * bench : variable n_past * ios : update SPM package --- CMakeLists.txt | 66 +- Makefile | 23 +- README.md | 8 +- bindings/ios | 2 +- coreml/whisper-encoder.mm | 8 +- examples/bench/bench.cpp | 45 +- examples/talk-llama/CMakeLists.txt | 2 +- .../app/src/main/jni/whisper/CMakeLists.txt | 3 +- examples/whisper.objc/README.md | 12 + .../whisper.objc.xcodeproj/project.pbxproj | 29 +- .../whisper.swiftui.xcodeproj/project.pbxproj | 10 +- extra/bench-all.sh | 19 +- extra/sync-ggml.sh | 28 +- ggml-alloc.c | 187 +- ggml-metal.m | 107 +- ggml-metal.metal | 497 ++-- ggml.c | 22 +- whisper.cpp | 2039 +++++++++-------- 18 files changed, 1855 insertions(+), 1252 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c50e2a..8b800e0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required (VERSION 3.0) +cmake_minimum_required (VERSION 3.5) project(whisper.cpp VERSION 1.4.2) @@ -35,6 +35,12 @@ endif() # options +if (APPLE) + set(WHISPER_METAL_DEFAULT ON) +else() + set(WHISPER_METAL_DEFAULT OFF) +endif() + option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT}) option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON) @@ -58,6 +64,8 @@ option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF) if (APPLE) option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF) + option(WHISPER_METAL "whisper: use Metal" ${WHISPER_METAL_DEFAULT}) + option(WHISPER_METAL_NDEBUG "whisper: disable Metal debugging" OFF) option(WHISPER_COREML "whisper: enable Core ML framework" OFF) option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF) else() @@ -113,6 +121,34 @@ if (APPLE) endif() endif() + if (WHISPER_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + + if (METAL_FRAMEWORK) + message(STATUS "Metal framework found") + + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} + ${FOUNDATION_LIBRARY} + ${METAL_FRAMEWORK} + ${METALKIT_FRAMEWORK} + ) + set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_METAL) + + if (WHISPER_METAL_NDEBUG) + set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_METAL_NDEBUG) + endif() + else() + message(WARNING "Metal framework not found") + endif() + + set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h) + + # copy ggml-metal.metal to bin directory + configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY) + endif() + if (WHISPER_COREML) find_library(FOUNDATION_FRAMEWORK Foundation) find_library(COREML_FRAMEWORK CoreML) @@ -177,7 +213,7 @@ if (WHISPER_CUBLAS) enable_language(CUDA) - set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) + set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) @@ -228,7 +264,7 @@ if (WHISPER_CLBLAST) if (CLBlast_FOUND) message(STATUS "CLBlast found") - set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h) + set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h) add_compile_definitions(GGML_USE_CLBLAST) @@ -426,8 +462,11 @@ set(TARGET whisper) add_library(${TARGET} ggml.h ggml.c - ${GGML_CUDA_SOURCES} - ${GGML_OPENCL_SOURCES} + ggml-alloc.h + ggml-alloc.c + ${GGML_SOURCES_METAL} + ${GGML_SOURCES_CUDA} + ${GGML_SOURCES_OPENCL} whisper.h whisper.cpp ) @@ -468,9 +507,15 @@ if (BUILD_SHARED_LIBS) WHISPER_BUILD GGML_BUILD ) + + if (WHISPER_METAL) + # TODO: I think this should make ggml-metal.m "see" the ggml-metal.metal file from the "bin" directory + # but for some reason it does not work here like it does in llama.cpp + set_target_properties(${TARGET} PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") + endif() endif() -if (GGML_CUDA_SOURCES) +if (GGML_SOURCES_CUDA) message(STATUS "GGML CUDA sources found, configuring CUDA architecture") set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF) set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") @@ -486,10 +531,13 @@ target_compile_definitions(${TARGET} PUBLIC set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h") +include(GNUInstallDirs) + install(TARGETS ${TARGET} - LIBRARY DESTINATION lib - ARCHIVE DESTINATION lib/static - RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib/static + RUNTIME DESTINATION bin + RESOURCE DESTINATION bin PUBLIC_HEADER DESTINATION include ) diff --git a/Makefile b/Makefile index ecbbcff..2df5111 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ ifndef NVCC_VERSION endif endif -CCV := $(shell $(CC) --version | head -n 1) +CCV := $(shell $(CC) --version | head -n 1) CXXV := $(shell $(CXX) --version | head -n 1) # Mac OS + Arm can report x86_64 @@ -182,6 +182,15 @@ ifdef WHISPER_COREML_ALLOW_FALLBACK endif endif +ifndef WHISPER_NO_METAL + ifeq ($(UNAME_S),Darwin) + WHISPER_METAL := 1 + + CXXFLAGS += -DGGML_USE_METAL + LDFLAGS += -framework Foundation -framework Metal -framework MetalKit + endif +endif + ifdef WHISPER_OPENBLAS CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas LDFLAGS += -lopenblas @@ -288,6 +297,11 @@ $(info ) ggml.o: ggml.c ggml.h ggml-cuda.h $(CC) $(CFLAGS) -c $< -o $@ +ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h + $(CC) $(CFLAGS) -c $< -o $@ + +WHISPER_OBJ += ggml-alloc.o + whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h $(CXX) $(CXXFLAGS) -c $< -o $@ @@ -303,6 +317,13 @@ whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-imp WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o endif +ifdef WHISPER_METAL +ggml-metal.o: ggml-metal.m ggml-metal.h + $(CC) $(CFLAGS) -c $< -o $@ + +WHISPER_OBJ += ggml-metal.o +endif + libwhisper.a: ggml.o $(WHISPER_OBJ) $(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ) diff --git a/README.md b/README.md index 5f18060..3707b93 100644 --- a/README.md +++ b/README.md @@ -11,14 +11,14 @@ Beta: [v1.4.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.2) / S High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: - Plain C/C++ implementation without dependencies -- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate framework and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support) +- Apple Silicon first-class citizen - optimized via ARM NEON, Accelerate framework, Metal and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support) - AVX intrinsics support for x86 architectures - VSX intrinsics support for POWER architectures - Mixed F16 / F32 precision - [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization) - Low memory usage (Flash Attention) - Zero memory allocations at runtime -- Runs on the CPU +- Support for CPU-only inference - [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas) - [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast) - [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas) @@ -50,6 +50,10 @@ You can also easily make your own offline voice assistant application: [command] https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4 +On Apply Silicon, the inference runs fully on the GPU via Metal: + +https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225 + Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm) ## Implementation details diff --git a/bindings/ios b/bindings/ios index de46d9e..22a9eef 160000 --- a/bindings/ios +++ b/bindings/ios @@ -1 +1 @@ -Subproject commit de46d9e7817fe851c109d66080239d415812d32a +Subproject commit 22a9eef021afc67f2154bc9811ed620b26299d1b diff --git a/coreml/whisper-encoder.mm b/coreml/whisper-encoder.mm index 6cd90ed..499edae 100644 --- a/coreml/whisper-encoder.mm +++ b/coreml/whisper-encoder.mm @@ -22,7 +22,13 @@ struct whisper_coreml_context * whisper_coreml_init(const char * path_model) { NSURL * url_model = [NSURL fileURLWithPath: path_model_str]; - const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]); + // select which device to run the Core ML model on + MLModelConfiguration *config = [[MLModelConfiguration alloc] init]; + config.computeUnits = MLComputeUnitsCPUAndGPU; + //config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; + //config.computeUnits = MLComputeUnitsAll; + + const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model configuration:config error:nil]); if (data == NULL) { return NULL; diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index 49daaa0..ac0e6bb 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -44,13 +44,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what); - fprintf(stderr, " %-7s 0 - whisper encoder\n", ""); + fprintf(stderr, " %-7s 0 - whisper\n", ""); fprintf(stderr, " %-7s 1 - memcpy\n", ""); fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", ""); fprintf(stderr, "\n"); } -int whisper_bench_encoder(const whisper_params & params) { +int whisper_bench_full(const whisper_params & params) { // whisper init struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); @@ -69,12 +69,49 @@ int whisper_bench_encoder(const whisper_params & params) { fprintf(stderr, "error: failed to set mel: %d\n", ret); return 3; } - + // heat encoder if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { fprintf(stderr, "error: failed to encode model: %d\n", ret); return 4; } + whisper_token tokens[512]; + memset(tokens, 0, sizeof(tokens)); + + // prompt heat + if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to encode model: %d\n", ret); + return 4; + } + + // text-generation heat + if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) { + fprintf(stderr, "error: failed to encode model: %d\n", ret); + return 4; + } + + whisper_reset_timings(ctx); + + // actual run + if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to encode model: %d\n", ret); + return 4; + } + + for (int i = 0; i < 16; i++) { + if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to encode model: %d\n", ret); + return 4; + } + } + + for (int i = 0; i < 256; i++) { + if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) { + fprintf(stderr, "error: failed to encode model: %d\n", ret); + return 4; + } + } + whisper_print_timings(ctx); whisper_free(ctx); @@ -103,7 +140,7 @@ int main(int argc, char ** argv) { int ret = -1; switch (params.what) { - case 0: ret = whisper_bench_encoder(params); break; + case 0: ret = whisper_bench_full(params); break; case 1: ret = whisper_bench_memcpy(params.n_threads); break; case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break; default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break; diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index cbdfb41..af5b547 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -7,7 +7,7 @@ if (WHISPER_SDL2) # TODO: this is temporary # need to export ggml symbols for MSVC, but too lazy .. - add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp) + add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../ggml-alloc.c ../../whisper.cpp) target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../) target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) diff --git a/examples/whisper.android/app/src/main/jni/whisper/CMakeLists.txt b/examples/whisper.android/app/src/main/jni/whisper/CMakeLists.txt index 55a4725..eac718a 100644 --- a/examples/whisper.android/app/src/main/jni/whisper/CMakeLists.txt +++ b/examples/whisper.android/app/src/main/jni/whisper/CMakeLists.txt @@ -8,6 +8,7 @@ set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../../) set( SOURCE_FILES ${WHISPER_LIB_DIR}/ggml.c + ${WHISPER_LIB_DIR}/ggml-alloc.c ${WHISPER_LIB_DIR}/whisper.cpp ${CMAKE_SOURCE_DIR}/jni.c ) @@ -20,7 +21,7 @@ function(build_library target_name) SHARED ${SOURCE_FILES} ) - + target_link_libraries(${target_name} ${LOG_LIB} android) if (${target_name} STREQUAL "whisper_v8fp16_va") diff --git a/examples/whisper.objc/README.md b/examples/whisper.objc/README.md index 6833ebb..bb55653 100644 --- a/examples/whisper.objc/README.md +++ b/examples/whisper.objc/README.md @@ -28,6 +28,8 @@ This can significantly improve the performance of the transcription: image +## Core ML + If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK` compiler flag for `whisper.cpp` in Build Phases: image @@ -35,3 +37,13 @@ If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DW Then follow the [`Core ML support` section of readme](../../README.md#core-ml-support) for convert the model. In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project. + +## Metal + +You can also enable Metal to make the inference run on the GPU of your device. This might or might not be more efficient +compared to Core ML depending on the model and device that you use. + +To enable Metal, just add `-DGGML_USE_METAL` instead off the `-DWHISPER_USE_COREML` flag and you are ready. +This will make both the Encoder and the Decoder run on the GPU. + +If you want to run the Encoder with Core ML and the Decoder with Metal then simply add both `-DWHISPER_USE_COREML -DGGML_USE_METAL` flags. That's all! diff --git a/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj b/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj index 49bd74e..f34b9c5 100644 --- a/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj +++ b/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj @@ -7,6 +7,9 @@ objects = { /* Begin PBXBuildFile section */ + 1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 184447182AB211A2007D6BFE /* ggml-alloc.c */; }; + 1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */ = {isa = PBXBuildFile; fileRef = 1844471B2AB21655007D6BFE /* ggml-metal.m */; settings = {COMPILER_FLAGS = "-framework Foundation -framework Metal -framework MetalKit -fno-objc-arc"; }; }; + 184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */ = {isa = PBXBuildFile; fileRef = 1844471D2AB2195F007D6BFE /* ggml-metal.metal */; }; 18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7A29052BDF00BD2A04 /* AppDelegate.m */; }; 18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7D29052BDF00BD2A04 /* SceneDelegate.m */; }; 18627C8129052BDF00BD2A04 /* ViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8029052BDF00BD2A04 /* ViewController.m */; }; @@ -14,7 +17,7 @@ 18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; }; 18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; }; 18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; }; - 18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK"; }; }; + 18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML"; }; }; 18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; }; 18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; }; 7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; }; @@ -23,7 +26,24 @@ 7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */ = {isa = PBXBuildFile; fileRef = 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */; }; /* End PBXBuildFile section */ +/* Begin PBXCopyFilesBuildPhase section */ + 184447202AB21B25007D6BFE /* CopyFiles */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = ""; + dstSubfolderSpec = 7; + files = ( + 184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXCopyFilesBuildPhase section */ + /* Begin PBXFileReference section */ + 184447182AB211A2007D6BFE /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-alloc.c"; path = "../../../ggml-alloc.c"; sourceTree = ""; }; + 184447192AB211A2007D6BFE /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-alloc.h"; path = "../../../ggml-alloc.h"; sourceTree = ""; }; + 1844471B2AB21655007D6BFE /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = "ggml-metal.m"; path = "../../../ggml-metal.m"; sourceTree = ""; }; + 1844471D2AB2195F007D6BFE /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; name = "ggml-metal.metal"; path = "../../../ggml-metal.metal"; sourceTree = ""; }; 18627C7629052BDF00BD2A04 /* whisper.objc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.objc.app; sourceTree = BUILT_PRODUCTS_DIR; }; 18627C7929052BDF00BD2A04 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; 18627C7A29052BDF00BD2A04 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = ""; }; @@ -80,6 +100,10 @@ 18627C7829052BDF00BD2A04 /* whisper.objc */ = { isa = PBXGroup; children = ( + 1844471D2AB2195F007D6BFE /* ggml-metal.metal */, + 1844471B2AB21655007D6BFE /* ggml-metal.m */, + 184447182AB211A2007D6BFE /* ggml-alloc.c */, + 184447192AB211A2007D6BFE /* ggml-alloc.h */, 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */, 7FE342442A0C3FA20015A058 /* coreml */, 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */, @@ -126,6 +150,7 @@ 18627C7229052BDF00BD2A04 /* Sources */, 18627C7329052BDF00BD2A04 /* Frameworks */, 18627C7429052BDF00BD2A04 /* Resources */, + 184447202AB21B25007D6BFE /* CopyFiles */, ); buildRules = ( ); @@ -194,8 +219,10 @@ 18627C9629052C5800BD2A04 /* ggml.c in Sources */, 18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */, 7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */, + 1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */, 18627C8C29052BE000BD2A04 /* main.m in Sources */, 18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */, + 1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */, 7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */, ); runOnlyForDeploymentPostprocessing = 0; diff --git a/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj b/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj index ab9f688..d2d0b05 100644 --- a/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj +++ b/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj @@ -20,6 +20,7 @@ 0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; }; 0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; }; 0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; }; + 18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 18AED47F2AB21F2B009D854F /* ggml-alloc.c */; }; /* End PBXBuildFile section */ /* Begin PBXFileReference section */ @@ -41,6 +42,8 @@ 0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = ""; }; 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = ""; }; 0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = ""; }; + 18AED47F2AB21F2B009D854F /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-alloc.c"; sourceTree = ""; }; + 18AED4802AB21F2B009D854F /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-alloc.h"; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -124,6 +127,8 @@ 0AAC5DC529539E89003032C3 /* whisper.cpp */ = { isa = PBXGroup; children = ( + 18AED47F2AB21F2B009D854F /* ggml-alloc.c */, + 18AED4802AB21F2B009D854F /* ggml-alloc.h */, 0AAC5DC929539EB0003032C3 /* ggml.c */, 0AAC5DCA29539EB0003032C3 /* ggml.h */, 0AAC5DC729539EB0003032C3 /* whisper.cpp */, @@ -242,6 +247,7 @@ 0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */, 0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */, 0AA7514E2953D958001EE061 /* Recorder.swift in Sources */, + 18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -369,7 +375,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\""; - DEVELOPMENT_TEAM = 3TZ9BM962G; + DEVELOPMENT_TEAM = P8JZH34X63; ENABLE_HARDENED_RUNTIME = YES; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; @@ -410,7 +416,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\""; - DEVELOPMENT_TEAM = 3TZ9BM962G; + DEVELOPMENT_TEAM = P8JZH34X63; ENABLE_HARDENED_RUNTIME = YES; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; diff --git a/extra/bench-all.sh b/extra/bench-all.sh index 43f989d..352a223 100755 --- a/extra/bench-all.sh +++ b/extra/bench-all.sh @@ -44,27 +44,26 @@ if [ "$encoder_only" -eq 0 ]; then printf "\n" fi -printf "| CPU | OS | Config | Model | Th | Load | Enc. | Commit |\n" -printf "| --- | -- | ------ | ----- | -- | ---- | ---- | ------ |\n" +printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit" +printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" for model in "${models[@]}"; do - # run once to heat-up the cache - ./bench -m ./models/ggml-$model.bin -t $n_threads 2>/dev/null 1>/dev/null - # actual run # store stderr output in a variable in order to parse it later output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1) ret=$? # parse the output: - load_time=$(echo "$output" | grep "load time" | awk '{print $5}') - encode_time=$(echo "$output" | grep "encode time" | awk '{print $5}') + encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}') + decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}') + prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}') system_info=$(echo "$output" | grep "system_info") n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}') # floor to milliseconds - load_time=${load_time%.*} - encode_time=${encode_time%.*} + #encode_time=${encode_time%.*} + #decode_time=${decode_time%.*} + #prompt_time=${prompt_time%.*} config="" @@ -87,6 +86,6 @@ for model in "${models[@]}"; do commit=$(git rev-parse --short HEAD) if [ $ret -eq 0 ]; then - printf "| | | $config | $model | $n_threads | $load_time | $encode_time | $commit |\n" + printf "| | | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit" fi done diff --git a/extra/sync-ggml.sh b/extra/sync-ggml.sh index 3bd99e3..0070e9e 100755 --- a/extra/sync-ggml.sh +++ b/extra/sync-ggml.sh @@ -1,18 +1,20 @@ #!/bin/bash -cp -rpv ../ggml/src/ggml.c ./ggml.c -cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h -cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu -cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h -cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp -cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h -cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m -cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal -cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h -cp -rpv ../ggml/examples/common.h ./examples/common.h -cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp -cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h -cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp +cp -rpv ../ggml/src/ggml.c ./ggml.c +cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c +cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h +cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu +cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h +cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp +cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h +cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m +cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal +cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h +cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h +cp -rpv ../ggml/examples/common.h ./examples/common.h +cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp +cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h +cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp cp -rpv ../ggml/examples/whisper/whisper.h ./whisper.h cp -rpv ../ggml/examples/whisper/whisper.cpp ./whisper.cpp diff --git a/ggml-alloc.c b/ggml-alloc.c index 856a4cd..304964b 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -6,6 +6,26 @@ #include #include +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #include + #endif + #endif +#endif + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #include +#endif + + #define UNUSED(x) (void)(x) #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define GGML_MAX_CONCUR (2*GGML_MAX_NODES) @@ -99,15 +119,28 @@ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tens } #endif - -static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { +static size_t ggml_allocr_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { return ggml_nbytes(tensor); UNUSED(alloc); } +// check if a tensor is allocated by this buffer +static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) { + void * ptr = tensor->data; + return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size; +} + +static bool ggml_is_view(struct ggml_tensor * t) { + return t->view_src != NULL; +} + void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { - size_t size = ggml_allocator_get_alloc_size(alloc, tensor); +#ifdef GGML_ALLOCATOR_DEBUG + GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources + GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated +#endif + size_t size = ggml_allocr_get_alloc_size(alloc, tensor); size = aligned_offset(NULL, size, alloc->alignment); AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size); @@ -131,14 +164,14 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) if (best_fit_block == -1) { // the last block is our last resort struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1]; + max_avail = MAX(max_avail, block->size); if (block->size >= size) { best_fit_block = alloc->n_free_blocks - 1; - max_avail = MAX(max_avail, block->size); } else { fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n", __func__, size, max_avail); GGML_ASSERT(!"not enough space in the buffer"); - return; + return; } } struct free_block * block = &alloc->free_blocks[best_fit_block]; @@ -173,17 +206,17 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) } // this is a very naive implementation, but for our case the number of free blocks should be very small -static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { +static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { void * ptr = tensor->data; - if (ptr < alloc->data || (char*)ptr >= (char*)alloc->data + alloc->max_size) { + if (ggml_allocr_is_own(alloc, tensor) == false) { // the tensor was not allocated in this buffer // this can happen because the graph allocator will try to free weights and other tensors from different buffers // the easiest way to deal with this is just to ignore it return; } - size_t size = ggml_allocator_get_alloc_size(alloc, tensor); + size_t size = ggml_allocr_get_alloc_size(alloc, tensor); size = aligned_offset(NULL, size, alloc->alignment); AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks); @@ -277,17 +310,68 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) return alloc; } -// address and size of the buffer when measuring -// it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers -static void * const MEASURE_BASE_ADDR = (void *) 0x1000; -static const size_t MEASURE_MAX_SIZE = 1ULL<<40; // 1 TB +// OS specific functions to allocate and free uncommitted virtual memory +static void * alloc_vmem(size_t size) { +#if defined(_WIN32) + return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS); +#elif defined(_POSIX_MAPPED_FILES) + void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0); + if (ptr == MAP_FAILED) { + return NULL; + } + return ptr; +#else + // use a fixed address for other platforms + uintptr_t base_addr = (uintptr_t)-size - 0x100; + return (void *)base_addr; +#endif +} + +static void free_vmem(void * base_addr, size_t size) { +#if defined(_WIN32) + VirtualFree(base_addr, 0, MEM_RELEASE); + UNUSED(size); +#elif defined(_POSIX_MAPPED_FILES) + munmap(base_addr, size); +#else + // nothing to do + UNUSED(base_addr); + UNUSED(size); +#endif +} + +// allocate uncommitted virtual memory to measure the size of the graph +static void alloc_measure_vmem(void ** base_addr, size_t * size) { + // 128GB for 64-bit, 1GB for 32-bit + *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37; + do { + *base_addr = alloc_vmem(*size); + if (*base_addr != NULL) { + AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr); + return; + } + // try again with half the size + *size /= 2; + } while (*size > 0); + + GGML_ASSERT(!"failed to allocate virtual memory for measure buffer"); +} + +static void free_measure_vmem(void * base_addr, size_t size) { + free_vmem(base_addr, size); +} struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) { struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */); + void * base_addr; + size_t size; + + alloc_measure_vmem(&base_addr, &size); + *alloc = (struct ggml_allocr){ - /*.data = */ MEASURE_BASE_ADDR, - /*.size = */ MEASURE_MAX_SIZE, + /*.data = */ base_addr, + /*.size = */ size, /*.alignment = */ alignment, /*.n_free_blocks = */ 0, /*.free_blocks = */ {{0}}, @@ -307,6 +391,9 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) { } void ggml_allocr_free(struct ggml_allocr * alloc) { + if (alloc->measure) { + free_measure_vmem(alloc->data, alloc->size); + } free(alloc); } @@ -316,11 +403,6 @@ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) { //////////// compute graph allocator -static bool ggml_is_view(struct ggml_tensor * t) { - return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE || - t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY; -} - static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { if (a->type != b->type) { return false; @@ -336,28 +418,6 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml return true; } -static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) { - switch (t->op) { - case GGML_OP_PERMUTE: - case GGML_OP_RESHAPE: - case GGML_OP_TRANSPOSE: - case GGML_OP_VIEW: - return t->src[0]; - case GGML_OP_CPY: - return t->src[1]; - default: - return NULL; - } -} - -static struct ggml_tensor * get_view_source(struct ggml_tensor * t) { - struct ggml_tensor * parent = t; - do { - parent = get_view_parent(parent); - } while (ggml_is_view(parent)); - return parent; -} - static bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { case GGML_OP_SCALE: @@ -365,7 +425,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_DIAG_MASK_INF: case GGML_OP_ADD: case GGML_OP_ADD1: - case GGML_OP_ACC: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: @@ -375,10 +434,8 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_UNARY: case GGML_OP_ROPE: case GGML_OP_RMS_NORM: - case GGML_OP_SET: case GGML_OP_SOFT_MAX: case GGML_OP_CONT: - case GGML_OP_ADD_REL_POS: return true; default: @@ -390,24 +447,8 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) struct hash_node * ht = alloc->hash_table; if (node->data == NULL) { if (ggml_is_view(node)) { - size_t offset; - switch(node->op) { - case GGML_OP_VIEW: - memcpy(&offset, node->op_params, sizeof(size_t)); - node->data = (char *) node->src[0]->data + offset; - break; - case GGML_OP_PERMUTE: - case GGML_OP_RESHAPE: - case GGML_OP_TRANSPOSE: - node->data = node->src[0]->data; - break; - case GGML_OP_CPY: - node->data = node->src[1]->data; - break; - default: - GGML_ASSERT(!"unknown view op"); - break; - } + assert(node->view_src->data != NULL); + node->data = (char *)node->view_src->data + node->view_offs; } else { // see if we can reuse a parent's buffer (inplace) if (ggml_op_can_inplace(node->op)) { @@ -418,8 +459,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) } // if the node's data is external, then we cannot re-use it - if ((char *) parent->data < (char *) alloc->data || - (char *) parent->data >= ((char *) alloc->data + alloc->size)) { + if (ggml_allocr_is_own(alloc, parent) == false) { AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data); continue; } @@ -427,7 +467,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) struct hash_node * p_hn = hash_get(ht, parent); if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) { if (ggml_is_view(parent)) { - struct ggml_tensor * view_src = get_view_source(parent); + struct ggml_tensor * view_src = parent->view_src; struct hash_node * view_src_hn = hash_get(ht, view_src); if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite @@ -453,7 +493,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) } } -static size_t ggml_allocator_alloc_graph_tensors_n( +static size_t ggml_allocr_alloc_graph_tensors_n( struct ggml_allocr * alloc, struct ggml_cgraph ** graphs, int n_graphs, struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) { @@ -469,7 +509,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n( struct ggml_tensor * node = gf->nodes[i]; if (ggml_is_view(node)) { - struct ggml_tensor * view_src = get_view_source(node); + struct ggml_tensor * view_src = node->view_src; hash_get(ht, view_src)->n_views += 1; } @@ -531,11 +571,10 @@ static size_t ggml_allocator_alloc_graph_tensors_n( AT_PRINTF("\n"); } - // update parents // update immediately if there is no parse_seq // update only at barriers if there is parse_seq - if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] == -1) { + if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) { int update_start = alloc->parse_seq_len ? last_barrier_pos : ind; int update_end = alloc->parse_seq_len ? ind : ind + 1; for (int i = update_start; i < update_end; i++) { @@ -554,17 +593,17 @@ static size_t ggml_allocator_alloc_graph_tensors_n( if (p_hn->n_children == 0 && p_hn->n_views == 0) { if (ggml_is_view(parent)) { - struct ggml_tensor * view_src = get_view_source(parent); + struct ggml_tensor * view_src = parent->view_src; struct hash_node * view_src_hn = hash_get(ht, view_src); view_src_hn->n_views -= 1; AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views); if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) { - ggml_allocator_free_tensor(alloc, view_src); + ggml_allocr_free_tensor(alloc, view_src); } } else { if (parent->data != node->data) { - ggml_allocator_free_tensor(alloc, parent); + ggml_allocr_free_tensor(alloc, parent); } } } @@ -581,7 +620,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n( for (int i = 0; outputs[g][i] != NULL; i++) { struct ggml_tensor * output = outputs[g][i]; AT_PRINTF("output: %s\n", output->name); - ggml_allocator_free_tensor(alloc, output); + ggml_allocr_free_tensor(alloc, output); } } } @@ -590,5 +629,5 @@ static size_t ggml_allocator_alloc_graph_tensors_n( } size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) { - return ggml_allocator_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL); + return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL); } diff --git a/ggml-metal.m b/ggml-metal.m index 7e2355c..b438b83 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -63,7 +63,10 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(relu); GGML_METAL_DECL_KERNEL(gelu); GGML_METAL_DECL_KERNEL(soft_max); + GGML_METAL_DECL_KERNEL(soft_max_4); GGML_METAL_DECL_KERNEL(diag_mask_inf); + GGML_METAL_DECL_KERNEL(diag_mask_inf_8); + GGML_METAL_DECL_KERNEL(get_rows_f32); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); @@ -77,6 +80,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row); + GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); @@ -117,14 +121,17 @@ static NSString * const msl_library_source = @"see metal.metal"; struct ggml_metal_context * ggml_metal_init(int n_cb) { metal_printf("%s: allocating\n", __func__); - // Show all the Metal device instances in the system - NSArray * devices = MTLCopyAllDevices(); id device; NSString * s; + +#if TARGET_OS_OSX + // Show all the Metal device instances in the system + NSArray * devices = MTLCopyAllDevices(); for (device in devices) { s = [device name]; metal_printf("%s: found device: %s\n", __func__, [s UTF8String]); } +#endif // Pick and show default Metal device device = MTLCreateSystemDefaultDevice(); @@ -139,14 +146,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { ctx->n_buffers = 0; ctx->concur_list_len = 0; - ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); + ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); -#if 0 - // compile from source string and show compile log +#ifdef GGML_SWIFT + // load the default.metallib file { NSError * error = nil; - ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error]; + NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; + NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"]; + NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath]; + NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"]; + NSURL * libURL = [NSURL fileURLWithPath:libPath]; + + // Load the metallib file into a Metal library + ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + if (error) { metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; @@ -161,7 +176,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; - NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]); NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; @@ -207,7 +222,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(relu); GGML_METAL_ADD_KERNEL(gelu); GGML_METAL_ADD_KERNEL(soft_max); + GGML_METAL_ADD_KERNEL(soft_max_4); GGML_METAL_ADD_KERNEL(diag_mask_inf); + GGML_METAL_ADD_KERNEL(diag_mask_inf_8); + GGML_METAL_ADD_KERNEL(get_rows_f32); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); @@ -221,6 +239,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row); + GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); @@ -247,13 +266,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { #undef GGML_METAL_ADD_KERNEL } - metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); +#if TARGET_OS_OSX + metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); if (ctx->device.maxTransferRate != 0) { metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); } else { metal_printf("%s: maxTransferRate = built-in GPU\n", __func__); } +#endif return ctx; } @@ -273,7 +294,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(relu); GGML_METAL_DEL_KERNEL(gelu); GGML_METAL_DEL_KERNEL(soft_max); + GGML_METAL_DEL_KERNEL(soft_max_4); GGML_METAL_DEL_KERNEL(diag_mask_inf); + GGML_METAL_DEL_KERNEL(diag_mask_inf_8); + GGML_METAL_DEL_KERNEL(get_rows_f32); GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_1); @@ -287,6 +311,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row); + GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4); GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32); @@ -365,6 +390,7 @@ static id ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; + //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { *offs = (size_t) ioffs; @@ -454,6 +480,7 @@ bool ggml_metal_add_buffer( } } +#if TARGET_OS_OSX metal_printf(", (%8.2f / %8.2f)", ctx->device.currentAllocatedSize / 1024.0 / 1024.0, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); @@ -463,6 +490,9 @@ bool ggml_metal_add_buffer( } else { metal_printf("\n"); } +#else + metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0); +#endif } return true; @@ -698,6 +728,7 @@ void ggml_metal_graph_compute( case GGML_OP_ADD: { GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); // utilize float4 GGML_ASSERT(ne00 % 4 == 0); @@ -705,6 +736,7 @@ void ggml_metal_graph_compute( if (ggml_nelements(src1) == ne10) { // src1 is a row + GGML_ASSERT(ne11 == 1); [encoder setComputePipelineState:ctx->pipeline_add_row]; } else { [encoder setComputePipelineState:ctx->pipeline_add]; @@ -721,6 +753,7 @@ void ggml_metal_graph_compute( case GGML_OP_MUL: { GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); // utilize float4 GGML_ASSERT(ne00 % 4 == 0); @@ -728,6 +761,7 @@ void ggml_metal_graph_compute( if (ggml_nelements(src1) == ne10) { // src1 is a row + GGML_ASSERT(ne11 == 1); [encoder setComputePipelineState:ctx->pipeline_mul_row]; } else { [encoder setComputePipelineState:ctx->pipeline_mul]; @@ -743,6 +777,8 @@ void ggml_metal_graph_compute( } break; case GGML_OP_SCALE: { + GGML_ASSERT(ggml_is_contiguous(src0)); + const float scale = *(const float *) src1->data; [encoder setComputePipelineState:ctx->pipeline_scale]; @@ -750,7 +786,7 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - const int64_t n = ggml_nelements(dst); + const int64_t n = ggml_nelements(dst)/4; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -762,7 +798,7 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); + const int64_t n = ggml_nelements(dst)/4; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -782,7 +818,7 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst); + const int64_t n = ggml_nelements(dst)/4; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -796,13 +832,16 @@ void ggml_metal_graph_compute( { const int nth = 32; - [encoder setComputePipelineState:ctx->pipeline_soft_max]; + if (ne00%4 == 0) { + [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; + } else { + [encoder setComputePipelineState:ctx->pipeline_soft_max]; + } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -810,14 +849,23 @@ void ggml_metal_graph_compute( { const int n_past = ((int32_t *)(dst->op_params))[0]; - [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; + if (ne00%8 == 0) { + [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8]; + } else { + [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; + } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + if (ne00%8 == 0) { + [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + else { + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } } break; case GGML_OP_MUL_MAT: { @@ -830,8 +878,8 @@ void ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && + if (!ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && [ctx->device supportsFamily:MTLGPUFamilyApple7] && ne00%32 == 0 && @@ -856,14 +904,18 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { int nth0 = 32; int nth1 = 1; + int nrows = 1; // use custom matrix x vector kernel switch (src0t) { @@ -873,8 +925,14 @@ void ggml_metal_graph_compute( nth1 = 1; if (ne11 * ne12 < 4) { [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; + //} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + } else if (false) { + // TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4]; + nrows = ne11; } else { [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + nrows = 4; } } break; case GGML_TYPE_Q4_0: @@ -995,7 +1053,7 @@ void ggml_metal_graph_compute( else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - int64_t ny = (ne11 + 3)/4; + int64_t ny = (ne11 + nrows - 1)/nrows; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } @@ -1003,6 +1061,7 @@ void ggml_metal_graph_compute( case GGML_OP_GET_ROWS: { switch (src0->type) { + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; @@ -1018,9 +1077,9 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; const int64_t n = ggml_nelements(src1); diff --git a/ggml-metal.metal b/ggml-metal.metal index 5070561..0db037c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -38,7 +38,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb, + constant int64_t & nb, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } @@ -63,18 +63,18 @@ kernel void kernel_mul_row( } kernel void kernel_scale( - device const float * src0, - device float * dst, + device const float4 * src0, + device float4 * dst, constant float & scale, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * scale; } kernel void kernel_silu( - device const float * src0, - device float * dst, + device const float4 * src0, + device float4 * dst, uint tpig[[thread_position_in_grid]]) { - float x = src0[tpig]; + device const float4 & x = src0[tpig]; dst[tpig] = x / (1.0f + exp(-x)); } @@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; kernel void kernel_gelu( - device const float * src0, - device float * dst, + device const float4 * src0, + device float4 * dst, uint tpig[[thread_position_in_grid]]) { - float x = src0[tpig]; + device const float4 & x = src0[tpig]; // BEWARE !!! // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! @@ -107,7 +107,6 @@ kernel void kernel_soft_max( constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, - threadgroup float * buf [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -119,64 +118,70 @@ kernel void kernel_soft_max( device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max - buf[tpitg[0]] = -INFINITY; - for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]); + float lmax = psrc0[tpitg[0]]; + for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { + lmax = MAX(lmax, psrc0[i00]); } - - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg[0]/2; i > 0; i /= 2) { - if (tpitg[0] < i) { - buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of - // the loop, and when that is done, buf[0] has the correct (synchronized) value - //if (tpitg[0] == 0) { - // buf[0] = buf[0]; - //} - - //threadgroup_barrier(mem_flags::mem_threadgroup); - - const float max = buf[0]; + const float max = simd_max(lmax); // parallel sum - buf[tpitg[0]] = 0.0f; + float lsum = 0.0f; for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { const float exp_psrc0 = exp(psrc0[i00] - max); - buf[tpitg[0]] += exp_psrc0; + lsum += exp_psrc0; // Remember the result of exp here. exp is expensive, so we really do not // whish to compute it twice. pdst[i00] = exp_psrc0; } - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg[0]/2; i > 0; i /= 2) { - if (tpitg[0] < i) { - buf[tpitg[0]] += buf[tpitg[0] + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // broadcast - not needed, see above - //// broadcast - //if (tpitg[0] == 0) { - // buf[0] = buf[0]; - //} - - //threadgroup_barrier(mem_flags::mem_threadgroup); - - const float sum = buf[0]; + const float sum = simd_sum(lsum); for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { pdst[i00] /= sum; } } +kernel void kernel_soft_max_4( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + // parallel max + float4 lmax4 = psrc4[tpitg[0]]; + for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) { + lmax4 = fmax(lmax4, psrc4[i00]); + } + float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + const float max = simd_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { + const float4 exp_psrc4 = exp(psrc4[i00] - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + const float sum = simd_sum(lsum); + + for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { + pdst4[i00] /= sum; + } +} + kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, @@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf( dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; } else { dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + const int64_t i01 = i4/(ne00); i4 -= i01*ne00; + const int64_t i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > n_past + i01) { + dst[i][k] = -INFINITY; + } } } @@ -616,6 +648,49 @@ kernel void kernel_mul_mat_f16_f32( } } +// Assumes row size (ne00) is a multiple of 4 +kernel void kernel_mul_mat_f16_f32_l4( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = ne11; + const int64_t r0 = tgpig.x; + const int64_t im = tgpig.z; + + device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + + for (int r1 = 0; r1 < nrows; ++r1) { + device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + kernel void kernel_alibi_f32( device const float * src0, device float * dst, @@ -1123,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32( device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - float yl[16]; + float yl[32]; - const uint16_t kmask1 = 0x0303; + const uint16_t kmask1 = 0x3030; const uint16_t kmask2 = 0x0f0f; - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid/2 - 4*ip; // 0...3 + const int tid = tiisg/4; + const int ix = tiisg%4; + const int ip = tid/4; // 0 or 1 + const int il = 2*((tid%4)/2); // 0 or 2 const int ir = tid%2; const int n = 8; const int l0 = n*ir; - const uint16_t m1 = 1 << (4*ip + il); - const uint16_t m2 = m1 << 8; + // One would think that the Metal compiler would figure out that ip and il can only have + // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it + // with these two tales. + // + // Possible masks for the high bit + const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 + {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 + {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 + {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 + + // Possible masks for the low 2 bits + const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; + + const ushort4 hm = mm[2*ip + il/2]; const int shift = 2*il; - const uint16_t qm1 = 0x0003 << shift; - const uint16_t qm2 = 0x0300 << shift; - const int32_t v1 = 4 << shift; - const int32_t v2 = 1024 << shift; + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; const uint16_t s_shift1 = 4*ip; - const uint16_t s_shift2 = s_shift1 + 2*(il/2); - const int ik = 4 + (il%2); + const uint16_t s_shift2 = s_shift1 + il; const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; @@ -1156,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32( device const float * y1 = yy + ix*QK_K + y_offset; - float sumf1[2] = {0.f}, sumf2[2] = {0.f}; - for (int i = ix; i < nb; i += 2) { + uint32_t scales32, aux32; + thread uint16_t * scales16 = (thread uint16_t *)&scales32; + thread const int8_t * scales = (thread const int8_t *)&scales32; + + float sumf1[2] = {0.f}; + float sumf2[2] = {0.f}; + for (int i = ix; i < nb; i += 4) { for (int l = 0; l < 8; ++l) { - yl[l+0] = y1[l+ 0]; - yl[l+8] = y1[l+16]; + yl[l+ 0] = y1[l+ 0]; + yl[l+ 8] = y1[l+16]; + yl[l+16] = y1[l+32]; + yl[l+24] = y1[l+48]; } device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); @@ -1172,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32( for (int row = 0; row < 2; ++row) { const float d_all = (float)dh[0]; - const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); - float s1 = 0, s2 = 0; - for (int l = 0; l < n; l += 2) { - const uint16_t qs = q[l/2]; - s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); - s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); - } - float d = d_all * (s1 + 1.f/256.f * s2); - sumf1[row] += d * scales[0]; - sumf2[row] += d; + scales16[0] = a[4]; + scales16[1] = a[5]; + aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; + scales16[0] = a[il+0]; + scales16[1] = a[il+1]; + scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; - s1 = s2 = 0; + float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; for (int l = 0; l < n; l += 2) { - const uint16_t qs = q[l/2+8]; - s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); - s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); + const int32_t qs = q[l/2]; + s1 += yl[l+0] * (qs & qm[il/2][0]); + s2 += yl[l+1] * (qs & qm[il/2][1]); + s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); + s4 += yl[l+16] * (qs & qm[il/2][2]); + s5 += yl[l+17] * (qs & qm[il/2][3]); + s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); } - d = d_all * (s1 + 1.f/256.f * s2); - sumf1[row] += d * scales[1]; - sumf2[row] += d; + float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[0] - 32); + sumf2[row] += d2 * (scales[2] - 32); + + s1 = s2 = s3 = s4 = s5 = s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2+8]; + s1 += yl[l+8] * (qs & qm[il/2][0]); + s2 += yl[l+9] * (qs & qm[il/2][1]); + s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); + s4 += yl[l+24] * (qs & qm[il/2][2]); + s5 += yl[l+25] * (qs & qm[il/2][3]); + s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); + } + d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[1] - 32); + sumf2[row] += d2 * (scales[3] - 32); q += step; h += step; @@ -1201,15 +1308,17 @@ kernel void kernel_mul_mat_q3_K_f32( } - y1 += 2 * QK_K; + y1 += 4 * QK_K; } for (int row = 0; row < 2; ++row) { - const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; + const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); + sumf1[row] = simd_sum(sumf); + } + if (tiisg == 0) { + for (int row = 0; row < 2; ++row) { + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; } } } @@ -1564,17 +1673,25 @@ kernel void kernel_mul_mat_q5_K_f32( sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); - float4 acc = {0.f, 0.f, 0.f, 0.f}; + float4 acc1 = {0.f}; + float4 acc2 = {0.f}; for (int l = 0; l < n; ++l) { uint8_t h = qh[l]; - acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); - acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); - acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); - acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); + acc1[0] += yl[l+0] * (q1[l] & 0x0F); + acc1[1] += yl[l+8] * (q1[l] & 0xF0); + acc1[2] += yh[l+0] * (q2[l] & 0x0F); + acc1[3] += yh[l+8] * (q2[l] & 0xF0); + acc2[0] += h & hm1 ? yl[l+0] : 0.f; + acc2[1] += h & hm2 ? yl[l+8] : 0.f; + acc2[2] += h & hm3 ? yh[l+0] : 0.f; + acc2[3] += h & hm4 ? yh[l+8] : 0.f; } const float dall = dh[0]; const float dmin = dh[1]; - sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) - + sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); q1 += step; @@ -1747,6 +1864,15 @@ kernel void kernel_mul_mat_q6_K_f32( //============================= templates and their specializations ============================= +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + float4x4 temp = *(((device float4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + template void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { half4x4 temp = *(((device half4x4 *)src)); @@ -1758,28 +1884,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) template void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const half d = il ? (xb->d / 16.h) : xb->d; - const half m = il ? ( -8.h * 16.h) : -8.h; + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = il ? 0xF000 : 0x0F00; + const ushort mask1 = mask0 << 8; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d; - reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; + reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; } } template void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const half d = il ? (xb->d / 16.h) : xb->d; - const half m = xb->m; + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = il ? 0xF000 : 0x0F00; + const ushort mask1 = mask0 << 8; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m; - reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m; + reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; + reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; } } @@ -1815,7 +1943,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg template void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { - const float d_all = (float)(xb->d); + const half d_all = xb->d; device const uint8_t * q = (device const uint8_t *)xb->qs; device const uint8_t * h = (device const uint8_t *)xb->hmask; device const int8_t * scales = (device const int8_t *)xb->scales; @@ -1828,16 +1956,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg ((il/4)>0 ? 12 : 3); uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; - int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \ - (scale_2&kmask2) | ((scale_1&kmask1) << 4); - float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) + : (scale_2&kmask2) | ((scale_1&kmask1) << 4); + half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); + const half ml = 4.h * dl; - il = (il/2)%4; - float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + il = (il/2) & 3; + const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl *= coef; for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef)); + reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); } #else float kcoef = il&1 ? 1.f/16.f : 1.f; @@ -1852,26 +1982,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg #endif } +static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { + return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} + : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; +} + template void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { - device const uint8_t * q = xb->qs; + device const uchar * q = xb->qs; #if QK_K == 256 - const float d = (float)(xb->d); - const float min = (float)(xb->dmin); short is = (il/4) * 2; q = q + (il/4) * 32 + 16 * (il&1); - il = il%4; - const uchar4 sc = get_scale_min_k4(is, xb->scales); - const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; - const float ml = il<2 ? min * sc[1] : min * sc[3]; + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const half d = il < 2 ? xb->d : xb->d / 16.h; + const half min = xb->dmin; + const half dl = d * sc[0]; + const half ml = min * sc[1]; #else q = q + 16 * (il&1); device const uint8_t * s = xb->scales; device const half2 * dh = (device const half2 *)xb->d; const float2 d = (float2)dh[0]; const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; - const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4); + const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); #endif const ushort mask = il<2 ? 0x0F : 0xF0; for (int i = 0; i < 16; ++i) { @@ -1885,19 +2020,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg device const uint8_t * qh = xb->qh; #if QK_K == 256 - const float d = (float)(xb->d); - const float min = (float)(xb->dmin); short is = (il/4) * 2; q = q + 32 * (il/4) + 16 * (il&1); qh = qh + 16 * (il&1); uint8_t ul = 1 << (il/2); - il = il%4; - const uchar4 sc = get_scale_min_k4(is, xb->scales); - const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; - const float ml = il<2 ? min * sc[1] : min * sc[3]; + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const half d = il < 2 ? xb->d : xb->d / 16.h; + const half min = xb->dmin; + const half dl = d * sc[0]; + const half ml = min * sc[1]; - const ushort mask = il<2 ? 0x0F : 0xF0; - const float qh_val = il<2 ? 16.f : 256.f; + const ushort mask = il<2 ? 0x0F : 0xF0; + const half qh_val = il<2 ? 16.h : 256.h; for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; } @@ -1916,7 +2051,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg template void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { - const float d_all = (float)(xb->d); + const half d_all = xb->d; device const uint8_t * ql = (device const uint8_t *)xb->ql; device const uint8_t * qh = (device const uint8_t *)xb->qh; device const int8_t * scales = (device const int8_t *)xb->scales; @@ -1924,19 +2059,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg #if QK_K == 256 ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); qh = qh + 32*(il/8) + 16*(il&1); - float sc = scales[(il%2) + 2 * ((il/2))]; - il = (il/2)%4; + half sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; #else ql = ql + 16 * (il&1); - float sc = scales[il]; + half sc = scales[il]; #endif + const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; + const half coef = il>1 ? 1.f/16.h : 1.h; + const half ml = d_all * sc * 32.h; + const half dl = d_all * sc * coef; for (int i = 0; i < 16; ++i) { - uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const float coef = il>1 ? 1.f/16.f : 1.f; - float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \ - ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef; - reg[i/4][i%4] = d_all * sc * q * coef; + const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) + : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; } } @@ -1976,22 +2113,25 @@ kernel void kernel_get_rows( // each block_q contains 16*nl weights template kernel void kernel_mul_mm(device const uchar * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & gqa, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = ((threadgroup half *)shared_memory); + threadgroup half * sa = (threadgroup half *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); const uint r0 = tgpig.y; @@ -2004,7 +2144,7 @@ kernel void kernel_mul_mm(device const uchar * src0, short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_half8x8 ma[4]; + simdgroup_half8x8 ma[4]; simdgroup_float8x8 mb[2]; simdgroup_float8x8 c_res[8]; for (int i = 0; i < 8; i++){ @@ -2012,10 +2152,15 @@ kernel void kernel_mul_mm(device const uchar * src0, } short il = (tiitg % THREAD_PER_ROW); - uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ - + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1; + + uint offset0 = im/gqa*nb02; + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { //load data and store to threadgroup memory @@ -2095,6 +2240,7 @@ kernel void kernel_mul_mm(device const uchar * src0, typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ constant uint64_t &, constant uint64_t &, uint, uint, uint); +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; @@ -2105,14 +2251,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; -typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ - constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ - constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); +typedef void (mat_mm_t)( + device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar *, uint3, uint, uint); -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/ggml.c b/ggml.c index dcdebd2..c5b5dd6 100644 --- a/ggml.c +++ b/ggml.c @@ -4303,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) { } size_t ggml_nbytes(const struct ggml_tensor * tensor) { - size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type); - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + size_t nbytes; + size_t blck_size = ggml_blck_size(tensor->type); + if (blck_size == 1) { + nbytes = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } } + else { + nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } + } + return nbytes; } @@ -18345,10 +18356,11 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { for (int i = 0; i < cgraph->n_leafs; i++) { struct ggml_tensor * node = cgraph->leafs[i]; - GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n", + GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n", i, node->ne[0], node->ne[1], - ggml_op_name(node->op)); + ggml_op_name(node->op), + ggml_get_name(node)); } for (int i = 0; i < GGML_OP_COUNT; i++) { diff --git a/whisper.cpp b/whisper.cpp index f5a9a71..23ebd7e 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3,11 +3,16 @@ #include "coreml/whisper-encoder.h" #endif +#ifdef GGML_USE_METAL +# include "ggml-metal.h" +#endif + #ifdef WHISPER_USE_OPENVINO #include "openvino/whisper-openvino-encoder.h" #endif #include "ggml.h" +#include "ggml-alloc.h" #include #include @@ -24,6 +29,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -115,9 +121,6 @@ static void byteswap_tensor(ggml_tensor * tensor) { //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 16 -#define WHISPER_USE_SCRATCH -#define WHISPER_MAX_SCRATCH_BUFFERS 16 - // // ggml helpers // @@ -133,6 +136,44 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * ggml_graph_compute(graph, &plan); } +// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad" +// the idea is to represent the original matrix multiplication: +// +// Z = X @ Y +// +// with the sum of two matrix multiplications: +// +// Z = (X_0 @ Y_0) + (X_1 @ Y_1) +// +// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad" +// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more +// general-purpose kernels +// +static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) { + // use padding only if dimension 0 is at least 8 times larger than the padding + // else we won't get much benefit from the optimization + const int n_pad_req = 8; + + if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) { + return ggml_mul_mat(ctx, x, y); + } + + struct ggml_tensor * x_0 = ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0); + struct ggml_tensor * x_1 = ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]); + + struct ggml_tensor * y_0 = ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0); + struct ggml_tensor * y_1 = ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]); + + return ggml_add(ctx, + ggml_mul_mat(ctx, x_0, y_0), + ggml_mul_mat(ctx, x_1, y_1)); +} + +// TODO: check if other platforms can benefit from this optimization +#if defined(GGML_USE_METAL) +#define ggml_mul_mat ggml_mul_mat_pad +#endif + // available whisper models enum e_model { MODEL_UNKNOWN, @@ -247,38 +288,7 @@ static const std::map> g_lang = { static const size_t MB = 1ull*1024*1024; -static const std::map MEM_REQ_SCRATCH0 = { - { MODEL_TINY, 62ull*MB }, - { MODEL_BASE, 80ull*MB }, - { MODEL_SMALL, 120ull*MB }, - { MODEL_MEDIUM, 158ull*MB }, - { MODEL_LARGE, 198ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH1 = { - { MODEL_TINY, 18ull*MB }, - { MODEL_BASE, 24ull*MB }, - { MODEL_SMALL, 36ull*MB }, - { MODEL_MEDIUM, 48ull*MB }, - { MODEL_LARGE, 60ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH2 = { - { MODEL_TINY, 4ull*MB }, - { MODEL_BASE, 4ull*MB }, - { MODEL_SMALL, 6ull*MB }, - { MODEL_MEDIUM, 7ull*MB }, - { MODEL_LARGE, 9ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH3 = { - { MODEL_TINY, 4ull*MB }, - { MODEL_BASE, 4ull*MB }, - { MODEL_SMALL, 6ull*MB }, - { MODEL_MEDIUM, 7ull*MB }, - { MODEL_LARGE, 9ull*MB }, -}; - +// TODO: avoid using GGUF static const std::map> MEM_REQ_MODEL = { { GGML_TYPE_F32, { @@ -345,38 +355,6 @@ static const std::map> MEM_REQ_MODEL = { }, }; -static const std::map MEM_REQ_KV_SELF = { - { MODEL_TINY, 3ull*MB }, - { MODEL_BASE, 6ull*MB }, - { MODEL_SMALL, 16ull*MB }, - { MODEL_MEDIUM, 43ull*MB }, - { MODEL_LARGE, 71ull*MB }, -}; - -static const std::map MEM_REQ_KV_CROSS = { - { MODEL_TINY, 9ull*MB }, - { MODEL_BASE, 18ull*MB }, - { MODEL_SMALL, 53ull*MB }, - { MODEL_MEDIUM, 141ull*MB }, - { MODEL_LARGE, 235ull*MB }, -}; - -static const std::map MEM_REQ_ENCODE = { - { MODEL_TINY, 30ull*MB }, - { MODEL_BASE, 38ull*MB }, - { MODEL_SMALL, 56ull*MB }, - { MODEL_MEDIUM, 74ull*MB }, - { MODEL_LARGE, 94ull*MB }, -}; - -static const std::map MEM_REQ_DECODE = { - { MODEL_TINY, 3ull*MB }, - { MODEL_BASE, 5ull*MB }, - { MODEL_SMALL, 10ull*MB }, - { MODEL_MEDIUM, 18ull*MB }, - { MODEL_LARGE, 27ull*MB }, -}; - struct whisper_mel { int n_len; int n_len_org; @@ -657,15 +635,57 @@ struct kv_buf { std::vector v; }; +// ggml_allocr wrapper for whisper usage +struct whisper_allocr { + ggml_allocr * alloc = nullptr; + + std::vector meta; + std::vector data; +}; + +static size_t whisper_allocr_size(struct whisper_allocr & allocr) { + return allocr.meta.size() + allocr.data.size(); +} + +// measure the memory usage of a graph and prepare the allocr's internal data buffer +static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function && get_graph) { + const int tensor_alignment = 32; + + auto & alloc = allocr.alloc; + auto & meta = allocr.meta; + auto & data = allocr.data; + + meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); + + alloc = ggml_allocr_new_measure(tensor_alignment); + + const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment; + + ggml_allocr_free(alloc); + + data.resize(alloc_size); + + alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); +} + +static void whisper_allocr_free(struct whisper_allocr & allocr) { + if (allocr.alloc) { + ggml_allocr_free(allocr.alloc); + allocr.alloc = nullptr; + } +} + struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; int64_t t_decode_us = 0; + int64_t t_prompt_us = 0; int64_t t_mel_us = 0; int32_t n_sample = 0; // number of tokens sampled int32_t n_encode = 0; // number of encoder calls - int32_t n_decode = 0; // number of decoder calls + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) int32_t n_fail_p = 0; // number of logprob threshold failures int32_t n_fail_h = 0; // number of entropy threshold failures @@ -679,13 +699,20 @@ struct whisper_state { // buffer for swapping KV caches between decoders during beam-search std::vector kv_swap_bufs; - // memory buffers used by encode / decode contexts - std::vector buf_compute; - std::vector buf_work; - std::vector buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS]; + // reusable buffer for `struct ggml_graph_plan.work_data` + std::vector work_buffer; - int buf_last = 0; - size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 }; + // ggml-alloc: + // - stores meta info about the intermediate tensors into the `meta` buffers + // - stores the actual tensor data into the `data` buffers + whisper_allocr alloc_conv; + whisper_allocr alloc_encode; + whisper_allocr alloc_cross; + whisper_allocr alloc_decode; + + // result of the encoder + struct ggml_tensor * embd_conv = nullptr; + struct ggml_tensor * embd_enc = nullptr; // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -705,6 +732,10 @@ struct whisper_state { whisper_coreml_context * ctx_coreml = nullptr; #endif +#ifdef GGML_USE_METAL + ggml_metal_context * ctx_metal = nullptr; +#endif + #ifdef WHISPER_USE_OPENVINO whisper_openvino_context * ctx_openvino = nullptr; #endif @@ -717,37 +748,6 @@ struct whisper_state { // [EXPERIMENTAL] speed-up techniques int32_t exp_n_audio_ctx = 0; // 0 - use default - - void use_buf(struct ggml_context * ctx, int i) { -#if defined(WHISPER_USE_SCRATCH) - size_t last_size = 0; - - if (i == -1) { - last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); - } else { - auto & buf = buf_scratch[i]; - last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), }); - } - - if (buf_last >= 0) { - buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); - } - - buf_last = i; -#else - (void) i; - (void) ctx; -#endif - } - - size_t get_buf_max_mem(int i) const { -#if defined(WHISPER_USE_SCRATCH) - return buf_max_size[i]; -#else - (void) i; - return 0; -#endif - } }; struct whisper_context { @@ -794,10 +794,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) { static bool kv_cache_init( const struct whisper_hparams & hparams, - const size_t mem_bytes, struct whisper_kv_cache & cache, ggml_type wtype, int n_ctx) { + const int64_t n_text_state = hparams.n_text_state; + const int64_t n_text_layer = hparams.n_text_layer; + + const int64_t n_mem = n_text_layer*n_ctx; + const int64_t n_elements = n_text_state*n_mem; + + const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead()); + cache.buf.resize(mem_bytes); struct ggml_init_params params = { @@ -813,12 +820,6 @@ static bool kv_cache_init( return false; } - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - - const int n_mem = n_text_layer*n_ctx; - const int n_elements = n_text_state*n_mem; - cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); @@ -961,22 +962,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // print memory requirements { - // this is the total memory required to run the inference - const size_t mem_required = - MEM_REQ_SCRATCH0.at(model.type) + - MEM_REQ_SCRATCH1.at(model.type) + - MEM_REQ_SCRATCH2.at(model.type) + - MEM_REQ_SCRATCH3.at(model.type) + - scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) + - scale*MEM_REQ_KV_CROSS.at(model.type) + - scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); - - // this is the memory required by one decoder - const size_t mem_required_decoder = - scale*MEM_REQ_KV_SELF.at(model.type); - - log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, - mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); + // TODO + //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, + // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); } // initialize all memory buffers @@ -1485,6 +1473,441 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con return true; } +static bool whisper_encode_external(const whisper_state & wstate) { + GGML_UNUSED(wstate); + +#ifndef WHISPER_USE_COREML + const bool use_coreml = false; +#else + const bool use_coreml = wstate.ctx_coreml != nullptr; +#endif + +#ifndef WHISPER_USE_OPENVINO + const bool use_openvino = false; +#else + const bool use_openvino = wstate.ctx_openvino != nullptr; +#endif + + return use_coreml || use_openvino; +} + +static struct ggml_cgraph * whisper_build_graph_conv( + whisper_context & wctx, + whisper_state & wstate, + const int mel_offset) { + const auto & model = wctx.model; + const auto & mel_inp = wstate.mel; + const auto & hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; GGML_UNUSED(n_state); + + const int n_mels = hparams.n_mels; + + struct ggml_init_params params = { + /*.mem_size =*/ wstate.alloc_conv.meta.size(), + /*.mem_buffer =*/ wstate.alloc_conv.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + ggml_allocr * alloc = wstate.alloc_conv.alloc; + + struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); + ggml_allocr_alloc(alloc, mel); + + assert(mel->type == GGML_TYPE_F32); + if (!ggml_allocr_is_measure(alloc)) { + assert(mel_inp.n_mel == n_mels); + + float * dst = (float *) mel->data; + memset(dst, 0, ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); + + for (int j = 0; j < mel_inp.n_mel; ++j) { + for (int i = i0; i < i1; ++i) { + dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; + } + } + } + + struct ggml_tensor * cur = nullptr; + + if (!whisper_encode_external(wstate)) { + // convolution + gelu + { + cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, + model.e_conv_1_b, + cur), + cur); + + cur = ggml_gelu(ctx0, cur); + + cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); + cur = ggml_add(ctx0, + ggml_repeat(ctx0, + model.e_conv_2_b, + cur), + cur); + + cur = ggml_gelu(ctx0, cur); + } + + wstate.embd_conv = cur; + } else { +#ifdef WHISPER_USE_COREML + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); + ggml_allocr_alloc(alloc, cur); + + if (!ggml_allocr_is_measure(alloc)) { + whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data); + } +#endif +#ifdef WHISPER_USE_OPENVINO + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); + ggml_allocr_alloc(alloc, cur); + + if (!ggml_allocr_is_measure(alloc)) { + whisper_openvino_encode(wstate.ctx_openvino, mel, cur); + } +#endif + + wstate.embd_enc = cur; + } + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * whisper_build_graph_encoder( + whisper_context & wctx, + whisper_state & wstate) { + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + const int n_layer = hparams.n_audio_layer; + + struct ggml_init_params params = { + /*.mem_size =*/ wstate.alloc_encode.meta.size(), + /*.mem_buffer =*/ wstate.alloc_encode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + ggml_allocr * alloc = wstate.alloc_encode.alloc; + + struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(alloc, KQscale); + + if (!ggml_allocr_is_measure(alloc)) { + ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head)); + } + + struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); + + // =================================================================== + // NOTE: experimenting with partial evaluation of the encoder (ignore) + //static int iter = -1; + //const int n_iter = 1500/n_ctx; + + //iter = (iter + 1) % n_iter; + + //if (iter == 0) { + // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); + //} + + static int iter = 0; + + const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; + + struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); + + cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); + + // =================================================================== + + // original: + //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); + + struct ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_encoder[il]; + + // norm + { + cur = ggml_norm(ctx0, inpL, hparams.eps); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, layer.attn_ln_0_w), + layer.attn_ln_0_b); + } + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b); + + //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + // note: no bias for Key + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b); + + // ------ + +#ifdef WHISPER_USE_FLASH_ATTN + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state/n_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)); + + struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); +#else + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled); + + struct ggml_tensor * V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state/n_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); +#endif + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + } + + // projection + { + cur = ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + cur = ggml_add(ctx0, cur, layer.attn_ln_1_b); + } + + // add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctx0, inpFF, hparams.eps); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, layer.mlp_ln_w), + layer.mlp_ln_b); + } + +#ifdef WHISPER_USE_FLASH_FF + cur = ggml_flash_ff(ctx0, + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)), + layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); +#else + // fully connected + cur = ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + cur = ggml_add(ctx0, cur, layer.mlp_0_b); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + cur = ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + cur = ggml_add(ctx0, cur, layer.mlp_1_b); +#endif + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur, hparams.eps); + + // cur = ln_f_g*cur + ln_f_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.e_ln_w), + model.e_ln_b); + } + + ggml_build_forward_expand(gf, cur); + + wstate.embd_enc = cur; + + //ggml_graph_print(gf); + + //////////////////////////////////////////////////////////////////////////// + + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1024.0/1024.0, + // wstate.get_buf_max_mem(0)/1024.0/1024.0, + // wstate.get_buf_max_mem(1)/1024.0/1024.0, + // wstate.get_buf_max_mem(2)/1024.0/1024.0, + // wstate.get_buf_max_mem(3)/1024.0/1024.0); + + ggml_free(ctx0); + + return gf; +} + +// pre-compute cross-attention memory +static struct ggml_cgraph * whisper_build_graph_cross( + whisper_context & wctx, + whisper_state & wstate) { + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + + struct ggml_init_params params = { + /*.mem_size =*/ wstate.alloc_cross.meta.size(), + /*.mem_buffer =*/ wstate.alloc_cross.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + ggml_allocr * alloc = wstate.alloc_cross.alloc; + + struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); + + struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(alloc, Kscale); + + if (!ggml_allocr_is_measure(alloc)) { + ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25)); + } + + for (int il = 0; il < model.hparams.n_text_layer; ++il) { + auto & layer = model.layers_decoder[il]; + + struct ggml_tensor* Kcross = ggml_mul_mat(ctx0, + layer.cross_attn_k_w, + cur); + + Kcross = ggml_scale(ctx0, Kcross, Kscale); + + struct ggml_tensor* Vcross = ggml_mul_mat(ctx0, + layer.cross_attn_v_w, + cur); + + Vcross = ggml_add(ctx0, + Vcross, + layer.cross_attn_v_b); + + Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + + struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, + n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + + struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, + ( n_ctx)*ggml_element_size(wstate.kv_cross.v), + (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v)); + } + + //ggml_graph_print(gf); + + ggml_free(ctx0); + + return gf; +} + // evaluate the encoder with the given state // // given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder @@ -1499,453 +1922,69 @@ static bool whisper_encode_internal( whisper_context & wctx, whisper_state & wstate, const int mel_offset, - const int n_threads){ - + const int n_threads) { const int64_t t_start_us = ggml_time_us(); - const auto & model = wctx.model; - const auto & mel_inp = wstate.mel; - const auto & hparams = model.hparams; - - const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; - const int n_state = hparams.n_audio_state; - const int n_head = hparams.n_audio_head; - const int n_layer = hparams.n_audio_layer; - - const int n_mels = hparams.n_mels; - assert(mel_inp.n_mel == n_mels); - - struct ggml_init_params params = { - /*.mem_size =*/ wstate.buf_compute.size(), - /*.mem_buffer =*/ wstate.buf_compute.data(), - /*.no_alloc =*/ false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - wstate.use_buf(ctx0, 0); - - struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); - assert(mel->type == GGML_TYPE_F32); + // conv { - float * dst = (float *) mel->data; - memset(dst, 0, ggml_nbytes(mel)); + auto & alloc = wstate.alloc_conv.alloc; - const int i0 = std::min(mel_offset, mel_inp.n_len); - const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); + ggml_allocr_reset(alloc); - for (int j = 0; j < mel_inp.n_mel; ++j) { - for (int i = i0; i < i1; ++i) { - dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; - } + ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset); + + ggml_allocr_alloc_graph(alloc, gf); + + if (!whisper_encode_external(wstate)) { + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); } } - struct ggml_tensor * cur; + // encoder + if (!whisper_encode_external(wstate)) { + auto & alloc = wstate.alloc_encode.alloc; -#ifndef WHISPER_USE_COREML - const bool use_coreml = false; -#else - const bool use_coreml = wstate.ctx_coreml != nullptr; -#endif + ggml_allocr_reset(alloc); -#ifndef WHISPER_USE_OPENVINO - const bool use_openvino = false; -#else - const bool use_openvino = wstate.ctx_openvino != nullptr; -#endif + ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate); - if (!use_coreml && !use_openvino) { - // convolution + gelu - { - wstate.use_buf(ctx0, 1); + ggml_allocr_alloc_graph(alloc, gf); - cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_1_b, - cur), - cur); - - cur = ggml_gelu(ctx0, cur); - - wstate.use_buf(ctx0, 0); - - cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); - cur = ggml_add(ctx0, - ggml_repeat(ctx0, - model.e_conv_2_b, - cur), - cur); - - cur = ggml_gelu(ctx0, cur); +#ifdef GGML_USE_METAL + if (wstate.ctx_metal) { + ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); + ggml_metal_graph_compute(wstate.ctx_metal, gf); + } else { + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); } - - wstate.use_buf(ctx0, 3); - - // =================================================================== - // NOTE: experimenting with partial evaluation of the encoder (ignore) - //static int iter = -1; - //const int n_iter = 1500/n_ctx; - - //iter = (iter + 1) % n_iter; - - //if (iter == 0) { - // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); - // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); - //} - - static int iter = 0; - - const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); - const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; - - struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); - - cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur)); - - // =================================================================== - - // original: - //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); - - struct ggml_tensor * inpL = cur; - - for (int il = 0; il < n_layer; ++il) { - const auto & layer = model.layers_encoder[il]; - - // norm - { - wstate.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, inpL, hparams.eps); - - // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.attn_ln_0_w, cur), - cur), - ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); - } - - // self-attention - { - wstate.use_buf(ctx0, 1); - - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, - layer.attn_q_w, - cur); - - Qcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_q_b, - Qcur), - Qcur); - - //Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - - // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, - layer.attn_k_w, - cur); - - //Kcur = ggml_scale_inplace(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, - layer.attn_v_w, - cur); - - Vcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_v_b, - Vcur), - Vcur); - - // ------ - - wstate.use_buf(ctx0, 0); - -#ifdef WHISPER_USE_FLASH_ATTN - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Kcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * V = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - Vcur, - n_state/n_head, n_head, n_ctx), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)); - - struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); #else - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Kcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - struct ggml_tensor * KQ_scaled = - ggml_scale_inplace(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) - ); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_scaled); - - struct ggml_tensor * V = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - Vcur, - n_state/n_head, n_head, n_ctx), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head) - ); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); #endif - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - wstate.use_buf(ctx0, 1); - - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); - } - - // projection - { - wstate.use_buf(ctx0, 0); - - cur = ggml_mul_mat(ctx0, - layer.attn_ln_1_w, - cur); - - wstate.use_buf(ctx0, 1); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.attn_ln_1_b, cur), - cur); - } - - wstate.use_buf(ctx0, 2); - - // add the input - cur = ggml_add(ctx0, cur, inpL); - - struct ggml_tensor * inpFF = cur; - - // feed-forward network - { - // norm - { - wstate.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, inpFF, hparams.eps); - - wstate.use_buf(ctx0, 1); - - // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.mlp_ln_w, cur), - cur), - ggml_repeat(ctx0, layer.mlp_ln_b, cur)); - } - -#ifdef WHISPER_USE_FLASH_FF - wstate.use_buf(ctx0, 0); - - cur = ggml_flash_ff(ctx0, - ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)), - layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); -#else - wstate.use_buf(ctx0, 0); - - // fully connected - cur = ggml_mul_mat(ctx0, - layer.mlp_0_w, - cur); - - wstate.use_buf(ctx0, 1); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_0_b, cur), - cur); - - wstate.use_buf(ctx0, 0); - - // GELU activation - cur = ggml_gelu(ctx0, cur); - - wstate.use_buf(ctx0, 1); - - // projection - cur = ggml_mul_mat(ctx0, - layer.mlp_1_w, - cur); - - wstate.use_buf(ctx0, 0); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_1_b, cur), - cur); -#endif - } - - wstate.use_buf(ctx0, 3); - - inpL = ggml_add(ctx0, cur, inpFF); - } - - cur = inpL; - - // norm - { - wstate.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, cur, hparams.eps); - - wstate.use_buf(ctx0, 1); - - // cur = ln_f_g*cur + ln_f_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.e_ln_w, cur), - cur), - ggml_repeat(ctx0, model.e_ln_b, cur)); - } - - wstate.use_buf(ctx0, -1); - - // run the computation - { - struct ggml_cgraph gf = {}; - - ggml_build_forward_expand(&gf, cur); - ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads); - - //ggml_graph_print(&gf); - } } -#ifdef WHISPER_USE_COREML - else if (use_coreml) { - wstate.use_buf(ctx0, -1); - cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); - - whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data); - } -#endif -#ifdef WHISPER_USE_OPENVINO - else if (use_openvino) { - wstate.use_buf(ctx0, -1); - - cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); - - if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) { - return false; - } - } -#endif - - // cur - //{ - // printf("ne0 = %d\n", cur->ne[0]); - // printf("ne1 = %d\n", cur->ne[1]); - // for (int i = 0; i < 10; ++i) { - // printf("%8.4f ", ((float *)(cur->data))[i]); - // } - // printf("... "); - // for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) { - // printf("%8.4f ", ((float *)(cur->data))[i]); - // } - // printf("\n"); - //} - - // pre-compute cross-attention memory + // cross { - struct ggml_cgraph gf = {}; + auto & alloc = wstate.alloc_cross.alloc; - // TODO: hack to disconnect the encoded features from the previous graph - cur->op = GGML_OP_NONE; - cur->src[0] = nullptr; - cur->src[1] = nullptr; + ggml_allocr_reset(alloc); - for (int il = 0; il < model.hparams.n_text_layer; ++il) { - auto& layer = model.layers_decoder[il]; + ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate); - wstate.use_buf(ctx0, 0); + ggml_allocr_alloc_graph(alloc, gf); - struct ggml_tensor* Kcross = ggml_mul_mat(ctx0, - layer.cross_attn_k_w, - cur); - - Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25))); - - wstate.use_buf(ctx0, 1); - - struct ggml_tensor* Vcross = ggml_mul_mat(ctx0, - layer.cross_attn_v_w, - cur); - - Vcross = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.cross_attn_v_b, - Vcross), - Vcross); - - wstate.use_buf(ctx0, -1); - - Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); - - struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); - struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, - ( n_ctx)*ggml_element_size(wstate.kv_cross.v), - (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); - - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); +#ifdef GGML_USE_METAL + if (wstate.ctx_metal) { + ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); + ggml_metal_graph_compute(wstate.ctx_metal, gf); + } else { + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); } - - ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads); - //ggml_graph_print(&gf); +#else + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); +#endif } - //////////////////////////////////////////////////////////////////////////// - - //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, - // ggml_used_mem(ctx0)/1024.0/1024.0, - // wstate.get_buf_max_mem(0)/1024.0/1024.0, - // wstate.get_buf_max_mem(1)/1024.0/1024.0, - // wstate.get_buf_max_mem(2)/1024.0/1024.0, - // wstate.get_buf_max_mem(3)/1024.0/1024.0); - - ggml_free(ctx0); + // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); wstate.t_encode_us += ggml_time_us() - t_start_us; wstate.n_encode++; @@ -1953,6 +1992,343 @@ static bool whisper_encode_internal( return true; } +static struct ggml_cgraph * whisper_build_graph_decoder( + whisper_context & wctx, + whisper_state & wstate, + whisper_decoder & decoder, + const whisper_token * tokens, + int n_tokens, + int n_past) { + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + auto & kv_self = decoder.kv_self; + + WHISPER_ASSERT(!!kv_self.ctx); + + const int n_ctx = hparams.n_text_ctx; + const int n_state = hparams.n_text_state; + const int n_head = hparams.n_text_head; + const int n_layer = hparams.n_text_layer; + + const int N = n_tokens; + const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + + //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); + + struct ggml_init_params params = { + /*.mem_size =*/ wstate.alloc_decode.meta.size(), + /*.mem_buffer =*/ wstate.alloc_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + ggml_allocr * alloc = wstate.alloc_decode.alloc; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_allocr_alloc(alloc, embd); + + if (!ggml_allocr_is_measure(alloc)) { + memcpy(embd->data, tokens, N*ggml_element_size(embd)); + } + + struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_allocr_alloc(alloc, position); + + if (!ggml_allocr_is_measure(alloc)) { + for (int i = 0; i < N; ++i) { + ((int32_t *) position->data)[i] = n_past + i; + } + } + + struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(alloc, KQscale); + + if (!ggml_allocr_is_measure(alloc)) { + ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25)); + } + + // token encoding + position encoding + struct ggml_tensor * cur = + ggml_add(ctx0, + ggml_get_rows(ctx0, model.d_te, embd), + ggml_get_rows(ctx0, model.d_pe, position)); + + struct ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_decoder[il]; + + // norm + { + cur = ggml_norm(ctx0, inpL, hparams.eps); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + layer.attn_ln_0_w), + layer.attn_ln_0_b); + } + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + Qcur, + layer.attn_q_b); + + Qcur = ggml_scale(ctx0, Qcur, KQscale); + + // note: no bias for Key + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + Kcur = ggml_scale(ctx0, Kcur, KQscale); + + // store key and value to memory + { + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, + Vcur, + layer.attn_v_b); + + Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N)); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v)); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + // ------ + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_state/n_head, n_past + N, n_head, + ggml_element_size(kv_self.k)*n_state, + ggml_element_size(kv_self.k)*n_state/n_head, + ggml_element_size(kv_self.k)*n_state*n_ctx*il); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); + + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_past + N, n_state/n_head, n_head, + n_ctx*ggml_element_size(kv_self.v), + n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, + il*n_ctx*ggml_element_size(kv_self.v)*n_state); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + } + + // projection + { + cur = ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.attn_ln_1_b); + } + + // add the input + struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); + + // norm + { + cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + layer.cross_attn_ln_0_w), + layer.cross_attn_ln_0_b); + } + + // cross-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.cross_attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + Qcur, + layer.cross_attn_q_b); + + Qcur = ggml_scale(ctx0, Qcur, KQscale); + + // Kcross is already scaled + struct ggml_tensor * Kcross = + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state/n_head, M, n_head, + ggml_element_size(wstate.kv_cross.k)*n_state, + ggml_element_size(wstate.kv_cross.k)*n_state/n_head, + ggml_element_size(wstate.kv_cross.k)*n_state*M*il); + + //struct ggml_tensor * Vcross = + // ggml_reshape_3d(ctx0, + // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state), + // n_state/n_head, n_head, M); + + //struct ggml_tensor * V_trans = + // ggml_cpy(ctx0, + // ggml_permute(ctx0, Vcross, 1, 2, 0, 3), + // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head)); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, wstate.kv_cross.v, + M, n_state/n_head, n_head, + M*ggml_element_size(wstate.kv_cross.v), + M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, + il*M*ggml_element_size(wstate.kv_cross.v)*n_state); + + // ------ + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); + + //struct ggml_tensor * KQ_scaled = + // ggml_scale(ctx0, + // KQ, + // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) + // ); + + // no masking for cross-attention + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_state, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + } + + // projection + { + cur = ggml_mul_mat(ctx0, + layer.cross_attn_ln_1_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.cross_attn_ln_1_b); + } + + // add the input + cur = ggml_add(ctx0, cur, inpCA); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctx0, inpFF, hparams.eps); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + layer.mlp_ln_w), + layer.mlp_ln_b); + } + + // fully connected + cur = ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.mlp_0_b); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + cur = ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.mlp_1_b); + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur, hparams.eps); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + model.d_ln_w), + model.d_ln_b); + } + + // compute logits only for the last token + // comment this line to compute logits for all N tokens + // might be useful in the future + cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); + + ggml_build_forward_expand(gf, logits); + + ggml_free(ctx0); + + return gf; +} + // evaluate the decoder // // given text prompt + audio features -> computes the logits for the next token @@ -1976,388 +2352,45 @@ static bool whisper_decode_internal( const auto & model = wctx.model; const auto & hparams = model.hparams; - auto & kv_self = decoder.kv_self; - - WHISPER_ASSERT(!!kv_self.ctx); + const int n_vocab = hparams.n_vocab; auto & logits_out = wstate.logits; - const int n_vocab = hparams.n_vocab; + struct ggml_tensor * logits; - const int n_ctx = hparams.n_text_ctx; - const int n_state = hparams.n_text_state; - const int n_head = hparams.n_text_head; - const int n_layer = hparams.n_text_layer; - - const int N = n_tokens; - const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; - - //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); - - struct ggml_init_params params = { - /*.mem_size =*/ wstate.buf_compute.size(), - /*.mem_buffer =*/ wstate.buf_compute.data(), - /*.no_alloc =*/ false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - struct ggml_cgraph gf = {}; - - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(embd->data, tokens, N*ggml_element_size(embd)); - - struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - for (int i = 0; i < N; ++i) { - ((int32_t *) position->data)[i] = n_past + i; - } - - wstate.use_buf(ctx0, 3); - - // token encoding + position encoding - struct ggml_tensor * cur = - ggml_add(ctx0, - ggml_get_rows(ctx0, model.d_te, embd), - ggml_get_rows(ctx0, model.d_pe, position)); - - struct ggml_tensor * inpL = cur; - - for (int il = 0; il < n_layer; ++il) { - const auto & layer = model.layers_decoder[il]; - - // norm - { - wstate.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, inpL, hparams.eps); - - // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.attn_ln_0_w, cur), - cur), - ggml_repeat(ctx0, layer.attn_ln_0_b, cur)); - } - - // self-attention - { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, - layer.attn_q_w, - cur); - - Qcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_q_b, - Qcur), - Qcur); - - Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - - // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, - layer.attn_k_w, - cur); - - Kcur = ggml_scale_inplace(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - - // store key and value to memory - { - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, - layer.attn_v_w, - cur); - - Vcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_v_b, - Vcur), - Vcur); - - Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N)); - - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v)); - - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); - } - - // ------ - - wstate.use_buf(ctx0, 0); - - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), - 0, 2, 1, 3); - - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), - n_state/n_head, n_head, n_past + N), - 0, 2, 1, 3); - - wstate.use_buf(ctx0, 1); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - //struct ggml_tensor * KQ_scaled = - // ggml_scale_inplace(ctx0, - // KQ, - // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) - // ); - - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ, n_past); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); - - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_state/n_head, n_head, - n_ctx*ggml_element_size(kv_self.v), - n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, - il*n_ctx*ggml_element_size(kv_self.v)*n_state); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); - } - - // projection - { - wstate.use_buf(ctx0, 0); - - cur = ggml_mul_mat(ctx0, - layer.attn_ln_1_w, - cur); - - wstate.use_buf(ctx0, 1); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.attn_ln_1_b, cur), - cur); - } - - wstate.use_buf(ctx0, 2); - - // add the input - struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); - - // norm - { - wstate.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here - - // cur = ln_0_w*cur + ln_0_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.cross_attn_ln_0_w, cur), - cur), - ggml_repeat(ctx0, layer.cross_attn_ln_0_b, cur)); - } - - // cross-attention - { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, - layer.cross_attn_q_w, - cur); - - Qcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.cross_attn_q_b, - Qcur), - Qcur); - - Qcur = ggml_scale_inplace(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - - // Kcross is already scaled - struct ggml_tensor * Kcross = - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state), - n_state/n_head, n_head, M); - - //struct ggml_tensor * Vcross = - // ggml_reshape_3d(ctx0, - // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state), - // n_state/n_head, n_head, M); - - //struct ggml_tensor * V_trans = - // ggml_cpy(ctx0, - // ggml_permute(ctx0, Vcross, 1, 2, 0, 3), - // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head)); - - struct ggml_tensor * V = - ggml_view_3d(ctx0, wstate.kv_cross.v, - M, n_state/n_head, n_head, - M*ggml_element_size(wstate.kv_cross.v), - M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, - il*M*ggml_element_size(wstate.kv_cross.v)*n_state); - - // ------ - - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), - 0, 2, 1, 3); - - struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - //struct ggml_tensor * KQ_scaled = - // ggml_scale_inplace(ctx0, - // KQ, - // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head)) - // ); - - // no masking for cross-attention - //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - // cur = KQV_merged.contiguous().view(n_state, N) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); - } - - // projection - { - wstate.use_buf(ctx0, 0); - - cur = ggml_mul_mat(ctx0, - layer.cross_attn_ln_1_w, - cur); - - wstate.use_buf(ctx0, 1); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur), - cur); - } - - wstate.use_buf(ctx0, 2); - - // add the input - cur = ggml_add(ctx0, cur, inpCA); - - struct ggml_tensor * inpFF = cur; - - // feed-forward network - { - // norm - { - wstate.use_buf(ctx0, 0); - - cur = ggml_norm(ctx0, inpFF, hparams.eps); - - wstate.use_buf(ctx0, 1); - - // cur = mlp_ln_w*cur + mlp_ln_b - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, layer.mlp_ln_w, cur), - cur), - ggml_repeat(ctx0, layer.mlp_ln_b, cur)); - } - - wstate.use_buf(ctx0, 0); - - // fully connected - cur = ggml_mul_mat(ctx0, - layer.mlp_0_w, - cur); - - wstate.use_buf(ctx0, 1); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_0_b, cur), - cur); - - wstate.use_buf(ctx0, 0); - - // GELU activation - cur = ggml_gelu(ctx0, cur); - - wstate.use_buf(ctx0, 1); - - // projection - cur = ggml_mul_mat(ctx0, - layer.mlp_1_w, - cur); - - wstate.use_buf(ctx0, 0); - - cur = ggml_add(ctx0, - ggml_repeat(ctx0, layer.mlp_1_b, cur), - cur); - } - - wstate.use_buf(ctx0, 3); - - inpL = ggml_add(ctx0, cur, inpFF); - } - - cur = inpL; - - // norm + // decoder { - wstate.use_buf(ctx0, 0); + auto & alloc = wstate.alloc_decode.alloc; - cur = ggml_norm(ctx0, cur, hparams.eps); + ggml_allocr_reset(alloc); - wstate.use_buf(ctx0, 1); + ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past); - cur = ggml_add(ctx0, - ggml_mul(ctx0, - ggml_repeat(ctx0, model.d_ln_w, cur), - cur), - ggml_repeat(ctx0, model.d_ln_b, cur)); - } + ggml_allocr_alloc_graph(alloc, gf); - wstate.use_buf(ctx0, 0); + logits = gf->nodes[gf->n_nodes - 1]; - // compute logits only for the last token - // comment this line to compute logits for all N tokens - // might be useful in the future - cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); - - struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); - - wstate.use_buf(ctx0, -1); - - // run the computation - { - ggml_build_forward_expand(&gf, logits); - ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads); +#ifdef GGML_USE_METAL + if (wstate.ctx_metal) { + ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); + ggml_metal_graph_compute(wstate.ctx_metal, gf); + } else { + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + } +#else + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); +#endif } // extract logits for all N tokens - //logits_out.resize(N*n_vocab); - //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + //logits_out.resize(n_tokens*n_vocab); + //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab); // extract logits only for the last token logits_out.resize(n_vocab); memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); - if (N > 1) { + if (n_tokens > 1) { //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, // ggml_used_mem(ctx0)/1024.0/1024.0, // wstate.get_buf_max_mem(0)/1024.0/1024.0, @@ -2366,14 +2399,18 @@ static bool whisper_decode_internal( // wstate.get_buf_max_mem(3)/1024.0/1024.0); } - ggml_free(ctx0); - - wstate.t_decode_us += ggml_time_us() - t_start_us; - wstate.n_decode++; + if (n_tokens == 1) { + wstate.t_decode_us += ggml_time_us() - t_start_us; + wstate.n_decode++; + } else { + wstate.t_prompt_us += ggml_time_us() - t_start_us; + wstate.n_prompt++; + } return true; } + // 500 -> 00:05.000 // 6000 -> 01:00.000 static std::string to_timestamp(int64_t t, bool comma = false) { @@ -2782,9 +2819,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { fill_sin_cos_table(); whisper_state * state = new whisper_state; - const size_t scale = ctx->model.hparams.ftype ? 1 : 2; - - if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) { + if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) { log("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state; return nullptr; @@ -2795,7 +2830,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) { + if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) { log("%s: kv_cache_init() failed for cross-attention cache\n", __func__); delete state; return nullptr; @@ -2816,6 +2851,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { if (!state->ctx_coreml) { log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); #ifndef WHISPER_COREML_ALLOW_FALLBACK + delete state; return nullptr; #endif } else { @@ -2830,15 +2866,111 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // TAGS: WHISPER_DECODER_INIT state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); - state->decoders[0].probs.reserve(ctx->vocab.n_vocab); - state->decoders[0].logits.reserve(ctx->vocab.n_vocab); + state->decoders[0].probs.reserve (ctx->vocab.n_vocab); + state->decoders[0].logits.reserve (ctx->vocab.n_vocab); state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); - state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type))); - state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type)); - state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type)); - state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type)); - state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type)); + // conv allocator + { + whisper_allocr_graph_init(state->alloc_conv, + [&]() { + return whisper_build_graph_conv(*ctx, *state, 0); + }); + + log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0); + } + + // encoder allocator + if (!whisper_encode_external(*state)) { + whisper_allocr_graph_init(state->alloc_encode, + [&]() { + return whisper_build_graph_encoder(*ctx, *state); + }); + + log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0); + } + + // cross allocator + { + whisper_allocr_graph_init(state->alloc_cross, + [&]() { + return whisper_build_graph_cross(*ctx, *state); + }); + + log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0); + } + + // decoder allocator + { + whisper_allocr_graph_init(state->alloc_decode, + [&]() { + const auto & hparams = ctx->model.hparams; + + // TODO: make sure this is the worst-case scenario + const int n_tokens = hparams.n_text_ctx; + const int n_past = 0; + + return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past); + }); + + log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); + } + +#ifdef GGML_USE_METAL + state->ctx_metal = ggml_metal_init(1); + if (!state->ctx_metal) { + log("%s: ggml_metal_init() failed\n", __func__); + delete state; + return nullptr; + } + + log("%s: Metal context initialized\n", __func__); + + // this allocates all Metal resources and memory buffers + + void * data_ptr = NULL; + size_t data_size = 0; + + // TODO: add mmap support + //if (params.use_mmap) { + // data_ptr = ctx->model.mapping->addr; + // data_size = ctx->model.mapping->size; + //} else { + // data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + // data_size = ggml_get_mem_size (ctx->model.ctx); + //} + + data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + data_size = ggml_get_mem_size (ctx->model.ctx); + + const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); + + log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); + +#define WHISPER_METAL_CHECK_BUF(result) \ + if (!(result)) { \ + log("%s: failed to add metal buffer\n", __func__); \ + delete state; \ + return nullptr; \ + } + + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size)); + + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0)); + + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0)); + + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0)); + + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0)); +#undef WHISPER_METAL_CHECK_BUF +#endif state->rng = std::mt19937(0); @@ -2895,7 +3027,6 @@ int whisper_ctx_init_openvino_encoder( } struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { - log("%s: loading model from '%s'\n", __func__, path_model); auto fin = std::ifstream(path_model, std::ios::binary); @@ -3048,6 +3179,13 @@ void whisper_free_state(struct whisper_state * state) } #endif +#ifdef GGML_USE_METAL + if (state->ctx_metal) { + ggml_metal_free(state->ctx_metal); + state->ctx_metal = nullptr; + } +#endif + #ifdef WHISPER_USE_OPENVINO if (state->ctx_openvino != nullptr) { whisper_openvino_free(state->ctx_openvino); @@ -3055,6 +3193,11 @@ void whisper_free_state(struct whisper_state * state) } #endif + whisper_allocr_free(state->alloc_conv); + whisper_allocr_free(state->alloc_decode); + whisper_allocr_free(state->alloc_cross); + whisper_allocr_free(state->alloc_encode); + delete state; } } @@ -3475,12 +3618,14 @@ void whisper_print_timings(struct whisper_context * ctx) { const int32_t n_sample = std::max(1, ctx->state->n_sample); const int32_t n_encode = std::max(1, ctx->state->n_encode); const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_prompt = std::max(1, ctx->state->n_prompt); log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); } log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } @@ -3490,6 +3635,11 @@ void whisper_reset_timings(struct whisper_context * ctx) { ctx->state->t_sample_us = 0; ctx->state->t_encode_us = 0; ctx->state->t_decode_us = 0; + ctx->state->t_prompt_us = 0; + ctx->state->n_sample = 0; + ctx->state->n_encode = 0; + ctx->state->n_decode = 0; + ctx->state->n_prompt = 0; } } @@ -4339,6 +4489,21 @@ int whisper_full_with_state( decoder.probs.resize (ctx->vocab.n_vocab); decoder.logits.resize (ctx->vocab.n_vocab); decoder.logprobs.resize(ctx->vocab.n_vocab); + + // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0 +#ifdef GGML_USE_METAL +#define WHISPER_METAL_CHECK_BUF(result) \ + if (!(result)) { \ + log("%s: failed to add metal buffer\n", __func__); \ + return 0; \ + } + + const std::string kv_name = "kv_self_" + std::to_string(j); + auto & kv_self = decoder.kv_self; + + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0)); +#undef WHISPER_METAL_CHECK_BUF +#endif } } @@ -4531,8 +4696,8 @@ int whisper_full_with_state( decoder.kv_self.n += prompt.size(); - memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); - memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); } @@ -5045,6 +5210,12 @@ int whisper_full_parallel( ctx->state->t_sample_us += states[i]->t_sample_us; ctx->state->t_encode_us += states[i]->t_encode_us; ctx->state->t_decode_us += states[i]->t_decode_us; + ctx->state->t_prompt_us += states[i]->t_prompt_us; + + ctx->state->n_sample += states[i]->n_sample; + ctx->state->n_encode += states[i]->n_encode; + ctx->state->n_decode += states[i]->n_decode; + ctx->state->n_prompt += states[i]->n_prompt; whisper_free_state(states[i]); } @@ -5241,8 +5412,8 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { // b: N*N*sizeof(float) // c: N*N*sizeof(float) // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) - std::vector buf (3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead()); - std::vector work(1llu*N_max*N_max*sizeof(float) + 1*ggml_tensor_overhead()); + std::vector buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead()); + std::vector work; // put a bunch of random data in the buffer for (size_t i = 0; i < buf.size(); i++) buf[i] = i;