diff --git a/examples/common.cpp b/examples/common.cpp index fd6dbc0e3..476d56594 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -586,7 +586,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param lparams.n_batch = params.n_batch; lparams.n_gpu_layers = params.n_gpu_layers; lparams.main_gpu = params.main_gpu; - memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float)); + lparams.tensor_split = params.tensor_split; lparams.low_vram = params.low_vram; lparams.seed = params.seed; lparams.f16_kv = params.memory_f16; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d3054a7fa..6537897b9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2512,6 +2512,9 @@ void ggml_init_cublas() { } void ggml_cuda_set_tensor_split(const float * tensor_split) { + if (tensor_split == nullptr) { + return; + } bool all_zero = true; for (int i = 0; i < g_device_count; ++i) { if (tensor_split[i] != 0.0f) { diff --git a/llama.cpp b/llama.cpp index 796dfdacb..23e746d62 100644 --- a/llama.cpp +++ b/llama.cpp @@ -849,7 +849,7 @@ struct llama_context_params llama_context_default_params() { /*.n_batch =*/ 512, /*.gpu_layers =*/ 0, /*.main_gpu =*/ 0, - /*.tensor_split =*/ {0}, + /*.tensor_split =*/ nullptr, /*.rope_freq_base =*/ 10000.0f, /*.rope_freq_scale =*/ 1.0f, /*.progress_callback =*/ nullptr, @@ -1289,7 +1289,7 @@ static bool llama_model_load( int n_batch, int n_gpu_layers, int main_gpu, - float * tensor_split, + const float * tensor_split, float rope_freq_base, float rope_freq_scale, bool low_vram, diff --git a/llama.h b/llama.h index b676a383b..c565f6a00 100644 --- a/llama.h +++ b/llama.h @@ -88,7 +88,8 @@ extern "C" { int32_t n_batch; // prompt processing batch size int32_t n_gpu_layers; // number of layers to store in VRAM int32_t main_gpu; // the GPU that is used for scratch and small tensors - float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs + + const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency