diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6575700..2e25ef6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -125,8 +125,10 @@ jobs: include: - arch: Win32 s2arc: x86 + jnaPath: win32-x86 - arch: x64 s2arc: x64 + jnaPath: win32-x86-64 - sdl2: ON s2ver: 2.26.0 @@ -159,6 +161,12 @@ jobs: if: matrix.sdl2 == 'ON' run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} + - name: Upload dll + uses: actions/upload-artifact@v3 + with: + name: ${{ matrix.jnaPath }}_whisper.dll + path: build/bin/${{ matrix.build }}/whisper.dll + - name: Upload binaries if: matrix.sdl2 == 'ON' uses: actions/upload-artifact@v1 @@ -363,3 +371,42 @@ jobs: run: | cd examples/whisper.android ./gradlew assembleRelease --no-daemon + + java: + needs: [ 'windows' ] + runs-on: windows-latest + steps: + - uses: actions/checkout@v1 + + - name: Install Java + uses: actions/setup-java@v1 + with: + java-version: 17 + + - name: Download Windows lib + uses: actions/download-artifact@v3 + with: + name: win32-x86-64_whisper.dll + path: bindings/java/build/generated/resources/main/win32-x86-64 + + - name: Build + run: | + models\download-ggml-model.cmd tiny.en + cd bindings/java + chmod +x ./gradlew + ./gradlew build + + - name: Upload jar + uses: actions/upload-artifact@v3 + with: + name: whispercpp.jar + path: bindings/java/build/libs/whispercpp-*.jar + +# - name: Publish package +# if: ${{ github.ref == 'refs/heads/master' }} +# uses: gradle/gradle-build-action@v2 +# with: +# arguments: publish +# env: +# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }} +# MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }} diff --git a/bindings/java/CMakeLists.txt b/bindings/java/CMakeLists.txt deleted file mode 100644 index 7e47bb3..0000000 --- a/bindings/java/CMakeLists.txt +++ /dev/null @@ -1,50 +0,0 @@ -cmake_minimum_required(VERSION 3.10) - -project(whisper_java VERSION 1.4.2) - -# Set the target name and source file/s -set(TARGET_NAME whisper_java) -set(SOURCES src/main/cpp/whisper_java.cpp) - -# include -include_directories(../../) - -# Set the output directory for the DLL/shared library based on the platform as required by JNA -if(WIN32) - set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/win32-x86-64) -elseif(UNIX AND NOT APPLE) - set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/linux-x86-64) -elseif(APPLE) - set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/macos-x86-64) -endif() - -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OUTPUT_DIR}) -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OUTPUT_DIR}) - -# Create the whisper_java library -add_library(${TARGET_NAME} SHARED ${SOURCES}) - -# Link against ../../build/Release/whisper.dll (or so/dynlib) -target_link_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/../../../build/${CMAKE_BUILD_TYPE}) -target_link_libraries(${TARGET_NAME} PRIVATE whisper) - -# Set the appropriate compiler flags for Windows, Linux, and macOS -if(WIN32) - target_compile_options(${TARGET_NAME} PRIVATE /W4 /D_CRT_SECURE_NO_WARNINGS) -elseif(UNIX AND NOT APPLE) - target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra) -elseif(APPLE) - target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra) -endif() - -target_compile_definitions(${TARGET_NAME} PRIVATE WHISPER_SHARED) -# add_definitions(-DWHISPER_SHARED) - -# Force CMake to save the libs to build/generated/resources/main/${os}-${arch} as required by JNA -foreach(OUTPUTCONFIG ${CMAKE_CONFIGURATION_TYPES}) - string(TOUPPER ${OUTPUTCONFIG} OUTPUTCONFIG) - set_target_properties(${TARGET_NAME} PROPERTIES - RUNTIME_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR} - LIBRARY_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR} - ARCHIVE_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}) -endforeach(OUTPUTCONFIG CMAKE_CONFIGURATION_TYPES) diff --git a/bindings/java/README.md b/bindings/java/README.md index 429287c..24c461e 100644 --- a/bindings/java/README.md +++ b/bindings/java/README.md @@ -6,11 +6,7 @@ This package provides Java JNI bindings for whisper.cpp. They have been tested o * Ubuntu on x86_64 * Windows on x86_64 -The "low level" bindings are in `WhisperCppJnaLibrary` and `WhisperJavaJnaLibrary` which caches `whisper_full_params` and `whisper_context` in `whisper_java.cpp`. - -There are a lot of classes in the `callbacks`, `ggml`, `model` and `params` directories but most of them have not been tested. - -The most simple usage is as follows: +The "low level" bindings are in `WhisperCppJnaLibrary`. The most simple usage is as follows: ```java import io.github.ggerganov.whispercpp.WhisperCpp; @@ -48,12 +44,6 @@ In order to build, you need to have the JDK 8 or higher installed. Run the tests git clone https://github.com/ggerganov/whisper.cpp.git cd whisper.cpp/bindings/java -mkdir build -pushd build -cmake .. -cmake --build . -popd - ./gradlew build ``` diff --git a/bindings/java/build.gradle b/bindings/java/build.gradle index 4a9b02f..3028f6f 100644 --- a/bindings/java/build.gradle +++ b/bindings/java/build.gradle @@ -22,6 +22,12 @@ sourceSets { } } +tasks.register('copyLibwhisperDynlib', Copy) { + from '../../build' + include 'libwhisper.dynlib' + into 'build/generated/resources/main/darwin' +} + tasks.register('copyLibwhisperSo', Copy) { from '../../build' include 'libwhisper.so' @@ -34,7 +40,9 @@ tasks.register('copyWhisperDll', Copy) { into 'build/generated/resources/main/windows-x86-64' } -tasks.build.dependsOn copyLibwhisperSo, copyWhisperDll +tasks.register('copyLibs') { + dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll +} test { systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath diff --git a/bindings/java/src/main/cpp/whisper_java.cpp b/bindings/java/src/main/cpp/whisper_java.cpp deleted file mode 100644 index 9e06aa0..0000000 --- a/bindings/java/src/main/cpp/whisper_java.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include -#include "whisper_java.h" - -struct whisper_full_params default_params; -struct whisper_context * whisper_ctx = nullptr; - -struct void whisper_java_default_params(enum whisper_sampling_strategy strategy) { - default_params = whisper_full_default_params(strategy); - -// struct whisper_java_params result = {}; -// return result; - return; -} - -void whisper_java_init_from_file(const char * path_model) { - whisper_ctx = whisper_init_from_file(path_model); - if (0 == default_params.n_threads) { - whisper_java_default_params(WHISPER_SAMPLING_GREEDY); - } -} - -/** Delegates to whisper_full, but without having to pass `whisper_full_params` */ -int whisper_java_full( - struct whisper_context * ctx, -// struct whisper_java_params params, - const float * samples, - int n_samples) { - return whisper_full(ctx, default_params, samples, n_samples); -} - -void whisper_java_free() { -// free(default_params); -} diff --git a/bindings/java/src/main/cpp/whisper_java.h b/bindings/java/src/main/cpp/whisper_java.h deleted file mode 100644 index d64866b..0000000 --- a/bindings/java/src/main/cpp/whisper_java.h +++ /dev/null @@ -1,24 +0,0 @@ -#define WHISPER_BUILD -#include - -#ifdef __cplusplus -extern "C" { -#endif - -struct whisper_java_params { -}; - -WHISPER_API void whisper_java_default_params(enum whisper_sampling_strategy strategy); - -WHISPER_API void whisper_java_init_from_file(const char * path_model); - -WHISPER_API int whisper_java_full( - struct whisper_context * ctx, -// struct whisper_java_params params, - const float * samples, - int n_samples); - - -#ifdef __cplusplus -} -#endif diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java index f014407..9bc1a86 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java @@ -1,7 +1,8 @@ package io.github.ggerganov.whispercpp; +import com.sun.jna.Native; import com.sun.jna.Pointer; -import io.github.ggerganov.whispercpp.params.WhisperJavaParams; +import io.github.ggerganov.whispercpp.params.WhisperFullParams; import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy; import java.io.File; @@ -13,8 +14,9 @@ import java.io.IOException; */ public class WhisperCpp implements AutoCloseable { private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance; - private WhisperJavaJnaLibrary javaLib = WhisperJavaJnaLibrary.instance; private Pointer ctx = null; + private Pointer greedyPointer = null; + private Pointer beamPointer = null; public File modelDir() { String modelDirPath = System.getenv("XDG_CACHE_HOME"); @@ -27,9 +29,8 @@ public class WhisperCpp implements AutoCloseable { /** * @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en") - * @return a Pointer to the WhisperContext */ - void initContext(String modelPath) throws FileNotFoundException { + public void initContext(String modelPath) throws FileNotFoundException { if (ctx != null) { lib.whisper_free(ctx); } @@ -42,7 +43,6 @@ public class WhisperCpp implements AutoCloseable { modelPath = new File(modelDir(), modelPath).getAbsolutePath(); } - javaLib.whisper_java_init_from_file(modelPath); ctx = lib.whisper_init_from_file(modelPath); if (ctx == null) { @@ -51,22 +51,38 @@ public class WhisperCpp implements AutoCloseable { } /** - * Initialises `whisper_full_params` internally in whisper_java.cpp so JNA doesn't have to map everything. - * `whisper_java_init_from_file()` calls `whisper_java_default_params(WHISPER_SAMPLING_GREEDY)` for convenience. + * Provides default params which can be used with `whisper_full()` etc. + * Because this function allocates memory for the params, the caller must call either: + * - call `whisper_free_params()` + * - `Native.free(Pointer.nativeValue(pointer));` + * + * @param strategy - GREEDY */ - public void getDefaultJavaParams(WhisperSamplingStrategy strategy) { - javaLib.whisper_java_default_params(strategy.ordinal()); -// return lib.whisper_full_default_params(strategy.value) - } + public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) { + Pointer pointer; -// whisper_full_default_params was too hard to integrate with, so for now we use javaLib.whisper_java_default_params -// fun getDefaultParams(strategy: WhisperSamplingStrategy): WhisperFullParams { -// return lib.whisper_full_default_params(strategy.value) -// } + // whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy. + if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) { + if (greedyPointer == null) { + greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal()); + } + pointer = greedyPointer; + } else { + if (beamPointer == null) { + beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal()); + } + pointer = beamPointer; + } + + WhisperFullParams params = new WhisperFullParams(pointer); + params.read(); + return params; + } @Override public void close() { freeContext(); + freeParams(); System.out.println("Whisper closed"); } @@ -76,17 +92,28 @@ public class WhisperCpp implements AutoCloseable { } } + private void freeParams() { + if (greedyPointer != null) { + Native.free(Pointer.nativeValue(greedyPointer)); + greedyPointer = null; + } + if (beamPointer != null) { + Native.free(Pointer.nativeValue(beamPointer)); + beamPointer = null; + } + } + /** * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text. * Not thread safe for same context * Uses the specified decoding strategy to obtain the text. */ - public String fullTranscribe(/*WhisperJavaParams whisperParams,*/ float[] audioData) throws IOException { + public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException { if (ctx == null) { throw new IllegalStateException("Model not initialised"); } - if (javaLib.whisper_java_full(ctx, /*whisperParams,*/ audioData, audioData.length) != 0) { + if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) { throw new IOException("Failed to process audio"); } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java index 6602565..c1fb4f8 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java @@ -231,10 +231,21 @@ public interface WhisperCppJnaLibrary extends Library { void whisper_print_timings(Pointer ctx); void whisper_reset_timings(Pointer ctx); + // Note: Even if `whisper_full_params is stripped back to just 4 ints, JNA throws "Invalid memory access" + // when `whisper_full_default_params()` tries to return a struct. + // WhisperFullParams whisper_full_default_params(int strategy); + /** + * Provides default params which can be used with `whisper_full()` etc. + * Because this function allocates memory for the params, the caller must call either: + * - call `whisper_free_params()` + * - `Native.free(Pointer.nativeValue(pointer));` + * * @param strategy - WhisperSamplingStrategy.value */ - WhisperFullParams whisper_full_default_params(int strategy); + Pointer whisper_full_default_params_by_ref(int strategy); + + void whisper_free_params(Pointer params); /** * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java deleted file mode 100644 index 74f8459..0000000 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java +++ /dev/null @@ -1,23 +0,0 @@ -package io.github.ggerganov.whispercpp; - -import com.sun.jna.Library; -import com.sun.jna.Native; -import com.sun.jna.Pointer; -import io.github.ggerganov.whispercpp.params.WhisperJavaParams; - -interface WhisperJavaJnaLibrary extends Library { - WhisperJavaJnaLibrary instance = Native.load("whisper_java", WhisperJavaJnaLibrary.class); - - void whisper_java_default_params(int strategy); - - void whisper_java_free(); - - void whisper_java_init_from_file(String modelPath); - - /** - * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text. - * Not thread safe for same context - * Uses the specified decoding strategy to obtain the text. - */ - int whisper_java_full(Pointer ctx, /*WhisperJavaParams params, */float[] samples, int nSamples); -} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java index b5e9797..3d228cb 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java @@ -20,5 +20,5 @@ public interface WhisperEncoderBeginCallback extends Callback { * @param user_data User data. * @return True if the computation should proceed, false otherwise. */ - boolean callback(WhisperContext ctx, WhisperState state, Pointer user_data); + boolean callback(Pointer ctx, Pointer state, Pointer user_data); } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java index 5377b4e..9777c76 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java @@ -1,12 +1,9 @@ package io.github.ggerganov.whispercpp.callbacks; +import com.sun.jna.Callback; import com.sun.jna.Pointer; -import io.github.ggerganov.whispercpp.WhisperContext; -import io.github.ggerganov.whispercpp.model.WhisperState; import io.github.ggerganov.whispercpp.model.WhisperTokenData; -import javax.security.auth.callback.Callback; - /** * Callback to filter logits. * Can be used to modify the logits before sampling. @@ -24,5 +21,5 @@ public interface WhisperLogitsFilterCallback extends Callback { * @param logits The array of logits. * @param user_data User data. */ - void callback(WhisperContext ctx, WhisperState state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data); + void callback(Pointer ctx, Pointer state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data); } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java index 95ca346..27b1c61 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java @@ -20,5 +20,5 @@ public interface WhisperNewSegmentCallback extends Callback { * @param n_new The number of newly generated text segments. * @param user_data User data. */ - void callback(WhisperContext ctx, WhisperState state, int n_new, Pointer user_data); + void callback(Pointer ctx, Pointer state, int n_new, Pointer user_data); } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java index 8866215..c64f0ab 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java @@ -1,11 +1,10 @@ package io.github.ggerganov.whispercpp.callbacks; +import com.sun.jna.Callback; import com.sun.jna.Pointer; import io.github.ggerganov.whispercpp.WhisperContext; import io.github.ggerganov.whispercpp.model.WhisperState; -import javax.security.auth.callback.Callback; - /** * Callback for progress updates. */ @@ -19,5 +18,5 @@ public interface WhisperProgressCallback extends Callback { * @param progress The progress value. * @param user_data User data. */ - void callback(WhisperContext ctx, WhisperState state, int progress, Pointer user_data); + void callback(Pointer ctx, Pointer state, int progress, Pointer user_data); } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/BeamSearchParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/BeamSearchParams.java new file mode 100644 index 0000000..fd621dd --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/BeamSearchParams.java @@ -0,0 +1,19 @@ +package io.github.ggerganov.whispercpp.params; + +import com.sun.jna.Structure; + +import java.util.Arrays; +import java.util.List; + +public class BeamSearchParams extends Structure { + /** ref: ... */ + public int beam_size; + + /** ref: ... */ + public float patience; + + @Override + protected List getFieldOrder() { + return Arrays.asList("beam_size", "patience"); + } +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/CBool.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/CBool.java new file mode 100644 index 0000000..1f6814b --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/CBool.java @@ -0,0 +1,30 @@ +package io.github.ggerganov.whispercpp.params; + +import com.sun.jna.IntegerType; + +import java.util.function.BooleanSupplier; + +public class CBool extends IntegerType implements BooleanSupplier { + public static final int SIZE = 1; + public static final CBool FALSE = new CBool(0); + public static final CBool TRUE = new CBool(1); + + + public CBool() { + this(0); + } + + public CBool(long value) { + super(SIZE, value, true); + } + + @Override + public boolean getAsBoolean() { + return intValue() == 1; + } + + @Override + public String toString() { + return intValue() == 1 ? "true" : "false"; + } +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/GreedyParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/GreedyParams.java new file mode 100644 index 0000000..e3b0138 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/GreedyParams.java @@ -0,0 +1,16 @@ +package io.github.ggerganov.whispercpp.params; + +import com.sun.jna.Structure; + +import java.util.Collections; +import java.util.List; + +public class GreedyParams extends Structure { + /** ... */ + public int best_of; + + @Override + protected List getFieldOrder() { + return Collections.singletonList("best_of"); + } +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index ea0bccf..07e6894 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -1,13 +1,14 @@ package io.github.ggerganov.whispercpp.params; -import com.sun.jna.Callback; -import com.sun.jna.Pointer; -import com.sun.jna.Structure; +import com.sun.jna.*; import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback; import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback; import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback; import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback; +import java.util.Arrays; +import java.util.List; + /** * Parameters for the whisper_full() function. * If you change the order or add new parameters, make sure to update the default values in whisper.cpp: @@ -15,62 +16,123 @@ import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback; */ public class WhisperFullParams extends Structure { + public WhisperFullParams(Pointer p) { + super(p); +// super(p, ALIGN_MSVC); +// super(p, ALIGN_GNUC); + } + /** Sampling strategy for whisper_full() function. */ public int strategy; - /** Number of threads. */ + /** Number of threads. (default = 4) */ public int n_threads; - /** Maximum tokens to use from past text as a prompt for the decoder. */ + /** Maximum tokens to use from past text as a prompt for the decoder. (default = 16384) */ public int n_max_text_ctx; - /** Start offset in milliseconds. */ + /** Start offset in milliseconds. (default = 0) */ public int offset_ms; - /** Audio duration to process in milliseconds. */ + /** Audio duration to process in milliseconds. (default = 0) */ public int duration_ms; - /** Translate flag. */ - public boolean translate; + /** Translate flag. (default = false) */ + public CBool translate; - /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. */ - public boolean no_context; + /** The compliment of translateMode() */ + public void transcribeMode() { + translate = CBool.FALSE; + } - /** Flag to force single segment output (useful for streaming). */ - public boolean single_segment; + /** The compliment of transcribeMode() */ + public void translateMode() { + translate = CBool.TRUE; + } - /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). */ - public boolean print_special; + /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */ + public CBool no_context; - /** Flag to print progress information. */ - public boolean print_progress; + /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */ + public void enableContext(boolean enable) { + no_context = enable ? CBool.FALSE : CBool.TRUE; + } - /** Flag to print results from within whisper.cpp (avoid it, use callback instead). */ - public boolean print_realtime; + /** Flag to force single segment output (useful for streaming). (default = false) */ + public CBool single_segment; - /** Flag to print timestamps for each text segment when printing realtime. */ - public boolean print_timestamps; + /** Flag to force single segment output (useful for streaming). (default = false) */ + public void singleSegment(boolean single) { + single_segment = single ? CBool.TRUE : CBool.FALSE; + } - /** [EXPERIMENTAL] Flag to enable token-level timestamps. */ - public boolean token_timestamps; + /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */ + public CBool print_special; - /** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). */ + /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */ + public void printSpecial(boolean enable) { + print_special = enable ? CBool.TRUE : CBool.FALSE; + } + + /** Flag to print progress information. (default = true) */ + public CBool print_progress; + + /** Flag to print progress information. (default = true) */ + public void printProgress(boolean enable) { + print_progress = enable ? CBool.TRUE : CBool.FALSE; + } + + /** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */ + public CBool print_realtime; + + /** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */ + public void printRealtime(boolean enable) { + print_realtime = enable ? CBool.TRUE : CBool.FALSE; + } + + /** Flag to print timestamps for each text segment when printing realtime. (default = true) */ + public CBool print_timestamps; + + /** Flag to print timestamps for each text segment when printing realtime. (default = true) */ + public void printTimestamps(boolean enable) { + print_timestamps = enable ? CBool.TRUE : CBool.FALSE; + } + + /** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */ + public CBool token_timestamps; + + /** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */ + public void tokenTimestamps(boolean enable) { + token_timestamps = enable ? CBool.TRUE : CBool.FALSE; + } + + /** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). (default = 0.01) */ public float thold_pt; /** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */ public float thold_ptsum; - /** Maximum segment length in characters. */ + /** Maximum segment length in characters. (default = 0) */ public int max_len; - /** Flag to split on word rather than on token (when used with max_len). */ - public boolean split_on_word; + /** Flag to split on word rather than on token (when used with max_len). (default = false) */ + public CBool split_on_word; - /** Maximum tokens per segment (0 = no limit). */ + /** Flag to split on word rather than on token (when used with max_len). (default = false) */ + public void splitOnWord(boolean enable) { + split_on_word = enable ? CBool.TRUE : CBool.FALSE; + } + + /** Maximum tokens per segment (0, default = no limit) */ public int max_tokens; - /** Flag to speed up the audio by 2x using Phase Vocoder. */ - public boolean speed_up; + /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */ + public CBool speed_up; + + /** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */ + public void speedUp(boolean enable) { + speed_up = enable ? CBool.TRUE : CBool.FALSE; + } /** Overwrite the audio context size (0 = use default). */ public int audio_ctx; @@ -79,9 +141,15 @@ public class WhisperFullParams extends Structure { * These are prepended to any existing text context from a previous call. */ public String initial_prompt; - /** Prompt tokens. */ + /** Prompt tokens. (int*) */ public Pointer prompt_tokens; + public void setPromptTokens(int[] tokens) { + Memory mem = new Memory(tokens.length * 4L); + mem.write(0, tokens, 0, tokens.length); + prompt_tokens = mem; + } + /** Number of prompt tokens. */ public int prompt_n_tokens; @@ -90,15 +158,29 @@ public class WhisperFullParams extends Structure { public String language; /** Flag to indicate whether to detect language automatically. */ - public boolean detect_language; + public CBool detect_language; - /** Common decoding parameters. */ + /** Flag to indicate whether to detect language automatically. */ + public void detectLanguage(boolean enable) { + detect_language = enable ? CBool.TRUE : CBool.FALSE; + } + + // Common decoding parameters. /** Flag to suppress blank tokens. */ - public boolean suppress_blank; + public CBool suppress_blank; + + public void suppressBlanks(boolean enable) { + suppress_blank = enable ? CBool.TRUE : CBool.FALSE; + } /** Flag to suppress non-speech tokens. */ - public boolean suppress_non_speech_tokens; + public CBool suppress_non_speech_tokens; + + /** Flag to suppress non-speech tokens. */ + public void suppressNonSpeechTokens(boolean enable) { + suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE; + } /** Initial decoding temperature. */ public float temperature; @@ -109,7 +191,7 @@ public class WhisperFullParams extends Structure { /** Length penalty. */ public float length_penalty; - /** Fallback parameters. */ + // Fallback parameters. /** Temperature increment. */ public float temperature_inc; @@ -123,31 +205,41 @@ public class WhisperFullParams extends Structure { /** No speech threshold. */ public float no_speech_thold; - class GreedyParams extends Structure { - /** https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 */ - public int best_of; - } - /** Greedy decoding parameters. */ public GreedyParams greedy; - class BeamSearchParams extends Structure { - /** ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 */ - int beam_size; - - /** ref: https://arxiv.org/pdf/2204.05424.pdf */ - float patience; - } - /** * Beam search decoding parameters. */ public BeamSearchParams beam_search; + public void setBestOf(int bestOf) { + if (greedy == null) { + greedy = new GreedyParams(); + } + greedy.best_of = bestOf; + } + + public void setBeamSize(int beamSize) { + if (beam_search == null) { + beam_search = new BeamSearchParams(); + } + beam_search.beam_size = beamSize; + } + + public void setBeamSizeAndPatience(int beamSize, float patience) { + if (beam_search == null) { + beam_search = new BeamSearchParams(); + } + beam_search.beam_size = beamSize; + beam_search.patience = patience; + } + /** * Callback for every newly generated text segment. + * WhisperNewSegmentCallback */ - public WhisperNewSegmentCallback new_segment_callback; + public Pointer new_segment_callback; /** * User data for the new_segment_callback. @@ -156,8 +248,9 @@ public class WhisperFullParams extends Structure { /** * Callback on each progress update. + * WhisperProgressCallback */ - public WhisperProgressCallback progress_callback; + public Pointer progress_callback; /** * User data for the progress_callback. @@ -166,8 +259,9 @@ public class WhisperFullParams extends Structure { /** * Callback each time before the encoder starts. + * WhisperEncoderBeginCallback */ - public WhisperEncoderBeginCallback encoder_begin_callback; + public Pointer encoder_begin_callback; /** * User data for the encoder_begin_callback. @@ -176,12 +270,44 @@ public class WhisperFullParams extends Structure { /** * Callback by each decoder to filter obtained logits. + * WhisperLogitsFilterCallback */ - public WhisperLogitsFilterCallback logits_filter_callback; + public Pointer logits_filter_callback; /** * User data for the logits_filter_callback. */ public Pointer logits_filter_callback_user_data; -} + + public void setNewSegmentCallback(WhisperNewSegmentCallback callback) { + new_segment_callback = CallbackReference.getFunctionPointer(callback); + } + + public void setProgressCallback(WhisperProgressCallback callback) { + progress_callback = CallbackReference.getFunctionPointer(callback); + } + + public void setEncoderBeginCallbackeginCallbackCallback(WhisperEncoderBeginCallback callback) { + encoder_begin_callback = CallbackReference.getFunctionPointer(callback); + } + + public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) { + logits_filter_callback = CallbackReference.getFunctionPointer(callback); + } + + @Override + protected List getFieldOrder() { + return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate", + "no_context", "single_segment", + "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", + "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx", + "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", + "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty", + "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", + "new_segment_callback", "new_segment_callback_user_data", + "progress_callback", "progress_callback_user_data", + "encoder_begin_callback", "encoder_begin_callback_user_data", + "logits_filter_callback", "logits_filter_callback_user_data"); + } +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java deleted file mode 100644 index 728485c..0000000 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java +++ /dev/null @@ -1,7 +0,0 @@ -package io.github.ggerganov.whispercpp.params; - -import com.sun.jna.Structure; - -public class WhisperJavaParams extends Structure { - -} diff --git a/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java index 98390aa..66e18f9 100644 --- a/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java +++ b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java @@ -2,7 +2,8 @@ package io.github.ggerganov.whispercpp; import static org.junit.jupiter.api.Assertions.*; -import io.github.ggerganov.whispercpp.params.WhisperJavaParams; +import io.github.ggerganov.whispercpp.params.CBool; +import io.github.ggerganov.whispercpp.params.WhisperFullParams; import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -19,11 +20,11 @@ class WhisperCppTest { static void init() throws FileNotFoundException { // By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin" // or you can provide the absolute path to the model file. - String modelName = "base.en"; + String modelName = "../../models/ggml-tiny.en.bin"; try { whisper.initContext(modelName); - whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY); -// whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); +// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY); +// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); modelInitialised = true; } catch (FileNotFoundException ex) { System.out.println("Model " + modelName + " not found"); @@ -31,11 +32,30 @@ class WhisperCppTest { } @Test - void testGetDefaultJavaParams() { + void testGetDefaultFullParams_BeamSearch() { // When - whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); + WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); - // Then if it doesn't throw we've connected to whisper.cpp + // Then + assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal(), params.strategy); + assertNotEquals(0, params.n_threads); + assertEquals(16384, params.n_max_text_ctx); + assertFalse(params.translate); + assertEquals(0.01f, params.thold_pt); + assertEquals(2, params.beam_search.beam_size); + assertEquals(-1.0f, params.beam_search.patience); + } + + @Test + void testGetDefaultFullParams_Greedy() { + // When + WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY); + + // Then + assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy); + assertNotEquals(0, params.n_threads); + assertEquals(16384, params.n_max_text_ctx); + assertEquals(2, params.greedy.best_of); } @Test @@ -52,6 +72,13 @@ class WhisperCppTest { byte[] b = new byte[audioInputStream.available()]; float[] floats = new float[b.length / 2]; +// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY); + WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); + params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress)); + params.print_progress = CBool.FALSE; +// params.initial_prompt = "and so my fellow Americans um, like"; + + try { audioInputStream.read(b); @@ -61,13 +88,13 @@ class WhisperCppTest { } // When - String result = whisper.fullTranscribe(/*params,*/ floats); + String result = whisper.fullTranscribe(params, floats); // Then - System.out.println(result); - assertEquals("And so my fellow Americans, ask not what your country can do for you, " + + System.err.println(result); + assertEquals("And so my fellow Americans ask not what your country can do for you " + "ask what you can do for your country.", - result); + result.replace(",", "")); } finally { audioInputStream.close(); } diff --git a/whisper.cpp b/whisper.cpp index 6faa3f2..0cdd4a1 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2852,6 +2852,12 @@ void whisper_free(struct whisper_context * ctx) { } } +void whisper_free_params(struct whisper_full_params * params) { + if (params) { + delete params; + } +} + int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); @@ -3285,6 +3291,14 @@ const char * whisper_print_system_info(void) { //////////////////////////////////////////////////////////////////////////// +struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) { + struct whisper_full_params params = whisper_full_default_params(strategy); + + struct whisper_full_params* result = new whisper_full_params(); + *result = params; + return result; +} + struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { struct whisper_full_params result = { /*.strategy =*/ strategy, diff --git a/whisper.h b/whisper.h index 2d5b3eb..e983c7d 100644 --- a/whisper.h +++ b/whisper.h @@ -113,6 +113,7 @@ extern "C" { // Frees all allocated memory WHISPER_API void whisper_free (struct whisper_context * ctx); WHISPER_API void whisper_free_state(struct whisper_state * state); + WHISPER_API void whisper_free_params(struct whisper_full_params * params); // Convert RAW PCM audio to log mel spectrogram. // The resulting spectrogram is stored inside the default state of the provided whisper context. @@ -409,6 +410,8 @@ extern "C" { void * logits_filter_callback_user_data; }; + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params() + WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy); WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text