From 0bc2cdfc875fa7877d8e01c8bb17066f99c08f21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 1 Jul 2023 21:49:44 +0200 Subject: [PATCH] Better CUDA synchronization logic (#2057) --- ggml-cuda.cu | 63 ++++++++++++++++++++++++++++++++++++++-------------- ggml-cuda.h | 4 ---- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 4e0d3dbde..50df20edd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -214,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif +struct ggml_tensor_extra_gpu { + void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors + cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs +}; + static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -1970,7 +1975,6 @@ inline void ggml_cuda_op_add( } else { GGML_ASSERT(false); } - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2002,7 +2006,6 @@ inline void ggml_cuda_op_mul( // compute mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); } (void) dst; @@ -2023,7 +2026,6 @@ inline void ggml_cuda_op_silu( // compute silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2046,7 +2048,6 @@ inline void ggml_cuda_op_rms_norm( // compute rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2125,7 +2126,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( GGML_ASSERT(false); break; } - CUDA_CHECK(cudaGetLastError()); #ifdef GGML_CUDA_DMMV_F16 if (src1_convert_f16) { @@ -2202,7 +2202,6 @@ inline void ggml_cuda_op_rope( // compute rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) dst; (void) src0_ddq_i; @@ -2226,7 +2225,6 @@ inline void ggml_cuda_op_diag_mask_inf( // compute diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) dst; (void) src0_ddq_i; @@ -2248,7 +2246,6 @@ inline void ggml_cuda_op_soft_max( // compute soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); - CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; @@ -2344,10 +2341,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0}; - // if multiple GPUs are used they need to wait for the main GPU to finish + // if multiple devices are used they need to wait for the main device + // here an event is recorded that signifies that the main device has finished calculating the input data if (split && g_device_count > 1) { CUDA_CHECK(cudaSetDevice(g_main_device)); - CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device])); } for (int id = 0; id < g_device_count; ++id) { @@ -2373,6 +2371,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm int64_t row_diff = row_high - row_low; cudaSetDevice(id); + cudaStream_t cudaStream_main = g_cudaStreams_main[id]; + + // wait for main GPU data if necessary + if (split && id != g_main_device) { + CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device])); + } if (src0_on_device && src0_is_contiguous) { if (src0_is_f32) { @@ -2448,8 +2452,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } const int64_t i11 = i13*ne12 + i12; - cudaStream_t cudaStream_main = g_cudaStreams_main[id]; - // for split tensors the data begins at i0 == i0_offset_low char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride; @@ -2509,6 +2511,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm // do the computation op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main); + CUDA_CHECK(cudaGetLastError()); // copy dst to host or other device if necessary if (!dst_on_device) { @@ -2538,6 +2541,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main)); } } + + // signify to main device that other device is done + if (split && g_device_count > 1 && id != g_main_device) { + CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main)); + } } } } @@ -2549,7 +2557,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } CUDA_CHECK(cudaSetDevice(id)); - CUDA_CHECK(cudaDeviceSynchronize()); if (src0_asq[id] > 0) { ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]); @@ -2564,6 +2571,21 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]); } } + + // main device waits for all other devices to be finished + if (split && g_device_count > 1) { + CUDA_CHECK(cudaSetDevice(g_main_device)); + for (int id = 0; id < g_device_count; ++id) { + if (id != g_main_device) { + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id])); + } + } + } + + if (dst->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(cudaDeviceSynchronize()); + } } void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -2803,6 +2825,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); extra->data_device[id] = buf; + + if (backend == GGML_BACKEND_GPU_SPLIT) { + CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming)); + } } tensor->extra = extra; @@ -2816,12 +2842,15 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; for (int id = 0; id < g_device_count; ++id) { - if (extra->data_device[id] == nullptr) { - continue; + if (extra->data_device[id] != nullptr) { + CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(cudaFree(extra->data_device[id])); } - CUDA_CHECK(cudaSetDevice(id)); - CUDA_CHECK(cudaFree(extra->data_device[id])); + if (extra->events[id] != nullptr) { + CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(cudaEventDestroy(extra->events[id])); + } } delete extra; diff --git a/ggml-cuda.h b/ggml-cuda.h index 7a65a3558..3c1e8deb6 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -8,10 +8,6 @@ extern "C" { #define GGML_CUDA_MAX_DEVICES 16 -struct ggml_tensor_extra_gpu { - void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors -}; - void ggml_init_cublas(void); void ggml_cuda_set_tensor_split(const float * tensor_split);