diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9ebc57aff..36a251ecc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -515,15 +515,15 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float const block_q2_K * x = (const block_q2_K *)vx + ib0; - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 const int step = 16/K_QUANTS_PER_ITERATION; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...7 + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...14 in steps of 4 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 const int q_offset = 32*im + l0; const int s_offset = 8*im; const int y_offset = 128*im + l0; @@ -578,27 +578,30 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float } } -static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols) { +static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; - const int row = blockIdx.x; + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; + const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; const block_q3_K * x = (const block_q3_K *)vx + ib0; - const int tid = threadIdx.x/2; // 0...15 - const int ix = threadIdx.x%2; // 0, 1 + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - const int n = 2; // iterations in the inner loop - const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - 8*im; // 0...7 + const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 const uint8_t m = 1 << (4*im); - const int l0 = n*in; // 0...28 in steps of 4 + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 const int q_offset = 32*im + l0; const int y_offset = 128*im + l0; @@ -609,7 +612,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 2) { + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; @@ -650,22 +653,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float } } -static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols) { +static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int row = blockIdx.x; + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; - const int tid = threadIdx.x/2; // 0...15 - const int ix = threadIdx.x%2; + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 4; + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int in = il%2; @@ -681,7 +687,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 2) { + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const uint8_t * q1 = x[i].qs + q_offset; const uint8_t * q2 = q1 + 64; @@ -736,7 +742,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float const int il = tid/4; // 0...3 const int ir = tid - 4*il;// 0...3 - const int n = 4; + const int n = 2; const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int in = il%2; @@ -775,11 +781,16 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float float4 sum = {0.f, 0.f, 0.f, 0.f}; float smin = 0; for (int l = 0; l < n; ++l) { - sum.x += y1[l+ 0] * ((ql1[l] & 0xF) + (qh[l] & (hm1 << 0) ? 16 : 0)); - sum.y += y1[l+32] * ((ql1[l] >> 4) + (qh[l] & (hm1 << 1) ? 16 : 0)); - sum.z += y2[l+ 0] * ((ql2[l] & 0xF) + (qh[l] & (hm2 << 0) ? 16 : 0)); - sum.w += y2[l+32] * ((ql2[l] >> 4) + (qh[l] & (hm2 << 1) ? 16 : 0)); - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) + + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0)); + sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) + + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0)); + sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) + + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0)); + sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) + + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0)); + smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] + + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; } tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; @@ -1311,7 +1322,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; + const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(32, ny, 1); @@ -1320,14 +1331,20 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 1, 1); - dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 1, 1); - dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {