diff --git a/.gitignore b/.gitignore index 08aa2a8..503a473 100644 --- a/.gitignore +++ b/.gitignore @@ -1,43 +1,46 @@ -*.o -*.a -.cache/ -.coreml/ -.test/ -.vs/ -.vscode/ -.DS_Store - -build/ -build-em/ -build-debug/ -build-release/ -build-static/ -build-cublas/ -build-no-accel/ -build-sanitize-addr/ -build-sanitize-thread/ - -/main -/stream -/command -/talk -/talk-llama -/bench -/quantize - -arm_neon.h -sync.sh -libwhisper.a -libwhisper.so -compile_commands.json - -examples/arm_neon.h -examples/whisper.objc/whisper.objc.xcodeproj/xcshareddata -examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/ -examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata - -extra/bench-gg.txt - -models/*.mlmodel -models/*.mlmodelc -models/*.mlpackage +*.o +*.a +.cache/ +.coreml/ +.test/ +.vs/ +.vscode/ +.DS_Store + +build/ +build-em/ +build-debug/ +build-release/ +build-static/ +build-cublas/ +build-no-accel/ +build-sanitize-addr/ +build-sanitize-thread/ + +/main +/stream +/command +/talk +/talk-llama +/bench +/quantize + +arm_neon.h +sync.sh +libwhisper.a +libwhisper.so +compile_commands.json + +examples/arm_neon.h +examples/whisper.objc/whisper.objc.xcodeproj/xcshareddata +examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/ +examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata + +extra/bench-gg.txt + +models/*.mlmodel +models/*.mlmodelc +models/*.mlpackage +bindings/java/.gradle/ +bindings/java/.idea/ +.idea/ diff --git a/README.md b/README.md index 75d6f81..7d76f90 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ Supported platforms: - [x] Mac OS (Intel and Arm) - [x] [iOS](examples/whisper.objc) - [x] [Android](examples/whisper.android) +- [x] [Java](bindings/java/README.md) - [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264) - [x] [WebAssembly](examples/whisper.wasm) - [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)] diff --git a/bindings/java/.idea/uiDesigner.xml b/bindings/java/.idea/uiDesigner.xml new file mode 100644 index 0000000..6d50cd4 --- /dev/null +++ b/bindings/java/.idea/uiDesigner.xml @@ -0,0 +1,124 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/bindings/java/CMakeLists.txt b/bindings/java/CMakeLists.txt new file mode 100644 index 0000000..7e47bb3 --- /dev/null +++ b/bindings/java/CMakeLists.txt @@ -0,0 +1,50 @@ +cmake_minimum_required(VERSION 3.10) + +project(whisper_java VERSION 1.4.2) + +# Set the target name and source file/s +set(TARGET_NAME whisper_java) +set(SOURCES src/main/cpp/whisper_java.cpp) + +# include +include_directories(../../) + +# Set the output directory for the DLL/shared library based on the platform as required by JNA +if(WIN32) + set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/win32-x86-64) +elseif(UNIX AND NOT APPLE) + set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/linux-x86-64) +elseif(APPLE) + set(OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated/resources/main/macos-x86-64) +endif() + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OUTPUT_DIR}) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${OUTPUT_DIR}) + +# Create the whisper_java library +add_library(${TARGET_NAME} SHARED ${SOURCES}) + +# Link against ../../build/Release/whisper.dll (or so/dynlib) +target_link_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/../../../build/${CMAKE_BUILD_TYPE}) +target_link_libraries(${TARGET_NAME} PRIVATE whisper) + +# Set the appropriate compiler flags for Windows, Linux, and macOS +if(WIN32) + target_compile_options(${TARGET_NAME} PRIVATE /W4 /D_CRT_SECURE_NO_WARNINGS) +elseif(UNIX AND NOT APPLE) + target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra) +elseif(APPLE) + target_compile_options(${TARGET_NAME} PRIVATE -Wall -Wextra) +endif() + +target_compile_definitions(${TARGET_NAME} PRIVATE WHISPER_SHARED) +# add_definitions(-DWHISPER_SHARED) + +# Force CMake to save the libs to build/generated/resources/main/${os}-${arch} as required by JNA +foreach(OUTPUTCONFIG ${CMAKE_CONFIGURATION_TYPES}) + string(TOUPPER ${OUTPUTCONFIG} OUTPUTCONFIG) + set_target_properties(${TARGET_NAME} PROPERTIES + RUNTIME_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR} + LIBRARY_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR} + ARCHIVE_OUTPUT_DIRECTORY_${OUTPUTCONFIG} ${OUTPUT_DIR}) +endforeach(OUTPUTCONFIG CMAKE_CONFIGURATION_TYPES) diff --git a/bindings/java/README.md b/bindings/java/README.md new file mode 100644 index 0000000..429287c --- /dev/null +++ b/bindings/java/README.md @@ -0,0 +1,63 @@ +# Java JNI bindings for Whisper + +This package provides Java JNI bindings for whisper.cpp. They have been tested on: + + * Darwin (OS X) 12.6 on x64_64 + * Ubuntu on x86_64 + * Windows on x86_64 + +The "low level" bindings are in `WhisperCppJnaLibrary` and `WhisperJavaJnaLibrary` which caches `whisper_full_params` and `whisper_context` in `whisper_java.cpp`. + +There are a lot of classes in the `callbacks`, `ggml`, `model` and `params` directories but most of them have not been tested. + +The most simple usage is as follows: + +```java +import io.github.ggerganov.whispercpp.WhisperCpp; + +public class Example { + + public static void main(String[] args) { + String modelpath; + WhisperCpp whisper = new WhisperCpp(); + // By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin" + // or you can provide the absolute path to the model file. + whisper.initContext("base.en"); + + long context = whisper.initContext(modelpath); + try { + whisper.fullTranscribe(context, samples); + + int segmentCount = whisper.getTextSegmentCount(context); + for (int i = 0; i < segmentCount; i++) { + String text = whisper.getTextSegment(context, i); + System.out.println(segment.getText()); + } + } finally { + whisper.freeContext(context); + } + } +} +``` + +## Building & Testing + +In order to build, you need to have the JDK 8 or higher installed. Run the tests with: + +```bash +git clone https://github.com/ggerganov/whisper.cpp.git +cd whisper.cpp/bindings/java + +mkdir build +pushd build +cmake .. +cmake --build . +popd + +./gradlew build +``` + +## License + +The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details. + diff --git a/bindings/java/build.gradle b/bindings/java/build.gradle new file mode 100644 index 0000000..4a9b02f --- /dev/null +++ b/bindings/java/build.gradle @@ -0,0 +1,104 @@ +plugins { + id 'java' + id 'java-library' + id 'maven-publish' +} + +archivesBaseName = 'whispercpp' +group = 'io.github.ggerganov' +version = '1.4.0' + +sourceCompatibility = 1.8 +targetCompatibility = 1.8 + +sourceSets { + main { + resources { + srcDirs = ['src/main/resources', 'build/generated/resources/main'] + } + } + test { + runtimeClasspath += files('build/generated/resources/main') + } +} + +tasks.register('copyLibwhisperSo', Copy) { + from '../../build' + include 'libwhisper.so' + into 'build/generated/resources/main/linux-x86-64' +} + +tasks.register('copyWhisperDll', Copy) { + from '../../build/Release' + include 'whisper.dll' + into 'build/generated/resources/main/windows-x86-64' +} + +tasks.build.dependsOn copyLibwhisperSo, copyWhisperDll + +test { + systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath +} + +java { + withSourcesJar() + withJavadocJar() +} + +jar { + exclude '**/whisper_java.exp', '**/whisper_java.lib' +} + +javadoc { + options.addStringOption('Xdoclint:none', '-quiet') +} + +tasks.withType(Test) { + useJUnitPlatform() +} + +dependencies { + implementation "net.java.dev.jna:jna:5.13.0" + testImplementation "org.junit.jupiter:junit-jupiter:5.9.2" + testImplementation "org.assertj:assertj-core:3.24.2" +} + +repositories { + mavenCentral() +} + +publishing { + publications { + mavenJava(MavenPublication) { + artifactId = 'whispercpp' + from components.java + pom { + name = 'whispercpp' + description = "Java JNA bindings for OpenAI's Whisper model, implemented in C/C++" + url = 'https://github.com/ggerganov/whisper.cpp' + licenses { + license { + name = 'MIT licence' + url = 'https://raw.githubusercontent.com/ggerganov/whisper.cpp/master/LICENSE' + } + } + developers { + developer { + id = 'ggerganov' + name = 'Georgi Gerganov' + email = 'ggerganov@gmail.com' + } + developer { + id = 'nalbion' + name = 'Nicholas Albion' + email = 'nalbion@yahoo.com' + } + } + scm { + connection = 'scm:git:git://github.com/ggerganov/whisper.cpp.git' + url = 'https://github.com/ggerganov/whisper.cpp' + } + } + } + } +} diff --git a/bindings/java/gradle.properties b/bindings/java/gradle.properties new file mode 100644 index 0000000..3ea68c2 --- /dev/null +++ b/bindings/java/gradle.properties @@ -0,0 +1,6 @@ +org.gradle.jvmargs=-Xms256m -Xmx1024m +system.include.dir=/usr/include +#system.local.include.dir=../../include +system.local.include.dir=./build/generated/sources/headers/java/main +jni.include.dir=/usr/lib/jvm/java-8-openjdk-amd64/include/ +jni.lib.dir=/usr/lib/jvm/java-8-openjdk-amd64/lib/ diff --git a/bindings/java/gradle/wrapper/gradle-wrapper.jar b/bindings/java/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000..ccebba7 Binary files /dev/null and b/bindings/java/gradle/wrapper/gradle-wrapper.jar differ diff --git a/bindings/java/gradle/wrapper/gradle-wrapper.properties b/bindings/java/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000..0c85a1f --- /dev/null +++ b/bindings/java/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.1-bin.zip +networkTimeout=10000 +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/bindings/java/gradlew b/bindings/java/gradlew new file mode 100644 index 0000000..79a61d4 --- /dev/null +++ b/bindings/java/gradlew @@ -0,0 +1,244 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/bindings/java/gradlew.bat b/bindings/java/gradlew.bat new file mode 100644 index 0000000..6689b85 --- /dev/null +++ b/bindings/java/gradlew.bat @@ -0,0 +1,92 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/bindings/java/settings.gradle b/bindings/java/settings.gradle new file mode 100644 index 0000000..dbc6f38 --- /dev/null +++ b/bindings/java/settings.gradle @@ -0,0 +1 @@ +rootProject.name = "whispercpp" diff --git a/bindings/java/src/main/cpp/whisper_java.cpp b/bindings/java/src/main/cpp/whisper_java.cpp new file mode 100644 index 0000000..9e06aa0 --- /dev/null +++ b/bindings/java/src/main/cpp/whisper_java.cpp @@ -0,0 +1,33 @@ +#include +#include "whisper_java.h" + +struct whisper_full_params default_params; +struct whisper_context * whisper_ctx = nullptr; + +struct void whisper_java_default_params(enum whisper_sampling_strategy strategy) { + default_params = whisper_full_default_params(strategy); + +// struct whisper_java_params result = {}; +// return result; + return; +} + +void whisper_java_init_from_file(const char * path_model) { + whisper_ctx = whisper_init_from_file(path_model); + if (0 == default_params.n_threads) { + whisper_java_default_params(WHISPER_SAMPLING_GREEDY); + } +} + +/** Delegates to whisper_full, but without having to pass `whisper_full_params` */ +int whisper_java_full( + struct whisper_context * ctx, +// struct whisper_java_params params, + const float * samples, + int n_samples) { + return whisper_full(ctx, default_params, samples, n_samples); +} + +void whisper_java_free() { +// free(default_params); +} diff --git a/bindings/java/src/main/cpp/whisper_java.h b/bindings/java/src/main/cpp/whisper_java.h new file mode 100644 index 0000000..d64866b --- /dev/null +++ b/bindings/java/src/main/cpp/whisper_java.h @@ -0,0 +1,24 @@ +#define WHISPER_BUILD +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct whisper_java_params { +}; + +WHISPER_API void whisper_java_default_params(enum whisper_sampling_strategy strategy); + +WHISPER_API void whisper_java_init_from_file(const char * path_model); + +WHISPER_API int whisper_java_full( + struct whisper_context * ctx, +// struct whisper_java_params params, + const float * samples, + int n_samples); + + +#ifdef __cplusplus +} +#endif diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java new file mode 100644 index 0000000..22d4ce8 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java @@ -0,0 +1,39 @@ +package io.github.ggerganov.whispercpp; + +import com.sun.jna.Structure; +import com.sun.jna.ptr.PointerByReference; +import io.github.ggerganov.whispercpp.ggml.GgmlType; +import io.github.ggerganov.whispercpp.WhisperModel; + +import java.util.List; + +public class WhisperContext extends Structure { + int t_load_us = 0; + int t_start_us = 0; + + /** weight type (FP32 / FP16 / QX) */ + GgmlType wtype = GgmlType.GGML_TYPE_F16; + /** intermediate type (FP32 or FP16) */ + GgmlType itype = GgmlType.GGML_TYPE_F16; + +// WhisperModel model; + public PointerByReference model; +// whisper_vocab vocab; +// whisper_state * state = nullptr; + public PointerByReference vocab; + public PointerByReference state; + + /** populated by whisper_init_from_file() */ + String path_model; + +// public static class ByReference extends WhisperContext implements Structure.ByReference { +// } +// +// public static class ByValue extends WhisperContext implements Structure.ByValue { +// } +// +// @Override +// protected List getFieldOrder() { +// return List.of("t_load_us", "t_start_us", "wtype", "itype", "model", "vocab", "state", "path_model"); +// } +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java new file mode 100644 index 0000000..f014407 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java @@ -0,0 +1,124 @@ +package io.github.ggerganov.whispercpp; + +import com.sun.jna.Pointer; +import io.github.ggerganov.whispercpp.params.WhisperJavaParams; +import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; + +/** + * Before calling most methods, you must call `initContext(modelPath)` to initialise the `ctx` Pointer. + */ +public class WhisperCpp implements AutoCloseable { + private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance; + private WhisperJavaJnaLibrary javaLib = WhisperJavaJnaLibrary.instance; + private Pointer ctx = null; + + public File modelDir() { + String modelDirPath = System.getenv("XDG_CACHE_HOME"); + if (modelDirPath == null) { + modelDirPath = System.getProperty("user.home") + "/.cache"; + } + + return new File(modelDirPath, "whisper"); + } + + /** + * @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en") + * @return a Pointer to the WhisperContext + */ + void initContext(String modelPath) throws FileNotFoundException { + if (ctx != null) { + lib.whisper_free(ctx); + } + + if (!modelPath.contains("/") && !modelPath.contains("\\")) { + if (!modelPath.endsWith(".bin")) { + modelPath = "ggml-" + modelPath.replace("-", ".") + ".bin"; + } + + modelPath = new File(modelDir(), modelPath).getAbsolutePath(); + } + + javaLib.whisper_java_init_from_file(modelPath); + ctx = lib.whisper_init_from_file(modelPath); + + if (ctx == null) { + throw new FileNotFoundException(modelPath); + } + } + + /** + * Initialises `whisper_full_params` internally in whisper_java.cpp so JNA doesn't have to map everything. + * `whisper_java_init_from_file()` calls `whisper_java_default_params(WHISPER_SAMPLING_GREEDY)` for convenience. + */ + public void getDefaultJavaParams(WhisperSamplingStrategy strategy) { + javaLib.whisper_java_default_params(strategy.ordinal()); +// return lib.whisper_full_default_params(strategy.value) + } + +// whisper_full_default_params was too hard to integrate with, so for now we use javaLib.whisper_java_default_params +// fun getDefaultParams(strategy: WhisperSamplingStrategy): WhisperFullParams { +// return lib.whisper_full_default_params(strategy.value) +// } + + @Override + public void close() { + freeContext(); + System.out.println("Whisper closed"); + } + + private void freeContext() { + if (ctx != null) { + lib.whisper_free(ctx); + } + } + + /** + * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text. + * Not thread safe for same context + * Uses the specified decoding strategy to obtain the text. + */ + public String fullTranscribe(/*WhisperJavaParams whisperParams,*/ float[] audioData) throws IOException { + if (ctx == null) { + throw new IllegalStateException("Model not initialised"); + } + + if (javaLib.whisper_java_full(ctx, /*whisperParams,*/ audioData, audioData.length) != 0) { + throw new IOException("Failed to process audio"); + } + + int nSegments = lib.whisper_full_n_segments(ctx); + + StringBuilder str = new StringBuilder(); + + for (int i = 0; i < nSegments; i++) { + String text = lib.whisper_full_get_segment_text(ctx, i); + System.out.println("Segment:" + text); + str.append(text); + } + + return str.toString().trim(); + } + +// public int getTextSegmentCount(Pointer ctx) { +// return lib.whisper_full_n_segments(ctx); +// } +// public String getTextSegment(Pointer ctx, int index) { +// return lib.whisper_full_get_segment_text(ctx, index); +// } + + public String getSystemInfo() { + return lib.whisper_print_system_info(); + } + + public int benchMemcpy(int nthread) { + return lib.whisper_bench_memcpy(nthread); + } + + public int benchGgmlMulMat(int nthread) { + return lib.whisper_bench_ggml_mul_mat(nthread); + } +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java new file mode 100644 index 0000000..6602565 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java @@ -0,0 +1,365 @@ +package io.github.ggerganov.whispercpp; + +import com.sun.jna.Library; +import com.sun.jna.Native; +import com.sun.jna.Pointer; +import io.github.ggerganov.whispercpp.model.WhisperModelLoader; +import io.github.ggerganov.whispercpp.model.WhisperTokenData; +import io.github.ggerganov.whispercpp.params.WhisperFullParams; + +public interface WhisperCppJnaLibrary extends Library { + WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class); + + String whisper_print_system_info(); + + /** + * Allocate (almost) all memory needed for the model by loading from a file. + * + * @param path_model Path to the model file + * @return Whisper context on success, null on failure + */ + Pointer whisper_init_from_file(String path_model); + + /** + * Allocate (almost) all memory needed for the model by loading from a buffer. + * + * @param buffer Model buffer + * @param buffer_size Size of the model buffer + * @return Whisper context on success, null on failure + */ + Pointer whisper_init_from_buffer(Pointer buffer, int buffer_size); + + /** + * Allocate (almost) all memory needed for the model using a model loader. + * + * @param loader Model loader + * @return Whisper context on success, null on failure + */ + Pointer whisper_init(WhisperModelLoader loader); + + /** + * Allocate (almost) all memory needed for the model by loading from a file without allocating the state. + * + * @param path_model Path to the model file + * @return Whisper context on success, null on failure + */ + Pointer whisper_init_from_file_no_state(String path_model); + + /** + * Allocate (almost) all memory needed for the model by loading from a buffer without allocating the state. + * + * @param buffer Model buffer + * @param buffer_size Size of the model buffer + * @return Whisper context on success, null on failure + */ + Pointer whisper_init_from_buffer_no_state(Pointer buffer, int buffer_size); + +// Pointer whisper_init_from_buffer_no_state(Pointer buffer, long buffer_size); + + /** + * Allocate (almost) all memory needed for the model using a model loader without allocating the state. + * + * @param loader Model loader + * @return Whisper context on success, null on failure + */ + Pointer whisper_init_no_state(WhisperModelLoader loader); + + /** + * Allocate memory for the Whisper state. + * + * @param ctx Whisper context + * @return Whisper state on success, null on failure + */ + Pointer whisper_init_state(Pointer ctx); + + /** + * Free all allocated memory associated with the Whisper context. + * + * @param ctx Whisper context + */ + void whisper_free(Pointer ctx); + + /** + * Free all allocated memory associated with the Whisper state. + * + * @param state Whisper state + */ + void whisper_free_state(Pointer state); + + + /** + * Convert RAW PCM audio to log mel spectrogram. + * The resulting spectrogram is stored inside the default state of the provided whisper context. + * + * @param ctx - Pointer to a WhisperContext + * @return 0 on success + */ + int whisper_pcm_to_mel(Pointer ctx, final float[] samples, int n_samples, int n_threads); + + /** + * @param ctx Pointer to a WhisperContext + * @param state Pointer to WhisperState + * @param n_samples + * @param n_threads + * @return 0 on success + */ + int whisper_pcm_to_mel_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads); + + /** + * This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. + * Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. + * n_mel must be 80 + * @return 0 on success + */ + int whisper_set_mel(Pointer ctx, final float[] data, int n_len, int n_mel); + int whisper_set_mel_with_state(Pointer ctx, Pointer state, final float[] data, int n_len, int n_mel); + + /** + * Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context. + * Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. + * Offset can be used to specify the offset of the first frame in the spectrogram. + * @return 0 on success + */ + int whisper_encode(Pointer ctx, int offset, int n_threads); + + int whisper_encode_with_state(Pointer ctx, Pointer state, int offset, int n_threads); + + /** + * Run the Whisper decoder to obtain the logits and probabilities for the next token. + * Make sure to call whisper_encode() first. + * tokens + n_tokens is the provided context for the decoder. + * n_past is the number of tokens to use from previous decoder calls. + * Returns 0 on success + * TODO: add support for multiple decoders + */ + int whisper_decode(Pointer ctx, Pointer tokens, int n_tokens, int n_past, int n_threads); + + /** + * @param ctx + * @param state + * @param tokens Pointer to int tokens + * @param n_tokens + * @param n_past + * @param n_threads + * @return + */ + int whisper_decode_with_state(Pointer ctx, Pointer state, Pointer tokens, int n_tokens, int n_past, int n_threads); + + /** + * Convert the provided text into tokens. + * The tokens pointer must be large enough to hold the resulting tokens. + * Returns the number of tokens on success, no more than n_max_tokens + * Returns -1 on failure + * TODO: not sure if correct + */ + int whisper_tokenize(Pointer ctx, String text, Pointer tokens, int n_max_tokens); + + /** Largest language id (i.e. number of available languages - 1) */ + int whisper_lang_max_id(); + + /** + * @return the id of the specified language, returns -1 if not found. + * Examples: + * "de" -> 2 + * "german" -> 2 + */ + int whisper_lang_id(String lang); + + /** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */ + String whisper_lang_str(int id); + + /** + * Use mel data at offset_ms to try and auto-detect the spoken language. + * Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first + * Returns the top language id or negative on failure + * If not null, fills the lang_probs array with the probabilities of all languages + * The array must be whisper_lang_max_id() + 1 in size + * + * ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 + */ + int whisper_lang_auto_detect(Pointer ctx, int offset_ms, int n_threads, float[] lang_probs); + + int whisper_lang_auto_detect_with_state(Pointer ctx, Pointer state, int offset_ms, int n_threads, float[] lang_probs); + + int whisper_n_len (Pointer ctx); // mel length + int whisper_n_len_from_state(Pointer state); // mel length + int whisper_n_vocab (Pointer ctx); + int whisper_n_text_ctx (Pointer ctx); + int whisper_n_audio_ctx (Pointer ctx); + int whisper_is_multilingual (Pointer ctx); + + int whisper_model_n_vocab (Pointer ctx); + int whisper_model_n_audio_ctx (Pointer ctx); + int whisper_model_n_audio_state(Pointer ctx); + int whisper_model_n_audio_head (Pointer ctx); + int whisper_model_n_audio_layer(Pointer ctx); + int whisper_model_n_text_ctx (Pointer ctx); + int whisper_model_n_text_state (Pointer ctx); + int whisper_model_n_text_head (Pointer ctx); + int whisper_model_n_text_layer (Pointer ctx); + int whisper_model_n_mels (Pointer ctx); + int whisper_model_ftype (Pointer ctx); + int whisper_model_type (Pointer ctx); + + /** + * Token logits obtained from the last call to whisper_decode(). + * The logits for the last token are stored in the last row + * Rows: n_tokens + * Cols: n_vocab + */ + float[] whisper_get_logits (Pointer ctx); + float[] whisper_get_logits_from_state(Pointer state); + + // Token Id -> String. Uses the vocabulary in the provided context + String whisper_token_to_str(Pointer ctx, int token); + String whisper_model_type_readable(Pointer ctx); + + // Special tokens + int whisper_token_eot (Pointer ctx); + int whisper_token_sot (Pointer ctx); + int whisper_token_prev(Pointer ctx); + int whisper_token_solm(Pointer ctx); + int whisper_token_not (Pointer ctx); + int whisper_token_beg (Pointer ctx); + int whisper_token_lang(Pointer ctx, int lang_id); + + // Task tokens + int whisper_token_translate(); + int whisper_token_transcribe(); + + // Performance information from the default state. + void whisper_print_timings(Pointer ctx); + void whisper_reset_timings(Pointer ctx); + + /** + * @param strategy - WhisperSamplingStrategy.value + */ + WhisperFullParams whisper_full_default_params(int strategy); + + /** + * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + * Not thread safe for same context + * Uses the specified decoding strategy to obtain the text. + */ + int whisper_full(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples); + + int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples); + + // Split the input audio in chunks and process each chunk separately using whisper_full_with_state() + // Result is stored in the default state of the context + // Not thread safe if executed in parallel on the same context. + // It seems this approach can offer some speedup in some cases. + // However, the transcription accuracy can be worse at the beginning and end of each chunk. + int whisper_full_parallel(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples, int n_processors); + + /** + * Number of generated text segments. + * A segment can be a few words, a sentence, or even a paragraph. + * @param ctx Pointer to WhisperContext + */ + int whisper_full_n_segments (Pointer ctx); + + /** + * @param state Pointer to WhisperState + */ + int whisper_full_n_segments_from_state(Pointer state); + + /** + * Language id associated with the context's default state. + * @param ctx Pointer to WhisperContext + */ + int whisper_full_lang_id(Pointer ctx); + + /** Language id associated with the provided state */ + int whisper_full_lang_id_from_state(Pointer state); + + /** + * Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. + * The resulting spectrogram is stored inside the default state of the provided whisper context. + * @return 0 on success + */ + int whisper_pcm_to_mel_phase_vocoder(Pointer ctx, final float[] samples, int n_samples, int n_threads); + + int whisper_pcm_to_mel_phase_vocoder_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads); + + /** Get the start time of the specified segment. */ + long whisper_full_get_segment_t0(Pointer ctx, int i_segment); + + /** Get the start time of the specified segment from the state. */ + long whisper_full_get_segment_t0_from_state(Pointer state, int i_segment); + + /** Get the end time of the specified segment. */ + long whisper_full_get_segment_t1(Pointer ctx, int i_segment); + + /** Get the end time of the specified segment from the state. */ + long whisper_full_get_segment_t1_from_state(Pointer state, int i_segment); + + /** Get the text of the specified segment. */ + String whisper_full_get_segment_text(Pointer ctx, int i_segment); + + /** Get the text of the specified segment from the state. */ + String whisper_full_get_segment_text_from_state(Pointer state, int i_segment); + + /** Get the number of tokens in the specified segment. */ + int whisper_full_n_tokens(Pointer ctx, int i_segment); + + /** Get the number of tokens in the specified segment from the state. */ + int whisper_full_n_tokens_from_state(Pointer state, int i_segment); + + /** Get the token text of the specified token in the specified segment. */ + String whisper_full_get_token_text(Pointer ctx, int i_segment, int i_token); + + + /** Get the token text of the specified token in the specified segment from the state. */ + String whisper_full_get_token_text_from_state(Pointer ctx, Pointer state, int i_segment, int i_token); + + /** Get the token ID of the specified token in the specified segment. */ + int whisper_full_get_token_id(Pointer ctx, int i_segment, int i_token); + + /** Get the token ID of the specified token in the specified segment from the state. */ + int whisper_full_get_token_id_from_state(Pointer state, int i_segment, int i_token); + + /** Get token data for the specified token in the specified segment. */ + WhisperTokenData whisper_full_get_token_data(Pointer ctx, int i_segment, int i_token); + + /** Get token data for the specified token in the specified segment from the state. */ + WhisperTokenData whisper_full_get_token_data_from_state(Pointer state, int i_segment, int i_token); + + /** Get the probability of the specified token in the specified segment. */ + float whisper_full_get_token_p(Pointer ctx, int i_segment, int i_token); + + /** Get the probability of the specified token in the specified segment from the state. */ + float whisper_full_get_token_p_from_state(Pointer state, int i_segment, int i_token); + + /** + * Benchmark function for memcpy. + * + * @param nThreads Number of threads to use for the benchmark. + * @return The result of the benchmark. + */ + int whisper_bench_memcpy(int nThreads); + + /** + * Benchmark function for memcpy as a string. + * + * @param nThreads Number of threads to use for the benchmark. + * @return The result of the benchmark as a string. + */ + String whisper_bench_memcpy_str(int nThreads); + + /** + * Benchmark function for ggml_mul_mat. + * + * @param nThreads Number of threads to use for the benchmark. + * @return The result of the benchmark. + */ + int whisper_bench_ggml_mul_mat(int nThreads); + + /** + * Benchmark function for ggml_mul_mat as a string. + * + * @param nThreads Number of threads to use for the benchmark. + * @return The result of the benchmark as a string. + */ + String whisper_bench_ggml_mul_mat_str(int nThreads); +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java new file mode 100644 index 0000000..74f8459 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperJavaJnaLibrary.java @@ -0,0 +1,23 @@ +package io.github.ggerganov.whispercpp; + +import com.sun.jna.Library; +import com.sun.jna.Native; +import com.sun.jna.Pointer; +import io.github.ggerganov.whispercpp.params.WhisperJavaParams; + +interface WhisperJavaJnaLibrary extends Library { + WhisperJavaJnaLibrary instance = Native.load("whisper_java", WhisperJavaJnaLibrary.class); + + void whisper_java_default_params(int strategy); + + void whisper_java_free(); + + void whisper_java_init_from_file(String modelPath); + + /** + * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text. + * Not thread safe for same context + * Uses the specified decoding strategy to obtain the text. + */ + int whisper_java_full(Pointer ctx, /*WhisperJavaParams params, */float[] samples, int nSamples); +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java new file mode 100644 index 0000000..b5e9797 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperEncoderBeginCallback.java @@ -0,0 +1,24 @@ +package io.github.ggerganov.whispercpp.callbacks; + +import com.sun.jna.Callback; +import com.sun.jna.Pointer; +import io.github.ggerganov.whispercpp.WhisperContext; +import io.github.ggerganov.whispercpp.model.WhisperState; + +/** + * Callback before the encoder starts. + * If not null, called before the encoder starts. + * If it returns false, the computation is aborted. + */ +public interface WhisperEncoderBeginCallback extends Callback { + + /** + * Callback method before the encoder starts. + * + * @param ctx The whisper context. + * @param state The whisper state. + * @param user_data User data. + * @return True if the computation should proceed, false otherwise. + */ + boolean callback(WhisperContext ctx, WhisperState state, Pointer user_data); +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java new file mode 100644 index 0000000..5377b4e --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperLogitsFilterCallback.java @@ -0,0 +1,28 @@ +package io.github.ggerganov.whispercpp.callbacks; + +import com.sun.jna.Pointer; +import io.github.ggerganov.whispercpp.WhisperContext; +import io.github.ggerganov.whispercpp.model.WhisperState; +import io.github.ggerganov.whispercpp.model.WhisperTokenData; + +import javax.security.auth.callback.Callback; + +/** + * Callback to filter logits. + * Can be used to modify the logits before sampling. + * If not null, called after applying temperature to logits. + */ +public interface WhisperLogitsFilterCallback extends Callback { + + /** + * Callback method to filter logits. + * + * @param ctx The whisper context. + * @param state The whisper state. + * @param tokens The array of whisper_token_data. + * @param n_tokens The number of tokens. + * @param logits The array of logits. + * @param user_data User data. + */ + void callback(WhisperContext ctx, WhisperState state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data); +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java new file mode 100644 index 0000000..95ca346 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperNewSegmentCallback.java @@ -0,0 +1,24 @@ +package io.github.ggerganov.whispercpp.callbacks; + +import com.sun.jna.Callback; +import com.sun.jna.Pointer; +import io.github.ggerganov.whispercpp.WhisperContext; +import io.github.ggerganov.whispercpp.model.WhisperState; + +/** + * Callback for the text segment. + * Called on every newly generated text segment. + * Use the whisper_full_...() functions to obtain the text segments. + */ +public interface WhisperNewSegmentCallback extends Callback { + + /** + * Callback method for the text segment. + * + * @param ctx The whisper context. + * @param state The whisper state. + * @param n_new The number of newly generated text segments. + * @param user_data User data. + */ + void callback(WhisperContext ctx, WhisperState state, int n_new, Pointer user_data); +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java new file mode 100644 index 0000000..8866215 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/WhisperProgressCallback.java @@ -0,0 +1,23 @@ +package io.github.ggerganov.whispercpp.callbacks; + +import com.sun.jna.Pointer; +import io.github.ggerganov.whispercpp.WhisperContext; +import io.github.ggerganov.whispercpp.model.WhisperState; + +import javax.security.auth.callback.Callback; + +/** + * Callback for progress updates. + */ +public interface WhisperProgressCallback extends Callback { + + /** + * Callback method for progress updates. + * + * @param ctx The whisper context. + * @param state The whisper state. + * @param progress The progress value. + * @param user_data User data. + */ + void callback(WhisperContext ctx, WhisperState state, int progress, Pointer user_data); +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/ggml/GgmlTensor.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/ggml/GgmlTensor.java new file mode 100644 index 0000000..2569957 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/ggml/GgmlTensor.java @@ -0,0 +1,4 @@ +package io.github.ggerganov.whispercpp.ggml; + +public class GgmlTensor { +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/ggml/GgmlType.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/ggml/GgmlType.java new file mode 100644 index 0000000..363120e --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/ggml/GgmlType.java @@ -0,0 +1,18 @@ +package io.github.ggerganov.whispercpp.ggml; + +public enum GgmlType { + GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + REMOVED_GGML_TYPE_Q4_2, // support has been removed + REMOVED_GGML_TYPE_Q4_3, // support has been removed + GGML_TYPE_Q5_0, + GGML_TYPE_Q5_1, + GGML_TYPE_Q8_0, + GGML_TYPE_Q8_1, + GGML_TYPE_I8, + GGML_TYPE_I16, + GGML_TYPE_I32, + GGML_TYPE_COUNT, +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/EModel.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/EModel.java new file mode 100644 index 0000000..b2475b3 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/EModel.java @@ -0,0 +1,10 @@ +package io.github.ggerganov.whispercpp.model; + +public enum EModel { + MODEL_UNKNOWN, + MODEL_TINY, + MODEL_BASE, + MODEL_SMALL, + MODEL_MEDIUM, + MODEL_LARGE, +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperModel.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperModel.java new file mode 100644 index 0000000..497ef42 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperModel.java @@ -0,0 +1,49 @@ +package io.github.ggerganov.whispercpp; + +import io.github.ggerganov.whispercpp.ggml.GgmlTensor; +import io.github.ggerganov.whispercpp.model.EModel; + +public class WhisperModel { +// EModel type = EModel.MODEL_UNKNOWN; +// +// WhisperHParams hparams; +// WhisperFilters filters; +// +// // encoder.positional_embedding +// GgmlTensor e_pe; +// +// // encoder.conv1 +// GgmlTensor e_conv_1_w; +// GgmlTensor e_conv_1_b; +// +// // encoder.conv2 +// GgmlTensor e_conv_2_w; +// GgmlTensor e_conv_2_b; +// +// // encoder.ln_post +// GgmlTensor e_ln_w; +// GgmlTensor e_ln_b; +// +// // decoder.positional_embedding +// GgmlTensor d_pe; +// +// // decoder.token_embedding +// GgmlTensor d_te; +// +// // decoder.ln +// GgmlTensor d_ln_w; +// GgmlTensor d_ln_b; +// +// std::vector layers_encoder; +// std::vector layers_decoder; +// +// // context +// struct ggml_context * ctx; +// +// // the model memory buffer is read-only and can be shared between processors +// std::vector * buf; +// +// // tensors +// int n_loaded; +// Map tensors; +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperModelLoader.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperModelLoader.java new file mode 100644 index 0000000..82615d9 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperModelLoader.java @@ -0,0 +1,62 @@ +package io.github.ggerganov.whispercpp.model; + +import com.sun.jna.Callback; +import com.sun.jna.Pointer; +import com.sun.jna.Structure; + + +public class WhisperModelLoader extends Structure { + public Pointer context; + public ReadFunction read; + public EOFFunction eof; + public CloseFunction close; + + public static class ReadFunction implements Callback { + public Pointer invoke(Pointer ctx, Pointer output, int readSize) { + // TODO + return ctx; + } + } + + public static class EOFFunction implements Callback { + public boolean invoke(Pointer ctx) { + // TODO + return false; + } + } + + public static class CloseFunction implements Callback { + public void invoke(Pointer ctx) { + // TODO + } + } + +// public WhisperModelLoader(Pointer p) { +// super(p); +// read = new ReadFunction(); +// eof = new EOFFunction(); +// close = new CloseFunction(); +// read.setCallback(this); +// eof.setCallback(this); +// close.setCallback(this); +// read.write(); +// eof.write(); +// close.write(); +// } + + public WhisperModelLoader() { + super(); + } + + public interface ReadCallback extends Callback { + Pointer invoke(Pointer ctx, Pointer output, int readSize); + } + + public interface EOFCallback extends Callback { + boolean invoke(Pointer ctx); + } + + public interface CloseCallback extends Callback { + void invoke(Pointer ctx); + } +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperState.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperState.java new file mode 100644 index 0000000..af93772 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperState.java @@ -0,0 +1,4 @@ +package io.github.ggerganov.whispercpp.model; + +public class WhisperState { +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperTokenData.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperTokenData.java new file mode 100644 index 0000000..bfa83f9 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/model/WhisperTokenData.java @@ -0,0 +1,50 @@ +package io.github.ggerganov.whispercpp.model; + +import com.sun.jna.Structure; + +import java.util.Arrays; +import java.util.List; + +/** + * Structure representing token data. + */ +public class WhisperTokenData extends Structure { + + /** Token ID. */ + public int id; + + /** Forced timestamp token ID. */ + public int tid; + + /** Probability of the token. */ + public float p; + + /** Log probability of the token. */ + public float plog; + + /** Probability of the timestamp token. */ + public float pt; + + /** Sum of probabilities of all timestamp tokens. */ + public float ptsum; + + /** + * Start time of the token (token-level timestamp data). + * Do not use if you haven't computed token-level timestamps. + */ + public long t0; + + /** + * End time of the token (token-level timestamp data). + * Do not use if you haven't computed token-level timestamps. + */ + public long t1; + + /** Voice length of the token. */ + public float vlen; + + @Override + protected List getFieldOrder() { + return Arrays.asList("id", "tid", "p", "plog", "pt", "ptsum", "t0", "t1", "vlen"); + } +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFilters.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFilters.java new file mode 100644 index 0000000..b035243 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFilters.java @@ -0,0 +1,10 @@ +package io.github.ggerganov.whispercpp.params; + +import java.util.List; + +public class WhisperFilters { + int n_mel; + int n_fft; + + List data; +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java new file mode 100644 index 0000000..ea0bccf --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -0,0 +1,187 @@ +package io.github.ggerganov.whispercpp.params; + +import com.sun.jna.Callback; +import com.sun.jna.Pointer; +import com.sun.jna.Structure; +import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback; +import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback; +import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback; +import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback; + +/** + * Parameters for the whisper_full() function. + * If you change the order or add new parameters, make sure to update the default values in whisper.cpp: + * whisper_full_default_params() + */ +public class WhisperFullParams extends Structure { + + /** Sampling strategy for whisper_full() function. */ + public int strategy; + + /** Number of threads. */ + public int n_threads; + + /** Maximum tokens to use from past text as a prompt for the decoder. */ + public int n_max_text_ctx; + + /** Start offset in milliseconds. */ + public int offset_ms; + + /** Audio duration to process in milliseconds. */ + public int duration_ms; + + /** Translate flag. */ + public boolean translate; + + /** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. */ + public boolean no_context; + + /** Flag to force single segment output (useful for streaming). */ + public boolean single_segment; + + /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). */ + public boolean print_special; + + /** Flag to print progress information. */ + public boolean print_progress; + + /** Flag to print results from within whisper.cpp (avoid it, use callback instead). */ + public boolean print_realtime; + + /** Flag to print timestamps for each text segment when printing realtime. */ + public boolean print_timestamps; + + /** [EXPERIMENTAL] Flag to enable token-level timestamps. */ + public boolean token_timestamps; + + /** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). */ + public float thold_pt; + + /** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */ + public float thold_ptsum; + + /** Maximum segment length in characters. */ + public int max_len; + + /** Flag to split on word rather than on token (when used with max_len). */ + public boolean split_on_word; + + /** Maximum tokens per segment (0 = no limit). */ + public int max_tokens; + + /** Flag to speed up the audio by 2x using Phase Vocoder. */ + public boolean speed_up; + + /** Overwrite the audio context size (0 = use default). */ + public int audio_ctx; + + /** Tokens to provide to the whisper decoder as an initial prompt. + * These are prepended to any existing text context from a previous call. */ + public String initial_prompt; + + /** Prompt tokens. */ + public Pointer prompt_tokens; + + /** Number of prompt tokens. */ + public int prompt_n_tokens; + + /** Language for auto-detection. + * For auto-detection, set to `null`, `""`, or "auto". */ + public String language; + + /** Flag to indicate whether to detect language automatically. */ + public boolean detect_language; + + /** Common decoding parameters. */ + + /** Flag to suppress blank tokens. */ + public boolean suppress_blank; + + /** Flag to suppress non-speech tokens. */ + public boolean suppress_non_speech_tokens; + + /** Initial decoding temperature. */ + public float temperature; + + /** Maximum initial timestamp. */ + public float max_initial_ts; + + /** Length penalty. */ + public float length_penalty; + + /** Fallback parameters. */ + + /** Temperature increment. */ + public float temperature_inc; + + /** Entropy threshold (similar to OpenAI's "compression_ratio_threshold"). */ + public float entropy_thold; + + /** Log probability threshold. */ + public float logprob_thold; + + /** No speech threshold. */ + public float no_speech_thold; + + class GreedyParams extends Structure { + /** https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 */ + public int best_of; + } + + /** Greedy decoding parameters. */ + public GreedyParams greedy; + + class BeamSearchParams extends Structure { + /** ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 */ + int beam_size; + + /** ref: https://arxiv.org/pdf/2204.05424.pdf */ + float patience; + } + + /** + * Beam search decoding parameters. + */ + public BeamSearchParams beam_search; + + /** + * Callback for every newly generated text segment. + */ + public WhisperNewSegmentCallback new_segment_callback; + + /** + * User data for the new_segment_callback. + */ + public Pointer new_segment_callback_user_data; + + /** + * Callback on each progress update. + */ + public WhisperProgressCallback progress_callback; + + /** + * User data for the progress_callback. + */ + public Pointer progress_callback_user_data; + + /** + * Callback each time before the encoder starts. + */ + public WhisperEncoderBeginCallback encoder_begin_callback; + + /** + * User data for the encoder_begin_callback. + */ + public Pointer encoder_begin_callback_user_data; + + /** + * Callback by each decoder to filter obtained logits. + */ + public WhisperLogitsFilterCallback logits_filter_callback; + + /** + * User data for the logits_filter_callback. + */ + public Pointer logits_filter_callback_user_data; +} + diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperHParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperHParams.java new file mode 100644 index 0000000..99feae0 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperHParams.java @@ -0,0 +1,15 @@ +package io.github.ggerganov.whispercpp.params; + +public class WhisperHParams { + int n_vocab = 51864; + int n_audio_ctx = 1500; + int n_audio_state = 384; + int n_audio_head = 6; + int n_audio_layer = 4; + int n_text_ctx = 448; + int n_text_state = 384; + int n_text_head = 6; + int n_text_layer = 4; + int n_mels = 80; + int ftype = 1; +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java new file mode 100644 index 0000000..728485c --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperJavaParams.java @@ -0,0 +1,7 @@ +package io.github.ggerganov.whispercpp.params; + +import com.sun.jna.Structure; + +public class WhisperJavaParams extends Structure { + +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperSamplingStrategy.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperSamplingStrategy.java new file mode 100644 index 0000000..a32c793 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperSamplingStrategy.java @@ -0,0 +1,10 @@ +package io.github.ggerganov.whispercpp.params; + +/** Available sampling strategies */ +public enum WhisperSamplingStrategy { + /** similar to OpenAI's GreedyDecoder */ + WHISPER_SAMPLING_GREEDY, + + /** similar to OpenAI's BeamSearchDecoder */ + WHISPER_SAMPLING_BEAM_SEARCH +} diff --git a/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java new file mode 100644 index 0000000..98390aa --- /dev/null +++ b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java @@ -0,0 +1,75 @@ +package io.github.ggerganov.whispercpp; + +import static org.junit.jupiter.api.Assertions.*; + +import io.github.ggerganov.whispercpp.params.WhisperJavaParams; +import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import javax.sound.sampled.AudioInputStream; +import javax.sound.sampled.AudioSystem; +import java.io.File; +import java.io.FileNotFoundException; + +class WhisperCppTest { + private static WhisperCpp whisper = new WhisperCpp(); + private static boolean modelInitialised = false; + + @BeforeAll + static void init() throws FileNotFoundException { + // By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin" + // or you can provide the absolute path to the model file. + String modelName = "base.en"; + try { + whisper.initContext(modelName); + whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY); +// whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); + modelInitialised = true; + } catch (FileNotFoundException ex) { + System.out.println("Model " + modelName + " not found"); + } + } + + @Test + void testGetDefaultJavaParams() { + // When + whisper.getDefaultJavaParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); + + // Then if it doesn't throw we've connected to whisper.cpp + } + + @Test + void testFullTranscribe() throws Exception { + if (!modelInitialised) { + System.out.println("Model not initialised, skipping test"); + return; + } + + // Given + File file = new File(System.getProperty("user.dir"), "../../samples/jfk.wav"); + AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file); + + byte[] b = new byte[audioInputStream.available()]; + float[] floats = new float[b.length / 2]; + + try { + audioInputStream.read(b); + + for (int i = 0, j = 0; i < b.length; i += 2, j++) { + int intSample = (int) (b[i + 1]) << 8 | (int) (b[i]) & 0xFF; + floats[j] = intSample / 32767.0f; + } + + // When + String result = whisper.fullTranscribe(/*params,*/ floats); + + // Then + System.out.println(result); + assertEquals("And so my fellow Americans, ask not what your country can do for you, " + + "ask what you can do for your country.", + result); + } finally { + audioInputStream.close(); + } + } +} diff --git a/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperJnaLibraryTest.java b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperJnaLibraryTest.java new file mode 100644 index 0000000..07a340c --- /dev/null +++ b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperJnaLibraryTest.java @@ -0,0 +1,17 @@ +package io.github.ggerganov.whispercpp; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +class WhisperJnaLibraryTest { + + @Test + void testWhisperPrint_system_info() { + String systemInfo = WhisperCppJnaLibrary.instance.whisper_print_system_info(); + // eg: "AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 + // | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 | COREML = 0 | " + System.out.println("System info: " + systemInfo); + assertTrue(systemInfo.length() > 10); + } +}