diff --git a/common/common.cpp b/common/common.cpp index d7e1a5725..1623ba21f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -289,7 +289,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_batch = std::stoi(argv[i]); - params.n_batch = std::min(512, params.n_batch); } else if (arg == "--keep") { if (++i >= argc) { invalid_param = true; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c0fb9fb65..8ab29bb20 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3887,13 +3887,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, // rope == RoPE == rotary positional embedding static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0, const float p_delta, const int p_delta_rows, const float theta_scale) { - const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); + const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (col >= ncols) { return; } - const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int row = blockDim.x*blockIdx.x + threadIdx.x; const int i = row*ncols + col; const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); @@ -3965,8 +3965,8 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols, } static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { - const int col = blockDim.x*blockIdx.x + threadIdx.x; - const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int col = blockDim.y*blockIdx.y + threadIdx.y; + const int row = blockDim.x*blockIdx.x + threadIdx.x; if (col >= ncols) { return; @@ -3982,9 +3982,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int // values are also not normalized to the maximum value by subtracting it in the exponential function // theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int block_size = blockDim.x; - const int tid = threadIdx.x; + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int block_size = blockDim.y; + const int tid = threadIdx.y; float tmp = 0.0; @@ -4776,9 +4776,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(nrows % 2 == 0); - const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1); + const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); - const dim3 block_nums(num_blocks_x, nrows, 1); + const dim3 block_nums(nrows, num_blocks_x, 1); rope_f32<<>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); } @@ -4800,15 +4800,15 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const } static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { - const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1); + const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1); const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; - const dim3 block_nums(block_num_x, nrows_x, 1); + const dim3 block_nums(nrows_x, block_num_x, 1); diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { - const dim3 block_dims(WARP_SIZE, 1, 1); - const dim3 block_nums(1, nrows_x, 1); + const dim3 block_dims(1, WARP_SIZE, 1); + const dim3 block_nums(nrows_x, 1, 1); soft_max_f32<<>>(x, dst, ncols_x); } @@ -6313,7 +6313,7 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { return extra; } -void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) { +void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) { if (scratch && g_scratch_size == 0) { return; } @@ -6322,14 +6322,19 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) { const ggml_op src0_op = tensor->src[0]->op; if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) { - ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace); + ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc); } } if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) { - ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace); + ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc); } tensor->backend = GGML_BACKEND_GPU; + + if (scratch && no_alloc) { + return; + } + struct ggml_tensor_extra_gpu * extra; const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || @@ -6381,16 +6386,48 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo tensor->extra = extra; } +void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) { + if (g_scratch_size == 0) { + return; + } + if (g_scratch_buffer == nullptr) { + CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size)); + } + + struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra(); + + const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || + tensor->op == GGML_OP_VIEW; + + if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; + char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; + size_t view_offset = 0; + if (tensor->op == GGML_OP_VIEW) { + memcpy(&view_offset, tensor->op_params, sizeof(size_t)); + } + extra->data_device[g_main_device] = src0_ddc + view_offset; + } else { + extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset; + } + + tensor->extra = extra; +} + void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, true, false); + ggml_cuda_assign_buffers_impl(tensor, true, false, false); +} + +void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, true, false, true); } void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, false, false); + ggml_cuda_assign_buffers_impl(tensor, false, false, false); } void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, false, true); + ggml_cuda_assign_buffers_impl(tensor, false, true, false); } void ggml_cuda_set_main_device(int main_device) { diff --git a/ggml-cuda.h b/ggml-cuda.h index cad05f5fa..f66bb1678 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -16,9 +16,14 @@ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split); GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor); GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor); + GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor); GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor); + +GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor); +GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset); + GGML_API void ggml_cuda_set_main_device(int main_device); GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q); GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size); diff --git a/llama.cpp b/llama.cpp index c97aaee69..8b151dc84 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10,13 +10,7 @@ #include "ggml.h" -#if !defined(GGML_USE_CUBLAS) -# include "ggml-alloc.h" -# define LLAMA_USE_ALLOCATOR -#else -# define LLAMA_USE_SCRATCH -# define LLAMA_MAX_SCRATCH_BUFFERS 16 -#endif +#include "ggml-alloc.h" #ifdef GGML_USE_CUBLAS # include "ggml-cuda.h" @@ -588,14 +582,6 @@ struct llama_state { static llama_state g_state; -// -// memory sizes (calculated for n_batch == 512) -// - -// computed for n_ctx == 2048 -// TODO: dynamically determine these sizes -// needs modifications in ggml - // available llama models enum e_model { MODEL_UNKNOWN, @@ -610,76 +596,6 @@ enum e_model { static const size_t kB = 1024; static const size_t MB = 1024*1024; -static std::map MEM_REQ_SCRATCH0(int n_ctx) -{ - std::map k_sizes = { - { MODEL_3B, ((size_t) n_ctx / 16ull + 92ull) * MB }, - { MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB }, - { MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB }, - { MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB }, - { MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess - { MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB }, - }; - return k_sizes; -} - -static const std::map & MEM_REQ_SCRATCH1() -{ - static std::map k_sizes = { - { MODEL_3B, 128ull * MB }, - { MODEL_7B, 160ull * MB }, - { MODEL_13B, 192ull * MB }, - { MODEL_30B, 256ull * MB }, - { MODEL_65B, 384ull * MB }, // guess - { MODEL_70B, 304ull * MB }, - }; - return k_sizes; -} - -// used to store the compute graph tensors + non-scratch data -static const std::map & MEM_REQ_EVAL() -{ - static std::map k_sizes = { - { MODEL_3B, 8ull * MB }, - { MODEL_7B, 10ull * MB }, - { MODEL_13B, 12ull * MB }, - { MODEL_30B, 16ull * MB }, - { MODEL_65B, 24ull * MB }, // guess - { MODEL_70B, 24ull * MB }, - }; - return k_sizes; -} - -// amount of VRAM needed per batch size to hold temporary results -// the values for 3b are not derived from testing but instead chosen conservatively -static const std::map & VRAM_REQ_SCRATCH_BASE() -{ - static std::map k_sizes = { - { MODEL_3B, 512ull * kB }, - { MODEL_7B, 512ull * kB }, - { MODEL_13B, 640ull * kB }, - { MODEL_30B, 768ull * kB }, - { MODEL_65B, 1280ull * kB }, - { MODEL_70B, 1280ull * kB }, - }; - return k_sizes; -} - -// amount of VRAM needed per batch size and context to hold temporary results -// the values for 3b are not derived from testing but instead chosen conservatively -static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT() -{ - static std::map k_sizes = { - { MODEL_3B, 128ull }, - { MODEL_7B, 128ull }, - { MODEL_13B, 160ull }, - { MODEL_30B, 208ull }, - { MODEL_65B, 256ull }, - { MODEL_70B, 256ull }, - }; - return k_sizes; -} - // default hparams (LLaMA 7B) struct llama_hparams { uint32_t n_vocab = 32000; @@ -857,11 +773,9 @@ struct llama_context { ggml_metal_free(ctx_metal); } #endif -#ifdef LLAMA_USE_ALLOCATOR if (alloc) { ggml_allocr_free(alloc); } -#endif } std::mt19937 rng; @@ -901,17 +815,8 @@ struct llama_context { // memory buffers used to evaluate the model llama_buffer buf_compute; -#ifdef LLAMA_USE_ALLOCATOR llama_buffer buf_alloc; ggml_allocr * alloc = NULL; -#endif - -#ifdef LLAMA_USE_SCRATCH - llama_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; - - int buf_last = 0; - size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; -#endif #ifdef GGML_USE_METAL ggml_metal_context * ctx_metal = NULL; @@ -920,37 +825,6 @@ struct llama_context { #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; #endif - - void use_buf(struct ggml_context * ctx, int i) { // NOLINT -#if defined(LLAMA_USE_SCRATCH) - size_t last_size = 0; - - if (i == -1) { - last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); - } else { - auto & buf = buf_scratch[i]; - last_size = ggml_set_scratch(ctx, { 0, buf.size, buf.data, }); - } - - if (buf_last >= 0) { - buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); - } - - buf_last = i; -#else - (void) i; - (void) ctx; -#endif - } - - size_t get_buf_max_mem(int i) { // NOLINT -#if defined(LLAMA_USE_SCRATCH) - return buf_max_size[i]; -#else - (void) i; - return 0; -#endif - } }; // @@ -1620,7 +1494,6 @@ static void llama_model_load_internal( // prepare memory for the weights size_t vram_weights = 0; - size_t vram_scratch = 0; { const uint32_t n_embd = hparams.n_embd; const uint32_t n_embd_gqa = hparams.n_embd_gqa(); @@ -1701,13 +1574,6 @@ static void llama_model_load_internal( ctx_size + mmapped_size - vram_weights; // weights in VRAM not in memory -#ifndef LLAMA_USE_ALLOCATOR - mem_required += - MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) + - MEM_REQ_SCRATCH1().at(model.type) + - MEM_REQ_EVAL().at(model.type); -#endif - // this is the memory required by one llama_state const size_t mem_required_state = scale*hparams.kv_size(); @@ -1715,24 +1581,7 @@ static void llama_model_load_internal( LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); - (void) vram_scratch; (void) n_batch; -#ifdef GGML_USE_CUBLAS - if (low_vram) { - LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); - ggml_cuda_set_scratch_size(0); // disable scratch - } else { - const size_t vram_scratch_base = VRAM_REQ_SCRATCH_BASE().at(model.type); - const size_t vram_scratch_per_context = VRAM_REQ_SCRATCH_PER_CONTEXT().at(model.type); - vram_scratch = n_batch * (vram_scratch_base + n_ctx * vram_scratch_per_context); - ggml_cuda_set_scratch_size(vram_scratch); - if (n_gpu_layers > 0) { - LLAMA_LOG_INFO("%s: allocating batch_size x (%zd kB + n_ctx x %zd B) = %zd MB VRAM for the scratch buffer\n", - __func__, vram_scratch_base / kB, vram_scratch_per_context, - (vram_scratch + MB - 1) / MB); // round up - } - } -#endif // GGML_USE_CUBLAS #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); @@ -1769,8 +1618,8 @@ static void llama_model_load_internal( LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); - LLAMA_LOG_INFO("%s: total VRAM used: %zu MB\n", - __func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up + LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n", + __func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up #else (void) n_gpu_layers; #endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) @@ -1875,9 +1724,7 @@ static struct ggml_cgraph * llama_build_graph( /*.no_alloc =*/ false, }; -#ifdef LLAMA_USE_ALLOCATOR params.no_alloc = true; -#endif struct ggml_context * ctx0 = ggml_init(params); @@ -1889,14 +1736,10 @@ static struct ggml_cgraph * llama_build_graph( if (tokens) { struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); -#ifdef LLAMA_USE_ALLOCATOR ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); } -#else - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); -#endif ggml_set_name(inp_tokens, "inp_tokens"); inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); @@ -1907,14 +1750,10 @@ static struct ggml_cgraph * llama_build_graph( inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); -#ifdef LLAMA_USE_ALLOCATOR ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); } -#else - memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); -#endif } const int i_gpu_start = n_layer - n_gpu_layers; @@ -1931,25 +1770,21 @@ static struct ggml_cgraph * llama_build_graph( #ifdef GGML_USE_CUBLAS if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers; + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; } if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers; + offload_func_v = ggml_cuda_assign_buffers_no_alloc; } if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers; + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; } #endif // GGML_USE_CUBLAS struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); -#ifdef LLAMA_USE_ALLOCATOR ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); } -#else - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); -#endif ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); for (int il = 0; il < n_layer; ++il) { @@ -1959,14 +1794,12 @@ static struct ggml_cgraph * llama_build_graph( #ifdef GGML_USE_CUBLAS if (il >= i_gpu_start) { - offload_func = ggml_cuda_assign_buffers; + offload_func = ggml_cuda_assign_buffers_no_alloc; } #endif // GGML_USE_CUBLAS struct ggml_tensor * inpSA = inpL; - lctx.use_buf(ctx0, 0); - // norm { cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); @@ -2104,8 +1937,6 @@ static struct ggml_cgraph * llama_build_graph( ggml_set_name(cur, "result_wo"); } - lctx.use_buf(ctx0, 1); - struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); offload_func(inpFF); ggml_set_name(inpFF, "inpFF"); @@ -2160,8 +1991,6 @@ static struct ggml_cgraph * llama_build_graph( inpL = cur; } - lctx.use_buf(ctx0, 0); - // norm { cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); @@ -2178,8 +2007,6 @@ static struct ggml_cgraph * llama_build_graph( cur = ggml_mul_mat(ctx0, model.output, cur); ggml_set_name(cur, "result_output"); - lctx.use_buf(ctx0, -1); - // logits -> probs //cur = ggml_soft_max_inplace(ctx0, cur); @@ -2189,15 +2016,6 @@ static struct ggml_cgraph * llama_build_graph( mem_per_token = ggml_used_mem(ctx0)/N; } -#if 0 - LLAMA_LOG_INFO("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__, - ggml_used_mem(ctx0)/1024.0/1024.0, - lctx.get_buf_max_mem(0)/1024.0/1024.0, - lctx.get_buf_max_mem(1)/1024.0/1024.0, - lctx.work_buffer.size()/1024.0/1024.0, - n_past, N); -#endif - ggml_free(ctx0); return gf; @@ -2248,14 +2066,26 @@ static bool llama_eval_internal( const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; -#ifdef LLAMA_USE_ALLOCATOR ggml_allocr_reset(lctx.alloc); -#endif ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past); -#ifdef LLAMA_USE_ALLOCATOR ggml_allocr_alloc_graph(lctx.alloc, gf); + +#ifdef GGML_USE_CUBLAS + for (int i = 0; i < gf->n_leafs; i++) { + ggml_tensor * node = gf->leafs[i]; + if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { + ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + } + } + + for (int i = 0; i < gf->n_nodes; i++) { + ggml_tensor * node = gf->nodes[i]; + if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { + ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + } + } #endif // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -4319,7 +4149,6 @@ struct llama_context * llama_new_context_with_model( ctx->embedding.resize(hparams.n_embd); } -#ifdef LLAMA_USE_ALLOCATOR { static const size_t tensor_alignment = 32; // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data @@ -4350,13 +4179,6 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); - // debug - for comparison with scratch buffer - //size_t prev_req = - // MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) + - // MEM_REQ_SCRATCH1().at(ctx->model.type) + - // MEM_REQ_EVAL().at(ctx->model.type); - //LLAMA_LOG_INFO("%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0); - // recreate allocator with exact memory requirements ggml_allocr_free(ctx->alloc); @@ -4366,16 +4188,17 @@ struct llama_context * llama_new_context_with_model( if (ctx->ctx_metal) { ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); } +#endif +#ifdef GGML_USE_CUBLAS + if (params.low_vram) { + LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); + ggml_cuda_set_scratch_size(0); // disable scratch + } else { + ggml_cuda_set_scratch_size(alloc_size); + LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); + } #endif } -#else - ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead()); -#endif - -#ifdef LLAMA_USE_SCRATCH - ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); - ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); -#endif } #ifdef GGML_USE_METAL