CUDA: Fixed 7b q3_K_S with mul_mat_vec_q (#2313)

This commit is contained in:
Johannes Gäßler 2023-07-22 21:27:34 +02:00 committed by GitHub
parent b47b8a9cfe
commit b9b7d94fc1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -220,7 +220,7 @@ typedef struct {
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
#define WARP_SIZE 32
#define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
#define CUDA_ADD_BLOCK_SIZE 256
#define CUDA_MUL_BLOCK_SIZE 256
@ -2815,8 +2815,8 @@ inline void ggml_cuda_op_mul_mat_vec(
#endif
if (use_mul_mat_vec_q) {
int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1;
padded_row_size -= padded_row_size % MATRIX_ROW_PADDING;
const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ?
ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
size_t as;
void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
@ -3642,7 +3642,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
size_t size = ggml_nbytes_split(tensor, nrows_split);
const size_t original_size = size;
// pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses
// pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
* ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
@ -3658,7 +3658,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
}
CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice));
extra->data_device[id] = buf;