diff --git a/ggml.c b/ggml.c index e8384ed..c0354f4 100644 --- a/ggml.c +++ b/ggml.c @@ -1136,6 +1136,7 @@ struct ggml_state { // global state struct ggml_state g_state; +atomic_bool g_state_barrier = 0; //////////////////////////////////////////////////////////////////////////////// @@ -1265,6 +1266,17 @@ int ggml_up64(int n) { //////////////////////////////////////////////////////////////////////////////// struct ggml_context * ggml_init(struct ggml_init_params params) { + // make this function thread safe + { + int processing = atomic_fetch_add(&g_state_barrier, 1); + while (processing > 0) { + // wait for other threads to finish + atomic_fetch_sub(&g_state_barrier, 1); + sched_yield(); + processing = atomic_fetch_add(&g_state_barrier, 1); + } + } + static bool is_first_call = true; if (is_first_call) { const uint64_t t_start = ggml_time_us(); UNUSED(t_start); @@ -1308,6 +1320,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { if (ctx == NULL) { GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); + + atomic_fetch_sub(&g_state_barrier, 1); + return NULL; } @@ -1322,10 +1337,25 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { ggml_assert_aligned(ctx->mem_buffer); + GGML_PRINT_DEBUG("%s: context initialized\n", __func__); + + atomic_fetch_sub(&g_state_barrier, 1); + return ctx; } void ggml_free(struct ggml_context * ctx) { + // make this function thread safe + { + int processing = atomic_fetch_add(&g_state_barrier, 1); + while (processing > 0) { + // wait for other threads to finish + atomic_fetch_sub(&g_state_barrier, 1); + sched_yield(); + processing = atomic_fetch_add(&g_state_barrier, 1); + } + } + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { if (&g_state.contexts[i].context == ctx) { g_state.contexts[i].used = false; @@ -1337,11 +1367,15 @@ void ggml_free(struct ggml_context * ctx) { free(ctx->mem_buffer); } + atomic_fetch_sub(&g_state_barrier, 1); + return; } } GGML_PRINT_DEBUG("%s: context not found\n", __func__); + + atomic_fetch_sub(&g_state_barrier, 1); } size_t ggml_used_mem(const struct ggml_context * ctx) {