diff --git a/ggml-metal.m b/ggml-metal.m index 3f098d396..b47a98e21 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -7,6 +7,11 @@ #import #import +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + #ifdef GGML_METAL_NDEBUG #define metal_printf(...) #else @@ -15,6 +20,8 @@ #define UNUSED(x) (void)(x) +#define GGML_MAX_CONCUR (2*GGML_MAX_NODES) + struct ggml_metal_buffer { const char * name; @@ -36,7 +43,7 @@ struct ggml_metal_context { int n_buffers; struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; - int concur_list[GGML_MAX_NODES]; + int concur_list[GGML_MAX_CONCUR]; int concur_list_len; // custom kernels @@ -370,15 +377,15 @@ void ggml_metal_graph_find_concurrency( struct ggml_metal_context * ctx, struct ggml_cgraph * gf) { int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time - int nodes_unused[GGML_MAX_NODES]; + int nodes_unused[GGML_MAX_CONCUR]; - for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;} - for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;} + for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; } + for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; } ctx->concur_list_len = 0; - int n_left = gf->n_nodes; - int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list - int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos + int n_left = gf->n_nodes; + int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list + int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos while (n_left > 0) { // number of nodes at a layer (that can be issued concurrently) @@ -386,28 +393,40 @@ void ggml_metal_graph_find_concurrency( for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) { if (nodes_unused[i]) { // if the requirements for gf->nodes[i] are satisfied - int exe_flag=1; + int exe_flag = 1; + // scan all srcs for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) { struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind]; if (src_cur) { // if is leaf nodes it's satisfied. - if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;} + // TODO: ggml_is_leaf() + if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) { + continue; + } // otherwise this src should be the output from previous nodes. int is_found = 0; + // scan 2*search_depth back because we inserted barrier. - for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { - if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;} + //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { + for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) { + if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) { + is_found = 1; + break; + } + } + if (is_found == 0) { + exe_flag = 0; + break; } - if (is_found == 0) {exe_flag = 0; break;} } } if (exe_flag) { // check if nodes[i]'s data will be overwritten by a node before nodes[i]. // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3] int64_t data_start = (int64_t) gf->nodes[i]->data; - int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]); + int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]); for (int j = n_start; j < i; j++) { if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \ && gf->nodes[j]->op != GGML_OP_VIEW \ @@ -416,9 +435,9 @@ void ggml_metal_graph_find_concurrency( if (((int64_t)gf->nodes[j]->data) >= data_start + length || \ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) { continue; - } else { - exe_flag = 0; } + + exe_flag = 0; } } } @@ -435,11 +454,13 @@ void ggml_metal_graph_find_concurrency( ctx->concur_list[level_pos + concurrency] = -1; ctx->concur_list_len++; // jump all sorted nodes at nodes_bak - while (!nodes_unused[n_start]) {n_start++;} + while (!nodes_unused[n_start]) { + n_start++; + } level_pos += concurrency + 1; } - if (ctx->concur_list_len > GGML_MAX_NODES) { + if (ctx->concur_list_len > GGML_MAX_CONCUR) { fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__); } } @@ -453,7 +474,7 @@ void ggml_metal_graph_compute( // else fallback to serial dispatch MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES; + const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;