diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 50344ae87..9e9eac487 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -221,10 +221,13 @@ typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); typedef void (*cpy_kernel_t)(const char * cx, char * cdst); typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); -typedef void (*ggml_cuda_op_t)( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, - float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main); +typedef void (*ggml_cuda_op_mul_mat_t)( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream); +typedef void (*ggml_cuda_op_flatten_t)( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream); // QK = number of values after dequantization // QR = QK / number of values before dequantization @@ -405,11 +408,29 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif +#define MUL_MAT_SRC1_COL_STRIDE 128 + +#define MAX_STREAMS 8 +static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; + struct ggml_tensor_extra_gpu { void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors - cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs + cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs }; +// this is faster on Windows +// probably because the Windows CUDA libraries forget to make this check before invoking the drivers +inline cudaError_t ggml_cuda_set_device(const int device) { + int current_device; + CUDA_CHECK(cudaGetDevice(¤t_device)); + + if (device == current_device) { + return cudaSuccess; + } + + return cudaSetDevice(device); +} + static int g_device_count = -1; static int g_main_device = 0; static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; @@ -422,8 +443,6 @@ static size_t g_scratch_offset = 0; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; -static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr }; - static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -5139,25 +5158,27 @@ void ggml_init_cublas() { GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count); - for (int id = 0; id < g_device_count; ++id) { + for (int64_t id = 0; id < g_device_count; ++id) { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); - fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor); + fprintf(stderr, " Device %ld: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor); g_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; } - for (int id = 0; id < g_device_count; ++id) { + for (int64_t id = 0; id < g_device_count; ++id) { g_tensor_split[id] /= total_vram; } - for (int id = 0; id < g_device_count; ++id) { - CUDA_CHECK(cudaSetDevice(id)); + for (int64_t id = 0; id < g_device_count; ++id) { + CUDA_CHECK(ggml_cuda_set_device(id)); - // create main stream - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking)); + // create cuda streams + for (int64_t is = 0; is < MAX_STREAMS; ++is) { + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[id][is], cudaStreamNonBlocking)); + } // create cublas handle CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); @@ -5265,225 +5286,169 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( } inline void ggml_cuda_op_add( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr); - GGML_ASSERT(src1_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); - - const int64_t ne00 = src0->ne[0]; - const int64_t i01_diff = i01_high - i01_low; + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; - // compute if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main); + add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main); + add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream); } else { GGML_ASSERT(false); } (void) src1; (void) dst; - (void) src0_ddq_i; - (void) i02; - (void) i1; } inline void ggml_cuda_op_mul( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0_ddf_i != nullptr); - GGML_ASSERT(src1_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); - - const int64_t ne00 = src0->ne[0]; - const int64_t i01_diff = i01_high - i01_low; + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; - mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main); + mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream); (void) dst; - (void) src0_ddq_i; - (void) i02; - (void) i1; } inline void ggml_cuda_op_gelu( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t i01_diff = i01_high - i01_low; - - // compute - gelu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main); + gelu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); (void) src1; (void) dst; - (void) src0_ddq_i; - (void) src1_ddf_i; - (void) i02; - (void) i1; + (void) src1_dd; } inline void ggml_cuda_op_silu( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t i01_diff = i01_high - i01_low; - - // compute - silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main); + silu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); (void) src1; (void) dst; - (void) src0_ddq_i; - (void) src1_ddf_i; - (void) i02; - (void) i1; + (void) src1_dd; } inline void ggml_cuda_op_norm( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; - const int64_t i01_diff = i01_high - i01_low; + const int64_t nrows = ggml_nrows(src0); - // compute - norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); + norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); (void) src1; (void) dst; - (void) src0_ddq_i; - (void) src1_ddf_i; - (void) i02; - (void) i1; + (void) src1_dd; } inline void ggml_cuda_op_rms_norm( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; - const int64_t i01_diff = i01_high - i01_low; + const int64_t nrows = ggml_nrows(src0); float eps; memcpy(&eps, dst->op_params, sizeof(float)); - // compute - rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main); + rms_norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream); (void) src1; (void) dst; - (void) src0_ddq_i; - (void) src1_ddf_i; - (void) i02; - (void) i1; + (void) src1_dd; } inline void ggml_cuda_op_mul_mat_q( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ - - GGML_ASSERT(src0_ddq_i != nullptr); - GGML_ASSERT(src1_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { const int64_t ne00 = src0->ne[0]; const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; GGML_ASSERT(ne10 % QK8_1 == 0); const int64_t ne0 = dst->ne[0]; - const int64_t i01_diff = i01_high - i01_low; + const int64_t row_diff = row_high - row_low; int id; CUDA_CHECK(cudaGetDevice(&id)); // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into - const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff; - - const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ? - ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; - size_t as; - void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*sizeof(block_q8_1)/QK8_1, &as); - quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, cudaStream_main); + const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; switch (src0->type) { case GGML_TYPE_Q4_0: - ggml_mul_mat_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); break; case GGML_TYPE_Q4_1: - ggml_mul_mat_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); break; case GGML_TYPE_Q5_0: - ggml_mul_mat_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); break; case GGML_TYPE_Q5_1: - ggml_mul_mat_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); break; case GGML_TYPE_Q8_0: - ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); break; case GGML_TYPE_Q2_K: - ggml_mul_mat_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); break; case GGML_TYPE_Q3_K: - ggml_mul_mat_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); break; case GGML_TYPE_Q4_K: - ggml_mul_mat_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); break; case GGML_TYPE_Q5_K: - ggml_mul_mat_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); break; case GGML_TYPE_Q6_K: - ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); break; default: GGML_ASSERT(false); break; } - ggml_cuda_pool_free(src1_q8_1, as); - (void) src1; (void) dst; - (void) src0_ddf_i; - (void) i02; - (void) i1; + (void) src1_ddf_i; } static int64_t get_row_rounding(ggml_type type) { @@ -5517,168 +5482,144 @@ static int64_t get_row_rounding(ggml_type type) { } } -inline void ggml_cuda_op_mul_mat_vec( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ - - GGML_ASSERT(src0_ddq_i != nullptr); - GGML_ASSERT(src1_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); +inline void ggml_cuda_op_mul_mat_vec_q( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { const int64_t ne00 = src0->ne[0]; - const int64_t nrows = i01_high - i01_low; + const int64_t row_diff = row_high - row_low; -#ifdef GGML_CUDA_FORCE_DMMV - const bool use_mul_mat_vec_q = false; - (void) g_compute_capabilities[0]; -#else - int id; - CUDA_CHECK(cudaGetDevice(&id)); - - bool mul_mat_vec_q_implemented = - src0->type == GGML_TYPE_Q4_0 || - src0->type == GGML_TYPE_Q4_1 || - src0->type == GGML_TYPE_Q5_0 || - src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0; -#if QK_K == 256 - mul_mat_vec_q_implemented = mul_mat_vec_q_implemented || - src0->type == GGML_TYPE_Q2_K || - src0->type == GGML_TYPE_Q3_K || - src0->type == GGML_TYPE_Q4_K || - src0->type == GGML_TYPE_Q5_K || - src0->type == GGML_TYPE_Q6_K; -#endif // QK_K == 256 - - const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= MIN_CC_DP4A && mul_mat_vec_q_implemented; -#endif - - if (use_mul_mat_vec_q) { - const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ? - ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; - size_t as; - void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as); - quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, 1, padded_row_size, cudaStream_main); - - switch (src0->type) { - case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - default: - GGML_ASSERT(false); - break; - } - - ggml_cuda_pool_free(src1_q8_1, as); - } else { - // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics -#ifdef GGML_CUDA_F16 - size_t ash; - dfloat * src1_dfloat = nullptr; // dfloat == half - - bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || - src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; - - if (src1_convert_f16) { - src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash); - ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00, - ne00, 1, sizeof(float), 0, 0, - ne00, 1, sizeof(half), 0, 0, cudaStream_main); - } -#else - dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion -#endif // GGML_CUDA_F16 - - switch (src0->type) { - case GGML_TYPE_Q4_0: - dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q4_1: - dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q5_0: - dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q5_1: - dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q2_K: - dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q3_K: - dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q4_K: - dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q5_K: - dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q6_K: - dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_F16: - convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - default: - GGML_ASSERT(false); - break; - } - -#ifdef GGML_CUDA_F16 - if (src1_convert_f16) { - ggml_cuda_pool_free(src1_dfloat, ash); - } -#endif // GGML_CUDA_F16 + switch (src0->type) { + case GGML_TYPE_Q4_0: + mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + default: + GGML_ASSERT(false); + break; } (void) src1; (void) dst; - (void) src0_ddf_i; - (void) i02; - (void) i1; + (void) src1_ddf_i; + (void) src1_ncols; + (void) src1_padded_row_size; +} + +inline void ggml_cuda_op_dequantize_mul_mat_vec( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics +#ifdef GGML_CUDA_F16 + size_t ash; + dfloat * src1_dfloat = nullptr; // dfloat == half + + bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || + src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + + if (src1_convert_f16) { + src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash); + ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, + ne00, 1, sizeof(float), 0, 0, + ne00, 1, sizeof(half), 0, 0, stream); + } +#else + const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion +#endif // GGML_CUDA_F16 + + switch (src0->type) { + case GGML_TYPE_Q4_0: + dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_1: + dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_K: + dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_F16: + convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + default: + GGML_ASSERT(false); + break; + } + +#ifdef GGML_CUDA_F16 + if (src1_convert_f16) { + ggml_cuda_pool_free(src1_dfloat, ash); + } +#endif // GGML_CUDA_F16 + + (void) src1; + (void) dst; + (void) src1_ddq_i; + (void) src1_ncols; + (void) src1_padded_row_size; } inline void ggml_cuda_op_mul_mat_cublas( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { - GGML_ASSERT(src0_ddf_i != nullptr); + GGML_ASSERT(src0_dd_i != nullptr); GGML_ASSERT(src1_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); + GGML_ASSERT(dst_dd_i != nullptr); const float alpha = 1.0f; const float beta = 0.0f; @@ -5686,43 +5627,48 @@ inline void ggml_cuda_op_mul_mat_cublas( const int64_t ne00 = src0->ne[0]; const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; const int64_t ne0 = dst->ne[0]; - const int64_t i01_diff = i01_high - i01_low; + const int64_t row_diff = row_high - row_low; + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); + size_t src0_as; + float * src0_ddf_i = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); + to_fp32_cuda(src0_dd_i, src0_ddf_i, row_diff*ne00, stream); int id; CUDA_CHECK(cudaGetDevice(&id)); // the main device has a larger memory buffer to hold the results from all GPUs // ldc == nrows of the matrix that cuBLAS writes into - int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff; + int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], cudaStream_main)); + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); CUBLAS_CHECK( cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, - i01_diff, ne11, ne10, + row_diff, src1_ncols, ne10, &alpha, src0_ddf_i, ne00, - src1_ddf_i, ne10, - &beta, dst_ddf_i, ldc)); + src1_ddf_i, ne10, + &beta, dst_dd_i, ldc)); + + ggml_cuda_pool_free(src0_ddf_i, src0_as); (void) dst; - (void) src0_ddq_i; - (void) i02; - (void) i1; + (void) src0_dd_i; + (void) src1_ddq_i; + (void) src1_padded_row_size; } inline void ggml_cuda_op_rope( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; - const int64_t i01_diff = i01_high - i01_low; + const int64_t nrows = ggml_nrows(src0); const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -5742,33 +5688,30 @@ inline void ggml_cuda_op_rope( // compute if (is_glm) { - rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, n_ctx, cudaStream_main); + rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, n_ctx, main_stream); } else if (is_neox) { GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); - rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main); + rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream); } else { - rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main); + rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, p0, freq_scale, ne01, theta_scale, main_stream); } (void) src1; (void) dst; - (void) src0_ddq_i; - (void) src1_ddf_i; - (void) i1; + (void) src1_dd; } inline void ggml_cuda_op_alibi( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; - const int64_t i01_diff = i01_high - i01_low; + const int64_t nrows = ggml_nrows(src0); const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; @@ -5783,334 +5726,355 @@ inline void ggml_cuda_op_alibi( const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - // compute - alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main); + alibi_f32_cuda(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream); (void) src1; - (void) src0_ddq_i; - (void) src1_ddf_i; - (void) i1; + (void) src1_dd; } inline void ggml_cuda_op_diag_mask_inf( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; - const int64_t i01_diff = i01_high - i01_low; + const int nrows0 = ggml_nrows(src0); const int n_past = ((int32_t *) dst->op_params)[0]; - // compute - diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main); + diag_mask_inf_f32_cuda(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); (void) src1; (void) dst; - (void) src0_ddq_i; - (void) src1_ddf_i; - (void) i02; - (void) i1; + (void) src1_dd; } inline void ggml_cuda_op_soft_max( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; - const int64_t i01_diff = i01_high - i01_low; + const int64_t nrows = ggml_nrows(src0); - // compute - soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); + soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); (void) src1; (void) dst; - (void) src0_ddq_i; - (void) src1_ddf_i; - (void) i02; - (void) i1; + (void) src1_dd; } inline void ggml_cuda_op_scale( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, - float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, - cudaStream_t & cudaStream_main){ + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0_ddf_i != nullptr); - GGML_ASSERT(dst_ddf_i != nullptr); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); const float scale = ((float *) src1->data)[0]; - const int64_t ne00 = src0->ne[0]; - const int64_t i01_diff = i01_high - i01_low; - - // compute - scale_f32_cuda(src0_ddf_i, dst_ddf_i, scale, ne00*i01_diff, cudaStream_main); + scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream); CUDA_CHECK(cudaGetLastError()); (void) src1; (void) dst; - (void) src0_ddq_i; - (void) src1_ddf_i; - (void) i02; - (void) i1; + (void) src1_dd; } -static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) { +static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) { + const int64_t nrows0 = ggml_nrows(src0); + + const bool use_src1 = src1 != nullptr; + const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; + + GGML_ASSERT( src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT); + + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + + const bool src0_on_device = src0->backend == GGML_BACKEND_GPU; + const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU; + const bool dst_on_device = dst->backend == GGML_BACKEND_GPU; + + const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE; + + // dd = data device + float * src0_ddf = nullptr; + float * src1_ddf = nullptr; + float * dst_ddf = nullptr; + + // as = actual size + size_t src0_asf = 0; + size_t src1_asf = 0; + size_t dst_asf = 0; + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + if (src0_on_device) { + src0_ddf = (float *) src0_extra->data_device[g_main_device]; + } else { + src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream)); + } + + if (use_src1 && !src1_stays_on_host) { + if (src1_on_device) { + src1_ddf = (float *) src1_extra->data_device[g_main_device]; + } else { + src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream)); + } + } + if (dst_on_device) { + dst_ddf = (float *) dst_extra->data_device[g_main_device]; + } else { + dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf); + } + + // do the computation + op(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); + CUDA_CHECK(cudaGetLastError()); + + // copy dst to host if necessary + if (!dst_on_device) { + CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream)); + } + + if (src0_asf > 0) { + ggml_cuda_pool_free(src0_ddf, src0_asf); + } + if (src1_asf > 0) { + ggml_cuda_pool_free(src1_ddf, src1_asf); + } + if (dst_asf > 0) { + ggml_cuda_pool_free(dst_ddf, dst_asf); + } + + if (dst->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(cudaDeviceSynchronize()); + } +} + +static void ggml_cuda_op_mul_mat( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, + const bool convert_src1_to_q8_1) { + const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; const int64_t nrows0 = ggml_nrows(src0); - const bool use_src1 = src1 != nullptr; - const int64_t ne10 = use_src1 ? src1->ne[0] : 1; - const int64_t ne11 = use_src1 ? src1->ne[1] : 1; - const int64_t ne12 = use_src1 ? src1->ne[2] : 1; - const int64_t ne13 = use_src1 ? src1->ne[3] : 1; - const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + const int64_t nrows1 = ggml_nrows(src1); GGML_ASSERT(ne03 == ne13); const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT); - // strides for iteration over dims 3 and 2 - const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13; - const int64_t num_iters = flatten_rows ? 1 : num_iters_0; - const int64_t stride_mod = flatten_rows ? num_iters_0 : 1; - const int64_t src0_stride = ne00 * ne01 * stride_mod; - const int64_t src1_stride = ne10 * ne11 * stride_mod; - const int64_t dst_stride = ne0 * ne1 * stride_mod; + GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); - const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; - const int64_t i03_max = flatten_rows ? 1 : ne03; - const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12); - const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02; - GGML_ASSERT(!(flatten_rows && ne02 < ne12)); + const int64_t i02_divisor = ne12 / ne02; const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type); + const size_t q8_1_ts = sizeof(block_q8_1); + const size_t q8_1_bs = QK8_1; - struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; - struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; const bool src0_is_contiguous = ggml_is_contiguous(src0); - const bool src0_is_f32 = src0->type == GGML_TYPE_F32; - const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1); - const bool src1_stays_on_host = use_src1 && ( - dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE); + const bool src1_is_contiguous = ggml_is_contiguous(src1); + const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ? + ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + GGML_ASSERT(!(split && ne02 > 1)); + GGML_ASSERT(!(split && ne03 > 1)); GGML_ASSERT(!(split && ne02 < ne12)); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); - // dd = data device - char * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized - float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float - float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; - float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; + char * src0_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; + float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float + char * src1_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // q8_1 + float * dst_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; - // asq = actual size quantized, asf = actual size float - size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0}; - size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0}; + // as = actual size + size_t src0_as[GGML_CUDA_MAX_DEVICES] = {0}; size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; - size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0}; + size_t src1_asq[GGML_CUDA_MAX_DEVICES] = {0}; + size_t dst_as[GGML_CUDA_MAX_DEVICES] = {0}; - // if multiple devices are used they need to wait for the main device - // here an event is recorded that signifies that the main device has finished calculating the input data - if (split && g_device_count > 1) { - CUDA_CHECK(cudaSetDevice(g_main_device)); - CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device])); - } + int64_t row_low[GGML_CUDA_MAX_DEVICES]; + int64_t row_high[GGML_CUDA_MAX_DEVICES]; - for (int id = 0; id < g_device_count; ++id) { - if (!split && id != g_main_device) { - continue; - } + for (int64_t id = 0; id < g_device_count; ++id) { + // by default, use all rows + row_low[id] = 0; + row_high[id] = ne01; - const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU && id == g_main_device; - const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; - - int64_t row_low, row_high; + // for multi GPU, get the row boundaries from tensor split + // and round to mul_mat_q tile sizes if (split) { const int64_t rounding = get_row_rounding(src0->type); - row_low = id == 0 ? 0 : nrows0*g_tensor_split[id]; - row_low -= row_low % rounding; - - if (id == g_device_count - 1) { - row_high = nrows0; - } else { - row_high = nrows0*g_tensor_split[id + 1]; - row_high -= row_high % rounding; + if (id != 0) { + row_low[id] = ne01*g_tensor_split[id]; + row_low[id] -= row_low[id] % rounding; + } + + if (id != g_device_count - 1) { + row_high[id] = ne01*g_tensor_split[id + 1]; + row_high[id] -= row_high[id] % rounding; } - } else { - row_low = 0; - row_high = nrows0*i02_divisor; } - if (row_low == row_high) { + } + + for (int64_t id = 0; id < g_device_count; ++id) { + if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { continue; } - int64_t row_diff = row_high - row_low; + const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; + const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; - cudaSetDevice(id); - cudaStream_t cudaStream_main = g_cudaStreams_main[id]; - - // wait for main GPU data if necessary - if (split && id != g_main_device) { - CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device])); - } + ggml_cuda_set_device(id); + const cudaStream_t stream = g_cudaStreams[id][0]; if (src0_on_device && src0_is_contiguous) { - if (src0_is_f32) { - src0_ddf[id] = (float *) src0_extra->data_device[id]; - } else { - src0_ddq[id] = (char *) src0_extra->data_device[id]; - } + src0_dd[id] = (char *) src0_extra->data_device[id]; } else { - if (src0_is_f32) { - src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]); - } else { - src0_ddq[id] = (char *) ggml_cuda_pool_malloc(row_diff*ne00 * src0_ts/src0_bs, &src0_asq[id]); + const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); + src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); + } + + if (src1_on_device && src1_is_contiguous) { + src1_ddf[id] = (float *) src1_extra->data_device[id]; + } else { + src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]); + } + + if (convert_src1_to_q8_1) { + src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); + + if (split && src1_on_device && src1_is_contiguous) { + quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); } } - if (src0_needs_f32 && !src0_is_f32) { - src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]); - } - - if (use_src1 && !src1_stays_on_host) { - if (src1_on_device && src1_is_contiguous) { - src1_ddf[id] = (float *) src1_extra->data_device[id]; - } else { - src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]); - } - } if (dst_on_device) { - dst_ddf[id] = (float *) dst_extra->data_device[id]; + dst_dd[id] = (float *) dst_extra->data_device[id]; } else { - size_t size_dst_ddf = split ? row_diff*ne1 * sizeof(float) : num_iters*dst_stride * sizeof(float); - dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]); + const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); + dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]); } + } - for (int64_t i03 = 0; i03 < i03_max; i03++) { - const int64_t i13 = i03 % ne13; - for (int64_t i02 = 0; i02 < i02_max; i02++) { - const int64_t i12 = i02 % ne12; + // if multiple devices are used they need to wait for the main device + // here an event is recorded that signals that the main device has finished calculating the input data + if (split && g_device_count > 1) { + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0])); + } - const int64_t i0 = i03*i02_max + i02; + const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; + for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { + const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0; + const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; - // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs - const int64_t i0_offset_low = row_low/rows_per_iter; - const int64_t i0_offset_high = row_high/rows_per_iter; + for (int64_t id = 0; id < g_device_count; ++id) { + if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { + continue; + } - int64_t i01_low = 0; - int64_t i01_high = rows_per_iter; - if (split) { - if (i0 < i0_offset_low || i0 > i0_offset_high) { - continue; - } - if (i0 == i0_offset_low) { - i01_low = row_low % rows_per_iter; - } - if (i0 == i0_offset_high) { - i01_high = row_high % rows_per_iter; - } - } + const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; + const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; + const int64_t row_diff = row_high[id] - row_low[id]; - // There is possibly a bug in the Windows nvcc compiler regarding instruction reordering or optimizing out local variables. - // Removing the first assert or changing the order of the arguments causes the second assert to fail. - // Removing both asserts results in i01_high becoming 0 which in turn results in garbage output. - // The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU). - GGML_ASSERT(i01_low == 0 || g_device_count > 1); - GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1); + ggml_cuda_set_device(id); + const cudaStream_t stream = g_cudaStreams[id][is]; - const int64_t i01_diff = i01_high - i01_low; - if (i01_diff == 0) { - continue; - } - const int64_t i11 = i13*ne12 + i12; + // wait for main GPU data if necessary + if (split && (id != g_main_device || is != 0)) { + CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0])); + } + + for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) { + const int64_t i03 = i0 / ne12; + const int64_t i02 = i0 % ne12; + + const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; // for split tensors the data begins at i0 == i0_offset_low - char * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs; - float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride; - float * src1_ddf_i = src1_ddf[id] + i11*src1_stride; - float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; - - // for split tensors the data pointer needs to be rounded down - // to the bin edge for i03, i02 bins beyond the first - if (i0 - i0_offset_low > 0) { - GGML_ASSERT(!flatten_rows); - src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs; - src0_ddf_i -= (row_low % ne01)*ne00; - dst_ddf_i -= (row_low % ne0)*ne1; - } + char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs; + float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10; + char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset; + float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); // the main device memory buffer can be on VRAM scratch, with space for all partial results // in that case an offset on dst_ddf_i is needed if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) { - dst_ddf_i += i01_low; // offset is 0 if no tensor split + dst_dd_i += row_low[id]; // offset is 0 if no tensor split } // copy src0, src1 to device if necessary - if (use_src1 && !src1_stays_on_host) { - if (src1->backend == GGML_BACKEND_CPU) { - GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1)); - int64_t nrows1 = flatten_rows ? nrows0 : ne11; - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main)); - } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { - if (id != g_main_device) { - GGML_ASSERT(!flatten_rows); + if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { + if (id != g_main_device) { + if (convert_src1_to_q8_1) { + char * src1_ddq_i_source = src1_ddq[g_main_device] + src1_ddq_i_offset; + CUDA_CHECK(cudaMemcpyAsync(src1_ddq_i, src1_ddq_i_source, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, + cudaMemcpyDeviceToDevice, stream)); + } else { float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device]; - src1_ddf_i_source += i11*src1_stride; - CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float), - cudaMemcpyDeviceToDevice, cudaStream_main)); + src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; + CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_ncols*ne10*sizeof(float), + cudaMemcpyDeviceToDevice, stream)); } - } else if (src1_on_device && !src1_is_contiguous) { - GGML_ASSERT(!split); - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main)); - } else { - GGML_ASSERT(false); } + } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) { + CUDA_CHECK(ggml_cuda_cpy_tensor_2d( + src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); + } else { + GGML_ASSERT(false); } - if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { - if (src0_is_f32) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main)); - } else { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main)); - } - } - - // convert src0 to f32 if it is necessary for the ggml_cuda_op - if (src0_needs_f32 && !src0_is_f32) { - to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main); + if (convert_src1_to_q8_1 && src1->backend == GGML_BACKEND_CPU) { + quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); } + if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream)); + } + // do the computation - op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main); + op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, + row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); // copy dst to host or other device if necessary @@ -6132,95 +6096,86 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU. // Instead they need to be copied to the correct slice in ne0 = dst row index. // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results. - float * dhf_dst_i = (float *) ((char *) dst_off_device + i01_low*sizeof(float) + i02*nb2 + i03*nb3); - CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_ddf_i, i01_diff*sizeof(float), - i01_diff*sizeof(float), ne1, kind, cudaStream_main)); + float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); + GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); + dhf_dst_i += src1_col_0*ne0 + row_low[id]; + CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float), + row_diff*sizeof(float), src1_ncols, kind, stream)); } else { float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); - CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main)); + GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); + dhf_dst_i += src1_col_0*ne0; + CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), kind, stream)); } } - // signify to main device that other device is done - if (split && g_device_count > 1 && id != g_main_device) { - CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main)); + // add event for the main device to wait on until other device is done + if (split && (id != g_main_device || is != 0)) { + CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream)); } } } } - // wait until each device is finished, then free their buffers - for (int id = 0; id < g_device_count; ++id) { - if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) { - continue; - } + for (int64_t id = 0; id < g_device_count; ++id) { + CUDA_CHECK(ggml_cuda_set_device(id)); - CUDA_CHECK(cudaSetDevice(id)); - - if (src0_asq[id] > 0) { - ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]); - } - if (src0_asf[id] > 0) { - ggml_cuda_pool_free(src0_ddf[id], src0_asf[id]); + // free buffers again when done + if (src0_as[id] > 0) { + ggml_cuda_pool_free(src0_dd[id], src0_as[id]); } if (src1_asf[id] > 0) { ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); } - if (dst_asf[id] > 0) { - ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]); + if (src1_asq[id] > 0) { + ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); + } + if (dst_as[id] > 0) { + ggml_cuda_pool_free(dst_dd[id], dst_as[id]); } } // main device waits for all other devices to be finished if (split && g_device_count > 1) { - CUDA_CHECK(cudaSetDevice(g_main_device)); - for (int id = 0; id < g_device_count; ++id) { - if (id != g_main_device && src0_extra->events[id]) { - CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id])); + int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; + is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS; + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + for (int64_t id = 0; id < g_device_count; ++id) { + for (int64_t is = 0; is < is_max; ++is) { + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is])); } } } if (dst->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(cudaDeviceSynchronize()); } } void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op. - // Due to flatten_rows == true this does in practice not make a difference however. - // Better solution would be nice but right now that would require disproportionate changes. - GGML_ASSERT( - (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && - src1->type == GGML_TYPE_F32 && - (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16)); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add); } void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); } void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_gelu, true, true); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu); } void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu); } void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm); } void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm); } bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { @@ -6254,8 +6209,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr const int64_t ne12 = src1->ne[2]; - CUDA_CHECK(cudaSetDevice(g_main_device)); - cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; @@ -6266,7 +6221,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main); + ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); } void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ @@ -6285,8 +6240,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int64_t nb01 = src0->nb[1]; const int64_t nb02 = src0->nb[2]; - CUDA_CHECK(cudaSetDevice(g_main_device)); - cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; @@ -6297,38 +6252,49 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - const int row_stride_x = nb01 / sizeof(half); - const int channel_stride_x = nb02 / sizeof(half); + const int64_t row_stride_x = nb01 / sizeof(half); + const int64_t channel_stride_x = nb02 / sizeof(half); - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main); + ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); } void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; + int64_t min_compute_capability = INT_MAX; + for (int64_t id = 0; id < g_device_count; ++id) { + if (min_compute_capability > g_compute_capabilities[id] + && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { + min_compute_capability = g_compute_capabilities[id]; + } + } + if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { ggml_cuda_mul_mat_vec_p021(src0, src1, dst); } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { ggml_cuda_mul_mat_vec_nc(src0, src1, dst); }else if (src0->type == GGML_TYPE_F32) { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false); - } else { - int min_compute_capability = INT_MAX; - for (int id = 0; id < g_device_count; ++id) { - if (min_compute_capability > g_compute_capabilities[id] - && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { - min_compute_capability = g_compute_capabilities[id]; - } - } - if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false); +#ifdef GGML_CUDA_FORCE_DMMV + const bool use_mul_mat_vec_q = false; +#else + const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); +#endif // GGML_CUDA_FORCE_DMMV + + if (use_mul_mat_vec_q) { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); } else { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); + } + } else { + if (src1->backend == GGML_BACKEND_GPU && g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true); + } else { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } } } else { @@ -6337,8 +6303,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ } void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -6367,8 +6332,8 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens const int64_t nb11 = src1->nb[1]; const int64_t nb12 = src1->nb[2]; - CUDA_CHECK(cudaSetDevice(g_main_device)); - cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; @@ -6378,10 +6343,10 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, cudaStream_main); + ne10, ne11, nb10, nb11, nb12, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, cudaStream_main); + ne10, ne11, nb10, nb11, nb12, main_stream); } else { GGML_ASSERT(false); } @@ -6395,25 +6360,20 @@ void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens } void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf); } void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max); } void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented - - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, true); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope); } void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); } void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -6423,7 +6383,7 @@ void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens } void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { - int nrows = ggml_nrows(tensor); + const int64_t nrows = ggml_nrows(tensor); const int64_t ne0 = tensor->ne[0]; @@ -6433,14 +6393,14 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu; memset(extra, 0, sizeof(*extra)); - for (int id = 0; id < g_device_count; ++id) { + for (int64_t id = 0; id < g_device_count; ++id) { if (backend == GGML_BACKEND_GPU && id != g_main_device) { continue; } - cudaSetDevice(id); + ggml_cuda_set_device(id); - int row_low, row_high; + int64_t row_low, row_high; if (backend == GGML_BACKEND_GPU) { row_low = 0; row_high = nrows; @@ -6490,7 +6450,9 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { extra->data_device[id] = buf; if (backend == GGML_BACKEND_GPU_SPLIT) { - CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming)); + for (int64_t is = 0; is < MAX_STREAMS; ++is) { + CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming)); + } } } @@ -6504,15 +6466,17 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; - for (int id = 0; id < g_device_count; ++id) { + for (int64_t id = 0; id < g_device_count; ++id) { if (extra->data_device[id] != nullptr) { - CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(ggml_cuda_set_device(id)); CUDA_CHECK(cudaFree(extra->data_device[id])); } - if (extra->events[id] != nullptr) { - CUDA_CHECK(cudaSetDevice(id)); - CUDA_CHECK(cudaEventDestroy(extra->events[id])); + for (int64_t is = 0; is < MAX_STREAMS; ++is) { + if (extra->events[id][is] != nullptr) { + CUDA_CHECK(ggml_cuda_set_device(id)); + CUDA_CHECK(cudaEventDestroy(extra->events[id][is])); + } } } @@ -6564,7 +6528,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo force_inplace; const size_t size = ggml_nbytes(tensor); - CUDA_CHECK(cudaSetDevice(g_main_device)); + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device];