ggml : change ggml_graph_compute() API to not require context (#1999)

* ggml_graph_compute: deprecate using ggml_context, try resolve issue #287

* rewrite: no longer consider backward compitability; plan and make_plan

* minor: rename ctx as plan; const

* remove ggml_graph_compute from tests/test-grad0.c, but current change breaks backward

* add static ggml_graph_compute_sugar()

* minor: update comments

* reusable buffers

* ggml : more consistent naming + metal fixes

* ggml : fix docs

* tests : disable grad / opt + minor naming changes

* ggml : add ggml_graph_compute_with_ctx()

- backwards compatible API
- deduplicates a lot of copy-paste

* ci : enable test-grad0

* examples : factor out plan allocation into a helper function

* llama : factor out plan stuff into a helper function

* ci : fix env

* llama : fix duplicate symbols + refactor example benchmark

* ggml : remove obsolete assert + refactor n_tasks section

* ggml : fix indentation in switch

* llama : avoid unnecessary bool

* ggml : remove comments from source file and match order in header

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Qingyou Meng 2023-07-08 00:24:01 +08:00 committed by GitHub
parent 7242140283
commit 1d656d6360
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 571 additions and 449 deletions

View file

@ -16,7 +16,9 @@ on:
paths: ['**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu'] paths: ['**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu']
env: env:
BRANCH_NAME: ${{ github.head_ref || github.ref_name }} BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
GGML_NLOOP: 3
GGML_NITER: 1
jobs: jobs:
ubuntu-focal-make: ubuntu-focal-make:
@ -64,7 +66,7 @@ jobs:
id: cmake_test id: cmake_test
run: | run: |
cd build cd build
ctest --verbose ctest --verbose --timeout 900
ubuntu-latest-cmake-sanitizer: ubuntu-latest-cmake-sanitizer:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -99,7 +101,7 @@ jobs:
id: cmake_test id: cmake_test
run: | run: |
cd build cd build
ctest --verbose ctest --verbose --timeout 900
macOS-latest-make: macOS-latest-make:
runs-on: macos-latest runs-on: macos-latest
@ -147,10 +149,11 @@ jobs:
id: cmake_test id: cmake_test
run: | run: |
cd build cd build
ctest --verbose ctest --verbose --timeout 900
windows-latest-cmake: windows-latest-cmake:
runs-on: windows-latest runs-on: windows-latest
env: env:
OPENBLAS_VERSION: 0.3.23 OPENBLAS_VERSION: 0.3.23
OPENCL_VERSION: 2023.04.17 OPENCL_VERSION: 2023.04.17
@ -249,7 +252,7 @@ jobs:
if: ${{ matrix.build != 'clblast' && (matrix.build != 'avx512' || env.HAS_AVX512F == '1') }} # Test AVX-512 only when possible if: ${{ matrix.build != 'clblast' && (matrix.build != 'avx512' || env.HAS_AVX512F == '1') }} # Test AVX-512 only when possible
run: | run: |
cd build cd build
ctest -C Release --verbose ctest -C Release --verbose --timeout 900
- name: Get commit hash - name: Get commit hash
id: commit id: commit

View file

@ -31,6 +31,17 @@ float frand_normal(struct random_normal_distribution * rnd) {
return ((r < rnd->min) ? (rnd->min) : (r > rnd->max) ? (rnd->max) : r); return ((r < rnd->min) ? (rnd->min) : (r > rnd->max) ? (rnd->max) : r);
} }
void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.data();
}
ggml_graph_compute(graph, &plan);
}
struct ggml_tensor * randomize_tensor( struct ggml_tensor * randomize_tensor(
struct ggml_tensor * tensor, struct ggml_tensor * tensor,
int ndims, int ndims,
@ -1569,6 +1580,8 @@ int main(int argc, char ** argv) {
int n_tokens = model.hparams.n_ctx; int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab; int n_vocab = model.hparams.n_vocab;
std::vector<uint8_t> work_buffer;
for (int ex=0; ex<n_examples; ++ex) { for (int ex=0; ex<n_examples; ++ex) {
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ compute_size, /*.mem_size =*/ compute_size,
@ -1586,7 +1599,6 @@ int main(int argc, char ** argv) {
int n_past = 0; int n_past = 0;
ggml_cgraph gf = {}; ggml_cgraph gf = {};
gf.n_threads = 1;
get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets); get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets);
@ -1595,7 +1607,7 @@ int main(int argc, char ** argv) {
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits); struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
ggml_build_forward_expand(&gf, e); ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, &gf); ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
float error_before_opt = ggml_get_f32_1d(e, 0); float error_before_opt = ggml_get_f32_1d(e, 0);
@ -1611,7 +1623,7 @@ int main(int argc, char ** argv) {
ggml_opt(ctx0, opt_params_lbfgs, e); ggml_opt(ctx0, opt_params_lbfgs, e);
// //
ggml_build_forward_expand(&gf, e); ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, &gf); ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
float error_after_opt = ggml_get_f32_1d(e, 0); float error_after_opt = ggml_get_f32_1d(e, 0);
@ -1659,13 +1671,12 @@ int main(int argc, char ** argv) {
struct ggml_context * ctx0 = ggml_init(params); struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph gf = {}; ggml_cgraph gf = {};
gf.n_threads = 1;
int n_past = 0; int n_past = 0;
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past); struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
ggml_build_forward_expand(&gf, logits); ggml_build_forward_expand(&gf, logits);
ggml_graph_compute(ctx0, &gf); ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx); struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx); struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
@ -1687,10 +1698,11 @@ int main(int argc, char ** argv) {
} }
print_matrix(model.tok_embeddings); print_matrix(model.tok_embeddings);
printf("done\n"); printf("done\n");
// ggml_free(kv_self.ctx); // ggml_free(kv_self.ctx);
// ggml_free(model_lora.ctx); // ggml_free(model_lora.ctx);
ggml_free(model.ctx); ggml_free(model.ctx);
return 0; return 0;
} }

View file

@ -20,6 +20,17 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.data();
}
ggml_graph_compute(graph, &plan);
}
float tensor_sum_elements(const ggml_tensor * tensor) { float tensor_sum_elements(const ggml_tensor * tensor) {
float sum = 0; float sum = 0;
if (tensor->type==GGML_TYPE_F32) { if (tensor->type==GGML_TYPE_F32) {
@ -159,13 +170,14 @@ int main(int argc, char ** argv) {
// printf("Creating compute graph\n"); // printf("Creating compute graph\n");
struct ggml_cgraph gf = ggml_build_forward(m11xm2); struct ggml_cgraph gf = ggml_build_forward(m11xm2);
gf.n_threads=benchmark_params.n_threads; printf("n_threads=%i\n", benchmark_params.n_threads);
printf("cgraph->n_threads=%i\n",gf.n_threads);
TENSOR_DUMP(m11); TENSOR_DUMP(m11);
TENSOR_DUMP(m2); TENSOR_DUMP(m2);
ggml_graph_compute(ctx, &gf); std::vector<uint8_t> work_buffer;
ggml_graph_compute_helper(work_buffer, &gf, benchmark_params.n_threads);
TENSOR_DUMP(gf.nodes[0]); TENSOR_DUMP(gf.nodes[0]);
@ -187,7 +199,6 @@ int main(int argc, char ** argv) {
// printf("Creating compute graph\n"); // printf("Creating compute graph\n");
struct ggml_cgraph gf31 = ggml_build_forward(q31); struct ggml_cgraph gf31 = ggml_build_forward(q31);
gf31.n_threads=benchmark_params.n_threads;
// Set up a second graph computation to make sure we override the CPU cache lines // Set up a second graph computation to make sure we override the CPU cache lines
// printf("Creating new tensor q12 & Running quantize\n"); // printf("Creating new tensor q12 & Running quantize\n");
@ -199,8 +210,7 @@ int main(int argc, char ** argv) {
//printf("Creating compute graph\n"); //printf("Creating compute graph\n");
struct ggml_cgraph gf32 = ggml_build_forward(q32); struct ggml_cgraph gf32 = ggml_build_forward(q32);
gf32.n_threads=benchmark_params.n_threads; printf("n_threads=%i\n", benchmark_params.n_threads);
printf("cgraph->n_threads=%i\n",gf31.n_threads);
const int dimx = sizex; const int dimx = sizex;
const int dimy = sizey; const int dimy = sizey;
@ -221,14 +231,15 @@ int main(int argc, char ** argv) {
long long int start = ggml_time_us(); long long int start = ggml_time_us();
//printf("Running ggml_graph_compute\n"); //printf("Running ggml_graph_compute\n");
ggml_graph_compute(ctx, &gf31); ggml_graph_compute_helper(work_buffer, &gf31, benchmark_params.n_threads);
long long int stop = ggml_time_us(); long long int stop = ggml_time_us();
long long int usec = stop-start; long long int usec = stop-start;
double gflops = (double)(flops_per_matrix)/usec/1000.0; double gflops = (double)(flops_per_matrix)/usec/1000.0;
gflops_sum += gflops; gflops_sum += gflops;
printf("%9i;%8i;%6i;%6i;%6i;%15lli;%18lli;%10.2f\n", printf("%9i;%8i;%6i;%6i;%6i;%15lli;%18lli;%10.2f\n",
i, i,
gf31.n_threads, benchmark_params.n_threads,
sizex, sizey, sizez, flops_per_matrix, sizex, sizey, sizez, flops_per_matrix,
usec,gflops); usec,gflops);
@ -253,7 +264,7 @@ int main(int argc, char ** argv) {
} }
// Running a different graph computation to make sure we override the CPU cache lines // Running a different graph computation to make sure we override the CPU cache lines
ggml_graph_compute(ctx, &gf32); ggml_graph_compute_helper(work_buffer, &gf32, benchmark_params.n_threads);
} }
printf("\n"); printf("\n");
printf("Average%78.2f\n",gflops_sum/((double)benchmark_params.n_iterations)); printf("Average%78.2f\n",gflops_sum/((double)benchmark_params.n_iterations));

View file

@ -35,10 +35,9 @@ int main(int argc, char ** argv) {
struct ggml_context * ctx_eval = NULL; struct ggml_context * ctx_eval = NULL;
struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval); struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
gf.n_threads = 1;
// this allocates all Metal resources and memory buffers // this allocates all Metal resources and memory buffers
auto * ctx_metal = ggml_metal_init(); auto * ctx_metal = ggml_metal_init(1);
const size_t max_size_data = ggml_get_max_tensor_size(ctx_data); const size_t max_size_data = ggml_get_max_tensor_size(ctx_data);
const size_t max_size_eval = ggml_get_max_tensor_size(ctx_eval); const size_t max_size_eval = ggml_get_max_tensor_size(ctx_eval);

View file

@ -60,6 +60,17 @@ float frand_uniform(struct random_uniform_distribution * rnd) {
return rnd->rd(rnd->gen); return rnd->rd(rnd->gen);
} }
void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.data();
}
ggml_graph_compute(graph, &plan);
}
struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) { struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) {
float scale = 1.0f; // xavier float scale = 1.0f; // xavier
switch (tensor->n_dims) { switch (tensor->n_dims) {
@ -1426,11 +1437,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
gf->n_nodes = 0; gf->n_nodes = 0;
gf->n_leafs = 0; gf->n_leafs = 0;
gf->work_size = 0;
gf->perf_runs = 0; gf->perf_runs = 0;
gf->perf_cycles = 0; gf->perf_cycles = 0;
gf->perf_time_us = 0; gf->perf_time_us = 0;
gf->work = NULL;
const auto & hparams = model->hparams; const auto & hparams = model->hparams;
//const int n_ctx = hparams.n_ctx; //const int n_ctx = hparams.n_ctx;
@ -3162,6 +3171,7 @@ int main(int argc, char ** argv) {
printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(model.ctx)); printf("used_mem model+cache: %zu bytes\n", ggml_used_mem(model.ctx));
// ggml_print_tensor_objects(model.ctx); // ggml_print_tensor_objects(model.ctx);
// TODO: use std::vector<uint8_t> intead of "new"
size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb); size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb);
uint8_t * compute_addr = new uint8_t[compute_size]; uint8_t * compute_addr = new uint8_t[compute_size];
@ -3183,6 +3193,8 @@ int main(int argc, char ** argv) {
GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size()); GGML_ASSERT(train_samples[i]+n_tokens-1 < (int) train_tokens.size());
} }
std::vector<uint8_t> work_buffer;
printf("%s: begin training\n", __func__); printf("%s: begin training\n", __func__);
for (int ex = 0; ex < params.n_examples; ++ex) { for (int ex = 0; ex < params.n_examples; ++ex) {
@ -3217,9 +3229,6 @@ int main(int argc, char ** argv) {
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
// ggml_cgraph gf = {};
gf->n_threads = params.n_threads;
gb->n_threads = params.n_threads;
get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs); get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
@ -3248,7 +3257,7 @@ int main(int argc, char ** argv) {
*gb = ggml_build_backward(ctx0, gf, true); *gb = ggml_build_backward(ctx0, gf, true);
} }
ggml_graph_compute(ctx0, gf); ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
size_t used_mem_before_opt = ggml_used_mem(ctx0); size_t used_mem_before_opt = ggml_used_mem(ctx0);
@ -3272,7 +3281,7 @@ int main(int argc, char ** argv) {
model.train_samples += n_batch; model.train_samples += n_batch;
model.train_tokens += n_batch * n_tokens; model.train_tokens += n_batch * n_tokens;
ggml_graph_compute(ctx0, gf); ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
float error_after_opt = ggml_get_f32_1d(loss, 0); float error_after_opt = ggml_get_f32_1d(loss, 0);
@ -3354,13 +3363,12 @@ int main(int argc, char ** argv) {
struct ggml_context * ctx0 = ggml_init(cparams); struct ggml_context * ctx0 = ggml_init(cparams);
ggml_cgraph gf = {}; ggml_cgraph gf = {};
gf.n_threads = params.n_threads;
int n_past = 0; int n_past = 0;
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past); struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
ggml_build_forward_expand(&gf, logits); ggml_build_forward_expand(&gf, logits);
ggml_graph_compute(ctx0, &gf); ggml_graph_compute_helper(work_buffer, &gf, params.n_threads);
//struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx); //struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
//struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx); //struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
@ -3386,6 +3394,7 @@ int main(int argc, char ** argv) {
delete[] compute_addr; delete[] compute_addr;
delete[] compute_buf_0; delete[] compute_buf_0;
delete[] compute_buf_1; delete[] compute_buf_1;
llama_free(lctx); llama_free(lctx);
llama_free_model(lmodel); llama_free_model(lmodel);
ggml_free(model.ctx); ggml_free(model.ctx);

View file

@ -34,9 +34,13 @@ extern "C" {
struct ggml_metal_context; struct ggml_metal_context;
struct ggml_metal_context * ggml_metal_init(void); // number of command buffers to use
struct ggml_metal_context * ggml_metal_init(int n_cb);
void ggml_metal_free(struct ggml_metal_context * ctx); void ggml_metal_free(struct ggml_metal_context * ctx);
// set the number of command buffers to use
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
// creates a mapping between a host memory buffer and a device memory buffer // creates a mapping between a host memory buffer and a device memory buffer
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute // - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
// - the mapping is used during computation to determine the arguments of the compute kernels // - the mapping is used during computation to determine the arguments of the compute kernels

View file

@ -25,6 +25,8 @@ struct ggml_metal_buffer {
}; };
struct ggml_metal_context { struct ggml_metal_context {
int n_cb;
float * logits; float * logits;
id<MTLDevice> device; id<MTLDevice> device;
@ -86,11 +88,12 @@ static NSString * const msl_library_source = @"see metal.metal";
@implementation GGMLMetalClass @implementation GGMLMetalClass
@end @end
struct ggml_metal_context * ggml_metal_init(void) { struct ggml_metal_context * ggml_metal_init(int n_cb) {
fprintf(stderr, "%s: allocating\n", __func__); fprintf(stderr, "%s: allocating\n", __func__);
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context)); struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
ctx->n_cb = n_cb;
ctx->device = MTLCreateSystemDefaultDevice(); ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue]; ctx->queue = [ctx->device newCommandQueue];
ctx->n_buffers = 0; ctx->n_buffers = 0;
@ -208,6 +211,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
free(ctx); free(ctx);
} }
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
ctx->n_cb = n_cb;
}
// finds the Metal buffer that contains the tensor data on the GPU device // finds the Metal buffer that contains the tensor data on the GPU device
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
// Metal buffer based on the host memory pointer // Metal buffer based on the host memory pointer
@ -354,7 +361,7 @@ void ggml_metal_graph_compute(
// create multiple command buffers and enqueue them // create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel // then, we encode the graph into the command buffers in parallel
const int n_cb = gf->n_threads; const int n_cb = ctx->n_cb;
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb]; NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];

760
ggml.c
View file

@ -4583,14 +4583,13 @@ struct ggml_tensor * ggml_new_tensor_impl(
/*.src0 =*/ NULL, /*.src0 =*/ NULL,
/*.src1 =*/ NULL, /*.src1 =*/ NULL,
/*.opt =*/ { NULL }, /*.opt =*/ { NULL },
/*.n_tasks =*/ 0,
/*.perf_runs =*/ 0, /*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0, /*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0, /*.perf_time_us =*/ 0,
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data, /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
/*.name =*/ { 0 }, /*.name =*/ { 0 },
/*.extra =*/ NULL, /*.extra =*/ NULL,
/*.pad =*/ { 0 }, /*.padding =*/ { 0 },
}; };
// TODO: this should not be needed as long as we don't rely on aligned SIMD loads // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
@ -10718,8 +10717,6 @@ static void ggml_compute_forward_mul_mat(
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
assert(ne00 % 32 == 0);
for (int64_t ic = 0; ic < ne11; ++ic) { for (int64_t ic = 0; ic < ne11; ++ic) {
vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
} }
@ -15772,9 +15769,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
struct ggml_cgraph result = { struct ggml_cgraph result = {
/*.n_nodes =*/ 0, /*.n_nodes =*/ 0,
/*.n_leafs =*/ 0, /*.n_leafs =*/ 0,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS,
/*.work_size =*/ 0,
/*.work =*/ NULL,
/*.nodes =*/ { NULL }, /*.nodes =*/ { NULL },
/*.grads =*/ { NULL }, /*.grads =*/ { NULL },
/*.leafs =*/ { NULL }, /*.leafs =*/ { NULL },
@ -15945,12 +15939,13 @@ void clear_numa_thread_affinity(void) {}
#endif #endif
struct ggml_compute_state_shared { struct ggml_compute_state_shared {
struct ggml_cgraph * cgraph; const struct ggml_cgraph * cgraph;
const struct ggml_cplan * cplan;
int64_t perf_node_start_cycles; int64_t perf_node_start_cycles;
int64_t perf_node_start_time_us; int64_t perf_node_start_time_us;
int n_threads; const int n_threads;
// synchronization primitives // synchronization primitives
atomic_int n_active; // num active threads atomic_int n_active; // num active threads
@ -15974,9 +15969,13 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
static thread_ret_t ggml_graph_compute_thread(void * data) { static thread_ret_t ggml_graph_compute_thread(void * data) {
struct ggml_compute_state * state = (struct ggml_compute_state *) data; struct ggml_compute_state * state = (struct ggml_compute_state *) data;
struct ggml_cgraph * cgraph = state->shared->cgraph;
const int n_threads = state->shared->n_threads; const struct ggml_cgraph * cgraph = state->shared->cgraph;
const struct ggml_cplan * cplan = state->shared->cplan;
const int * n_tasks_arr = cplan->n_tasks;
const int n_threads = state->shared->n_threads;
set_numa_thread_affinity(state->ith, n_threads); set_numa_thread_affinity(state->ith, n_threads);
int node_n = -1; int node_n = -1;
@ -15989,15 +15988,15 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
/*.type =*/ GGML_TASK_FINALIZE, /*.type =*/ GGML_TASK_FINALIZE,
/*.ith =*/ 0, /*.ith =*/ 0,
/*.nth =*/ 0, /*.nth =*/ 0,
/*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, /*.wsize =*/ cplan->work_size,
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, /*.wdata =*/ cplan->work_data,
}; };
if (node_n != -1) { if (node_n != -1) {
/* FINALIZE */ /* FINALIZE */
struct ggml_tensor * node = state->shared->cgraph->nodes[node_n]; struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
if (GGML_OP_HAS_FINALIZE[node->op]) { if (GGML_OP_HAS_FINALIZE[node->op]) {
params.nth = node->n_tasks; params.nth = n_tasks_arr[node_n];
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);
ggml_graph_compute_perf_stats_node(node, state->shared); ggml_graph_compute_perf_stats_node(node, state->shared);
} }
@ -16008,11 +16007,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
struct ggml_tensor * node = cgraph->nodes[node_n]; struct ggml_tensor * node = cgraph->nodes[node_n];
const int n_tasks = n_tasks_arr[node_n];
state->shared->perf_node_start_cycles = ggml_perf_cycles(); state->shared->perf_node_start_cycles = ggml_perf_cycles();
state->shared->perf_node_start_time_us = ggml_perf_time_us(); state->shared->perf_node_start_time_us = ggml_perf_time_us();
params.nth = node->n_tasks; params.nth = n_tasks;
/* INIT */ /* INIT */
if (GGML_OP_HAS_INIT[node->op]) { if (GGML_OP_HAS_INIT[node->op]) {
@ -16020,7 +16020,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);
} }
if (node->n_tasks == 1) { if (n_tasks == 1) {
// TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
// they do something more efficient than spinning (?) // they do something more efficient than spinning (?)
params.type = GGML_TASK_COMPUTE; params.type = GGML_TASK_COMPUTE;
@ -16052,16 +16052,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
/* COMPUTE */ /* COMPUTE */
struct ggml_tensor * node = cgraph->nodes[node_n]; struct ggml_tensor * node = cgraph->nodes[node_n];
const int n_tasks = n_tasks_arr[node_n];
struct ggml_compute_params params = { struct ggml_compute_params params = {
/*.type =*/ GGML_TASK_COMPUTE, /*.type =*/ GGML_TASK_COMPUTE,
/*.ith =*/ state->ith, /*.ith =*/ state->ith,
/*.nth =*/ node->n_tasks, /*.nth =*/ n_tasks,
/*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, /*.wsize =*/ cplan->work_size,
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, /*.wdata =*/ cplan->work_data,
}; };
if (state->ith < node->n_tasks) { if (state->ith < n_tasks) {
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);
} }
} }
@ -16069,11 +16070,364 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
return 0; return 0;
} }
void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
const int n_threads = cgraph->n_threads; if (n_threads <= 0) {
n_threads = GGML_DEFAULT_N_THREADS;
}
size_t work_size = 0;
struct ggml_cplan cplan;
memset(&cplan, 0, sizeof(struct ggml_cplan));
// thread scheduling for the different operations + work buffer size estimation
for (int i = 0; i < cgraph->n_nodes; i++) {
int n_tasks = 1;
struct ggml_tensor * node = cgraph->nodes[i];
switch (node->op) {
case GGML_OP_CPY:
case GGML_OP_DUP:
{
n_tasks = n_threads;
size_t cur = 0;
if (ggml_is_quantized(node->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_ADD:
case GGML_OP_ADD1:
{
n_tasks = n_threads;
size_t cur = 0;
if (ggml_is_quantized(node->src0->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_tasks;
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_ACC:
{
n_tasks = n_threads;
size_t cur = 0;
if (ggml_is_quantized(node->src0->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_tasks;
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_SUB:
case GGML_OP_DIV:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_LOG:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_ABS:
case GGML_OP_SGN:
case GGML_OP_NEG:
case GGML_OP_STEP:
case GGML_OP_TANH:
case GGML_OP_ELU:
case GGML_OP_RELU:
{
n_tasks = 1;
} break;
case GGML_OP_MUL:
case GGML_OP_GELU:
case GGML_OP_GELU_QUICK:
case GGML_OP_SILU:
case GGML_OP_SILU_BACK:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
{
n_tasks = n_threads;
} break;
case GGML_OP_MUL_MAT:
case GGML_OP_OUT_PROD:
{
n_tasks = n_threads;
// TODO: use different scheduling for different matrix sizes
//const int nr0 = ggml_nrows(node->src0);
//const int nr1 = ggml_nrows(node->src1);
//n_tasks = MIN(n_threads, MAX(1, nr0/128));
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
size_t cur = 0;
const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
} else
#elif defined(GGML_USE_CLBLAST)
if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
} else
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
if (node->src0->type != GGML_TYPE_F32) {
// here we need memory just for single 2D matrix from src0
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
}
} else
#endif
if (node->src1->type != vec_dot_type) {
cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
} else {
cur = 0;
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_SCALE:
{
n_tasks = 1;
} break;
case GGML_OP_SET:
case GGML_OP_CONT:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_GET_ROWS:
case GGML_OP_GET_ROWS_BACK:
case GGML_OP_DIAG:
case GGML_OP_DIAG_MASK_ZERO:
{
n_tasks = 1;
} break;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{
n_tasks = n_threads;
} break;
case GGML_OP_ALIBI:
{
n_tasks = 1; //TODO
} break;
case GGML_OP_CLAMP:
{
n_tasks = 1; //TODO
} break;
case GGML_OP_CONV_1D:
{
n_tasks = n_threads;
GGML_ASSERT(node->src0->ne[3] == 1);
GGML_ASSERT(node->src1->ne[2] == 1);
GGML_ASSERT(node->src1->ne[3] == 1);
size_t cur = 0;
const int nk = node->src0->ne[0];
if (node->src0->type == GGML_TYPE_F16 &&
node->src1->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
);
} else if (node->src0->type == GGML_TYPE_F32 &&
node->src1->type == GGML_TYPE_F32) {
cur = sizeof(float)*(
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
);
} else {
GGML_ASSERT(false);
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_CONV_2D:
{
n_tasks = n_threads;
GGML_ASSERT(node->src1->ne[3] == 1);
const int64_t ne00 = node->src0->ne[0]; // W
const int64_t ne01 = node->src0->ne[1]; // H
const int64_t ne02 = node->src0->ne[2]; // C
const int64_t ne03 = node->src0->ne[3]; // N
const int64_t ne10 = node->src1->ne[0]; // W
const int64_t ne11 = node->src1->ne[1]; // H
const int64_t ne12 = node->src1->ne[2]; // C
const int64_t nk = ne00*ne01;
UNUSED(ne02);
UNUSED(ne03);
UNUSED(nk);
size_t cur = 0;
if (node->src0->type == GGML_TYPE_F16 &&
node->src1->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
} else if (node->src0->type == GGML_TYPE_F32 &&
node->src1->type == GGML_TYPE_F32) {
cur = sizeof(float)* (ne10*ne11*ne12);
} else {
GGML_ASSERT(false);
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_FLASH_ATTN:
{
n_tasks = n_threads;
size_t cur = 0;
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
if (node->src1->type == GGML_TYPE_F32) {
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
}
if (node->src1->type == GGML_TYPE_F16) {
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_FLASH_FF:
{
n_tasks = n_threads;
size_t cur = 0;
if (node->src1->type == GGML_TYPE_F32) {
cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
}
if (node->src1->type == GGML_TYPE_F16) {
cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
n_tasks = n_threads;
size_t cur = 0;
const int64_t D = node->src0->ne[0];
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
if (node->src1->type == GGML_TYPE_F32) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
}
if (node->src1->type == GGML_TYPE_F16) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1:
case GGML_OP_MAP_CUSTOM2:
case GGML_OP_MAP_CUSTOM3:
{
n_tasks = 1;
} break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
n_tasks = n_threads;
size_t cur = ggml_type_size(node->type)*(n_tasks + node->src0->ne[0]*n_tasks);
work_size = MAX(work_size, cur);
} break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
{
n_tasks = n_threads;
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*n_tasks;
work_size = MAX(work_size, cur);
} break;
case GGML_OP_NONE:
{
n_tasks = 1;
} break;
case GGML_OP_COUNT:
{
GGML_ASSERT(false);
} break;
}
cplan.n_tasks[i] = n_tasks;
}
if (work_size > 0) {
work_size += CACHE_LINE_SIZE*(n_threads - 1);
}
cplan.n_threads = n_threads;
cplan.work_size = work_size;
cplan.work_data = NULL;
return cplan;
}
void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
{
GGML_ASSERT(cplan);
GGML_ASSERT(cplan->n_threads > 0);
if (cplan->work_size > 0) {
GGML_ASSERT(cplan->work_data);
}
for (int i = 0; i < cgraph->n_nodes; ++i) {
if (cgraph->nodes[i]->op != GGML_OP_NONE) {
GGML_ASSERT(cplan->n_tasks[i] > 0);
}
}
}
const int n_threads = cplan->n_threads;
struct ggml_compute_state_shared state_shared = { struct ggml_compute_state_shared state_shared = {
/*.cgraph =*/ cgraph, /*.cgraph =*/ cgraph,
/*.cgraph_plan =*/ cplan,
/*.perf_node_start_cycles =*/ 0, /*.perf_node_start_cycles =*/ 0,
/*.perf_node_start_time_us =*/ 0, /*.perf_node_start_time_us =*/ 0,
/*.n_threads =*/ n_threads, /*.n_threads =*/ n_threads,
@ -16082,336 +16436,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
}; };
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
// initialize tasks + work buffer
{
size_t work_size = 0;
// thread scheduling for the different operations
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
switch (node->op) {
case GGML_OP_CPY:
case GGML_OP_DUP:
{
node->n_tasks = n_threads;
size_t cur = 0;
if (ggml_is_quantized(node->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_ADD:
case GGML_OP_ADD1:
{
node->n_tasks = n_threads;
size_t cur = 0;
if (ggml_is_quantized(node->src0->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_ACC:
{
node->n_tasks = n_threads;
size_t cur = 0;
if (ggml_is_quantized(node->src0->type)) {
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_threads;
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_SUB:
case GGML_OP_DIV:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_LOG:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_ABS:
case GGML_OP_SGN:
case GGML_OP_NEG:
case GGML_OP_STEP:
case GGML_OP_TANH:
case GGML_OP_ELU:
case GGML_OP_RELU:
{
node->n_tasks = 1;
} break;
case GGML_OP_MUL:
case GGML_OP_GELU:
case GGML_OP_GELU_QUICK:
case GGML_OP_SILU:
case GGML_OP_SILU_BACK:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
{
node->n_tasks = n_threads;
} break;
case GGML_OP_MUL_MAT:
case GGML_OP_OUT_PROD:
{
node->n_tasks = n_threads;
// TODO: use different scheduling for different matrix sizes
//const int nr0 = ggml_nrows(node->src0);
//const int nr1 = ggml_nrows(node->src1);
//node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
size_t cur = 0;
const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
node->n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
}
else
#elif defined(GGML_USE_CLBLAST)
if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
node->n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
}
else
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
if (node->src0->type != GGML_TYPE_F32) {
// here we need memory just for single 2D matrix from src0
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
}
} else
#endif
if (node->src1->type != vec_dot_type) {
cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
} else {
cur = 0;
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_SCALE:
{
node->n_tasks = 1;
} break;
case GGML_OP_SET:
case GGML_OP_CONT:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_GET_ROWS:
case GGML_OP_GET_ROWS_BACK:
case GGML_OP_DIAG:
case GGML_OP_DIAG_MASK_ZERO:
{
node->n_tasks = 1;
} break;
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{
node->n_tasks = n_threads;
} break;
case GGML_OP_ALIBI:
{
node->n_tasks = 1; //TODO
} break;
case GGML_OP_CLAMP:
{
node->n_tasks = 1; //TODO
} break;
case GGML_OP_CONV_1D:
{
node->n_tasks = n_threads;
GGML_ASSERT(node->src0->ne[3] == 1);
GGML_ASSERT(node->src1->ne[2] == 1);
GGML_ASSERT(node->src1->ne[3] == 1);
size_t cur = 0;
const int nk = node->src0->ne[0];
if (node->src0->type == GGML_TYPE_F16 &&
node->src1->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
);
} else if (node->src0->type == GGML_TYPE_F32 &&
node->src1->type == GGML_TYPE_F32) {
cur = sizeof(float)*(
nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
);
} else {
GGML_ASSERT(false);
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_CONV_2D:
{
node->n_tasks = n_threads;
GGML_ASSERT(node->src1->ne[3] == 1);
const int64_t ne00 = node->src0->ne[0]; // W
const int64_t ne01 = node->src0->ne[1]; // H
const int64_t ne02 = node->src0->ne[2]; // C
const int64_t ne03 = node->src0->ne[3]; // N
const int64_t ne10 = node->src1->ne[0]; // W
const int64_t ne11 = node->src1->ne[1]; // H
const int64_t ne12 = node->src1->ne[2]; // C
const int64_t nk = ne00*ne01;
UNUSED(ne02);
UNUSED(ne03);
UNUSED(nk);
size_t cur = 0;
if (node->src0->type == GGML_TYPE_F16 &&
node->src1->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
} else if (node->src0->type == GGML_TYPE_F32 &&
node->src1->type == GGML_TYPE_F32) {
cur = sizeof(float)* (ne10*ne11*ne12);
} else {
GGML_ASSERT(false);
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_FLASH_ATTN:
{
node->n_tasks = n_threads;
size_t cur = 0;
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
if (node->src1->type == GGML_TYPE_F32) {
cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
}
if (node->src1->type == GGML_TYPE_F16) {
cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_FLASH_FF:
{
node->n_tasks = n_threads;
size_t cur = 0;
if (node->src1->type == GGML_TYPE_F32) {
cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
}
if (node->src1->type == GGML_TYPE_F16) {
cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
node->n_tasks = n_threads;
size_t cur = 0;
const int64_t D = node->src0->ne[0];
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
if (node->src1->type == GGML_TYPE_F32) {
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
}
if (node->src1->type == GGML_TYPE_F16) {
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
}
work_size = MAX(work_size, cur);
} break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:
case GGML_OP_MAP_UNARY:
case GGML_OP_MAP_BINARY:
case GGML_OP_MAP_CUSTOM1:
case GGML_OP_MAP_CUSTOM2:
case GGML_OP_MAP_CUSTOM3:
{
node->n_tasks = 1;
} break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
node->n_tasks = n_threads;
size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
work_size = MAX(work_size, cur);
} break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
{
node->n_tasks = n_threads;
size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
work_size = MAX(work_size, cur);
} break;
case GGML_OP_NONE:
{
node->n_tasks = 1;
} break;
case GGML_OP_COUNT:
{
GGML_ASSERT(false);
} break;
}
}
if (cgraph->work != NULL && work_size > cgraph->work_size) {
GGML_ASSERT(false); // TODO: better handling
}
if (work_size > 0 && cgraph->work == NULL) {
cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
}
}
// create thread pool // create thread pool
if (n_threads > 1) { if (n_threads > 1) {
for (int j = 1; j < n_threads; ++j) { for (int j = 1; j < n_threads; ++j) {
@ -16473,6 +16497,17 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
} }
} }
void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
GGML_ASSERT(buf);
cplan.work_data = buf->data;
ggml_graph_compute(cgraph, &cplan);
}
struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) { struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
for (int i = 0; i < cgraph->n_leafs; i++) { for (int i = 0; i < cgraph->n_leafs; i++) {
struct ggml_tensor * leaf = cgraph->leafs[i]; struct ggml_tensor * leaf = cgraph->leafs[i];
@ -16511,14 +16546,13 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
const int64_t * ne = tensor->ne; const int64_t * ne = tensor->ne;
const size_t * nb = tensor->nb; const size_t * nb = tensor->nb;
fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n", fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
arg, arg,
ggml_type_name(tensor->type), ggml_type_name(tensor->type),
ggml_op_name (tensor->op), ggml_op_name (tensor->op),
tensor->n_dims, tensor->n_dims,
ne[0], ne[1], ne[2], ne[3], ne[0], ne[1], ne[2], ne[3],
nb[0], nb[1], nb[2], nb[3], nb[0], nb[1], nb[2], nb[3],
tensor->n_tasks,
tensor->data, tensor->data,
tensor->name); tensor->name);
} }
@ -17254,9 +17288,6 @@ static enum ggml_opt_result ggml_opt_adam(
struct ggml_cgraph * gb) { struct ggml_cgraph * gb) {
GGML_ASSERT(ggml_is_scalar(f)); GGML_ASSERT(ggml_is_scalar(f));
gf->n_threads = params.n_threads;
gb->n_threads = params.n_threads;
// these will store the parameters we want to optimize // these will store the parameters we want to optimize
struct ggml_tensor * ps[GGML_MAX_PARAMS]; struct ggml_tensor * ps[GGML_MAX_PARAMS];
@ -17303,7 +17334,8 @@ static enum ggml_opt_result ggml_opt_adam(
// compute the function value // compute the function value
ggml_graph_reset (gf); ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute(ctx, gb);
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
opt->adam.fx_prev = ggml_get_f32_1d(f, 0); opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
opt->adam.fx_best = opt->adam.fx_prev; opt->adam.fx_best = opt->adam.fx_prev;
@ -17383,7 +17415,8 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_graph_reset (gf); ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute(ctx, gb);
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
const float fx = ggml_get_f32_1d(f, 0); const float fx = ggml_get_f32_1d(f, 0);
@ -17505,7 +17538,8 @@ static enum ggml_opt_result linesearch_backtracking(
ggml_graph_reset (gf); ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute(ctx, gb);
ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
ggml_opt_get_grad(np, ps, g); ggml_opt_get_grad(np, ps, g);
@ -17573,9 +17607,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
} }
} }
gf->n_threads = params.n_threads;
gb->n_threads = params.n_threads;
const int m = params.lbfgs.m; const int m = params.lbfgs.m;
// these will store the parameters we want to optimize // these will store the parameters we want to optimize
@ -17627,7 +17658,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_graph_reset (gf); ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute(ctx, gb);
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
ggml_opt_get_grad(np, ps, g); ggml_opt_get_grad(np, ps, g);

36
ggml.h
View file

@ -65,7 +65,7 @@
// ggml_set_f32(a, 3.0f); // ggml_set_f32(a, 3.0f);
// ggml_set_f32(b, 4.0f); // ggml_set_f32(b, 4.0f);
// //
// ggml_graph_compute(ctx0, &gf); // ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
// //
// printf("f = %f\n", ggml_get_f32_1d(f, 0)); // printf("f = %f\n", ggml_get_f32_1d(f, 0));
// //
@ -418,9 +418,6 @@ extern "C" {
struct ggml_tensor * src1; struct ggml_tensor * src1;
struct ggml_tensor * opt[GGML_MAX_OPT]; struct ggml_tensor * opt[GGML_MAX_OPT];
// thread scheduling
int n_tasks;
// performance // performance
int perf_runs; int perf_runs;
int64_t perf_cycles; int64_t perf_cycles;
@ -432,19 +429,27 @@ extern "C" {
void * extra; // extra things e.g. for ggml-cuda.cu void * extra; // extra things e.g. for ggml-cuda.cu
char padding[4]; char padding[8];
}; };
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
// the compute plan that needs to be prepared for ggml_graph_compute()
// since https://github.com/ggerganov/ggml/issues/287
struct ggml_cplan {
size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()`
uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
int n_threads;
// the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
int n_tasks[GGML_MAX_NODES];
};
// computation graph // computation graph
struct ggml_cgraph { struct ggml_cgraph {
int n_nodes; int n_nodes;
int n_leafs; int n_leafs;
int n_threads;
size_t work_size;
struct ggml_tensor * work;
struct ggml_tensor * nodes[GGML_MAX_NODES]; struct ggml_tensor * nodes[GGML_MAX_NODES];
struct ggml_tensor * grads[GGML_MAX_NODES]; struct ggml_tensor * grads[GGML_MAX_NODES];
@ -1290,15 +1295,22 @@ extern "C" {
GGML_API void ggml_set_param( GGML_API void ggml_set_param(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * tensor); struct ggml_tensor * tensor);
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph); // ggml_graph_plan() has to be called before ggml_graph_compute()
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
// same as ggml_graph_compute() but the work data is allocated as a part of the context
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name); GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);

View file

@ -79,6 +79,25 @@ void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
(void) tensor; (void) tensor;
} }
//
// ggml helpers
//
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
if (plan.work_size > 0) {
buf.resize(plan.work_size);
plan.work_data = buf.data();
}
ggml_graph_compute(graph, &plan);
}
//
// memory sizes
//
static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0() static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0()
{ {
static std::map<e_model, size_t> k_sizes = { static std::map<e_model, size_t> k_sizes = {
@ -321,6 +340,9 @@ struct llama_context {
// input embedding (1-dimensional array: [n_embd]) // input embedding (1-dimensional array: [n_embd])
std::vector<float> embedding; std::vector<float> embedding;
// reusable buffer for `struct ggml_graph_plan.work_data`
std::vector<uint8_t> work_buffer;
// memory buffers used to evaluate the model // memory buffers used to evaluate the model
// TODO: move in llama_state // TODO: move in llama_state
llama_ctx_buffer buf_compute; llama_ctx_buffer buf_compute;
@ -758,7 +780,6 @@ struct llama_model_loader {
}; };
// //
// kv cache // kv cache
// //
@ -1265,7 +1286,7 @@ static bool llama_eval_internal(
const float * embd, const float * embd,
const int n_tokens, const int n_tokens,
const int n_past, const int n_past,
const int n_threads, int n_threads,
const char * cgraph_fname) { const char * cgraph_fname) {
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
@ -1306,10 +1327,11 @@ static bool llama_eval_internal(
struct ggml_context * ctx0 = ggml_init(params); struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph gf = {};
// for big prompts, if BLAS is enabled, it is better to use only one thread // for big prompts, if BLAS is enabled, it is better to use only one thread
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
ggml_cgraph gf = {}; n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -1593,6 +1615,7 @@ static bool llama_eval_internal(
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (lctx.ctx_metal && N == 1) { if (lctx.ctx_metal && N == 1) {
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
ggml_metal_graph_compute(lctx.ctx_metal, &gf); ggml_metal_graph_compute(lctx.ctx_metal, &gf);
ggml_metal_get_tensor (lctx.ctx_metal, cur); ggml_metal_get_tensor (lctx.ctx_metal, cur);
} else { } else {
@ -1612,10 +1635,10 @@ static bool llama_eval_internal(
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v); ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
} }
ggml_graph_compute(ctx0, &gf); ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
} }
#else #else
ggml_graph_compute(ctx0, &gf); ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
#endif #endif
if (cgraph_fname) { if (cgraph_fname) {
@ -2575,8 +2598,8 @@ void llama_free_model(struct llama_model * model) {
} }
struct llama_context * llama_new_context_with_model( struct llama_context * llama_new_context_with_model(
struct llama_model * model, struct llama_model * model,
struct llama_context_params params) { struct llama_context_params params) {
if (!model) { if (!model) {
return nullptr; return nullptr;
@ -2645,7 +2668,7 @@ struct llama_context * llama_new_context_with_model(
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (params.n_gpu_layers > 0) { if (params.n_gpu_layers > 0) {
// this allocates all Metal resources and memory buffers // this allocates all Metal resources and memory buffers
ctx->ctx_metal = ggml_metal_init(); ctx->ctx_metal = ggml_metal_init(1);
void * data_ptr = NULL; void * data_ptr = NULL;
size_t data_size = 0; size_t data_size = 0;
@ -2802,6 +2825,9 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
// read tensors and apply // read tensors and apply
bool warned = false; bool warned = false;
int n_tensors = 0; int n_tensors = 0;
std::vector<uint8_t> work_buffer;
while (true) { while (true) {
int32_t n_dims; int32_t n_dims;
int32_t length; int32_t length;
@ -2966,8 +2992,8 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const
} }
struct ggml_cgraph gf = ggml_build_forward(r); struct ggml_cgraph gf = ggml_build_forward(r);
gf.n_threads = n_threads;
ggml_graph_compute(lora_ctx, &gf); ggml_graph_compute_helper(work_buffer, &gf, n_threads);
// we won't need these tensors again, reset the context to save memory // we won't need these tensors again, reset the context to save memory
ggml_free(lora_ctx); ggml_free(lora_ctx);
@ -3120,7 +3146,6 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
ggml_cgraph gf{}; ggml_cgraph gf{};
gf.n_threads = 1;
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
kout3d->data = out; kout3d->data = out;
@ -3140,7 +3165,7 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
ggml_graph_compute(cpy_ctx, &gf); ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
ggml_free(cpy_ctx); ggml_free(cpy_ctx);
} }
@ -3226,7 +3251,6 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true });
ggml_cgraph gf{}; ggml_cgraph gf{};
gf.n_threads = 1;
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
kin3d->data = (void *) inp; kin3d->data = (void *) inp;
@ -3246,7 +3270,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
ggml_graph_compute(cpy_ctx, &gf); ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1);
ggml_free(cpy_ctx); ggml_free(cpy_ctx);
} }

View file

@ -10,5 +10,5 @@ llama_add_test(test-quantize-fns.cpp)
llama_add_test(test-quantize-perf.cpp) llama_add_test(test-quantize-perf.cpp)
llama_add_test(test-sampling.cpp) llama_add_test(test-sampling.cpp)
llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin) llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
# llama_add_test(test-grad0.c) # SLOW llama_add_test(test-grad0.c) # SLOW
# llama_add_test(test-opt.c) # SLOW # llama_add_test(test-opt.c) # SLOW

View file

@ -10,6 +10,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
#pragma GCC diagnostic ignored "-Wdouble-promotion"
#define MAX_NARGS 3 #define MAX_NARGS 3
#undef MIN #undef MIN
@ -49,7 +51,7 @@ float frand(void) {
int irand(int n) { int irand(int n) {
if (n == 0) return 0; if (n == 0) return 0;
else return rand()%n; return rand()%n;
} }
void get_random_dims(int64_t * dims, int ndims) { void get_random_dims(int64_t * dims, int ndims) {
@ -159,12 +161,14 @@ struct ggml_tensor * get_random_tensor_int(
float get_element(const struct ggml_tensor * t, int idx) { float get_element(const struct ggml_tensor * t, int idx) {
if (t->type == GGML_TYPE_F32) { if (t->type == GGML_TYPE_F32) {
return ((float *)t->data)[idx]; return ((float *)t->data)[idx];
} else if (t->type == GGML_TYPE_I32) {
return ((int32_t *)t->data)[idx];
} else {
assert(false);
return INFINITY;
} }
if (t->type == GGML_TYPE_I32) {
return ((int32_t *)t->data)[idx];
}
assert(false);
return INFINITY;
} }
void set_element(struct ggml_tensor * t, int idx, float value) { void set_element(struct ggml_tensor * t, int idx, float value) {
@ -215,15 +219,14 @@ bool check_gradient(
} }
struct ggml_cgraph gf = ggml_build_forward (f); struct ggml_cgraph gf = ggml_build_forward (f);
gf.n_threads = n_threads;
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false); struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
gb.n_threads = n_threads;
ggml_graph_compute(ctx0, &gf); ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
ggml_graph_reset (&gf); ggml_graph_reset (&gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute(ctx0, &gb);
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
// ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot"); // ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
// ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot"); // ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot");
@ -236,15 +239,16 @@ bool check_gradient(
const float xm = x0 - eps; const float xm = x0 - eps;
const float xp = x0 + eps; const float xp = x0 + eps;
set_element(x[i], k, xp); set_element(x[i], k, xp);
ggml_graph_compute(ctx0, &gf);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const float f0 = ggml_get_f32_1d(f, 0); const float f0 = ggml_get_f32_1d(f, 0);
set_element(x[i], k, xm); set_element(x[i], k, xm);
ggml_graph_compute(ctx0, &gf);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const float f1 = ggml_get_f32_1d(f, 0); const float f1 = ggml_get_f32_1d(f, 0);
const float g0 = (f0 - f1)/(2.0f*eps); const float g0 = (f0 - f1)/(2.0f*eps);
set_element(x[i], k, x0); set_element(x[i], k, x0);
@ -252,12 +256,13 @@ bool check_gradient(
// compute gradient using backward graph // compute gradient using backward graph
ggml_graph_reset (&gf); ggml_graph_reset (&gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute(ctx0, &gb);
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
const float g1 = get_element(x[i]->grad, k); const float g1 = get_element(x[i]->grad, k);
const float error_abs = fabsf(g0 - g1); const float error_abs = fabsf(g0 - g1);
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabs(g0) : 0; const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
if (error_abs > max_error_abs || error_rel > max_error_rel) { if (error_abs > max_error_abs || error_rel > max_error_rel) {
printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n", printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",

View file

@ -7,6 +7,7 @@
#define MAX_NARGS 2 #define MAX_NARGS 2
#pragma GCC diagnostic ignored "-Wdouble-promotion"
// //
// logging // logging
@ -33,7 +34,7 @@
#define GGML_PRINT(...) printf(__VA_ARGS__) #define GGML_PRINT(...) printf(__VA_ARGS__)
float frand() { float frand(void) {
return (float)rand()/(float)RAND_MAX; return (float)rand()/(float)RAND_MAX;
} }
@ -114,7 +115,7 @@ void set_element(struct ggml_tensor * t, int idx, float value) {
((float *)t->data)[idx] = value; ((float *)t->data)[idx] = value;
} }
int main(int argc, const char ** argv) { int main(void) {
struct ggml_init_params params = { struct ggml_init_params params = {
.mem_size = 1024*1024*1024, .mem_size = 1024*1024*1024,
.mem_buffer = NULL, .mem_buffer = NULL,
@ -137,10 +138,11 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * d = ggml_sub(ctx, c, ab); struct ggml_tensor * d = ggml_sub(ctx, c, ab);
struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d)); struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d));
struct ggml_cgraph ge = ggml_build_forward(e); struct ggml_cgraph ge = ggml_build_forward(e);
ggml_graph_reset (&ge); ggml_graph_reset(&ge);
ggml_graph_compute(ctx, &ge);
ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
const float fe = ggml_get_f32_1d(e, 0); const float fe = ggml_get_f32_1d(e, 0);
printf("%s: e = %.4f\n", __func__, fe); printf("%s: e = %.4f\n", __func__, fe);
@ -148,8 +150,10 @@ int main(int argc, const char ** argv) {
ggml_opt(ctx, opt_params, e); ggml_opt(ctx, opt_params, e);
ggml_graph_reset (&ge); ggml_graph_reset(&ge);
ggml_graph_compute(ctx, &ge);
ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
const float fe_opt = ggml_get_f32_1d(e, 0); const float fe_opt = ggml_get_f32_1d(e, 0);
printf("%s: original e = %.4f\n", __func__, fe); printf("%s: original e = %.4f\n", __func__, fe);
printf("%s: optimized e = %.4f\n", __func__, fe_opt); printf("%s: optimized e = %.4f\n", __func__, fe_opt);