whisper : expose CUDA device setting in public API (#1840)

* Makefile : allow to override CUDA_ARCH_FLAG

* whisper : allow to select GPU (CUDA) device from public API
pull/1768/merge
Didzis Gosko 2024-02-09 17:27:47 +02:00 committed by GitHub
parent b6559333ff
commit 0f80e5a80a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 3 deletions

View File

@ -215,9 +215,9 @@ endif
ifdef WHISPER_CUBLAS
ifeq ($(shell expr $(NVCC_VERSION) \>= 11.6), 1)
CUDA_ARCH_FLAG=native
CUDA_ARCH_FLAG ?= native
else
CUDA_ARCH_FLAG=all
CUDA_ARCH_FLAG ?= all
endif
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include

View File

@ -1060,7 +1060,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
#ifdef GGML_USE_CUBLAS
if (params.use_gpu && ggml_cublas_loaded()) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init(0);
backend_gpu = ggml_backend_cuda_init(params.gpu_device);
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
}
@ -3213,6 +3213,7 @@ int whisper_ctx_init_openvino_encoder(
struct whisper_context_params whisper_context_default_params() {
struct whisper_context_params result = {
/*.use_gpu =*/ true,
/*.gpu_device =*/ 0,
};
return result;
}

View File

@ -86,6 +86,7 @@ extern "C" {
struct whisper_context_params {
bool use_gpu;
int gpu_device; // CUDA device
};
typedef struct whisper_token_data {