cuda : fix const ptrs warning causing ROCm build issues (#3913)

This commit is contained in:
Georgi Gerganov 2023-11-02 20:32:11 +02:00 committed by GitHub
parent d6069051de
commit c7743fe1c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -7248,7 +7248,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
__global__ void k_compute_batched_ptrs( __global__ void k_compute_batched_ptrs(
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
void ** ptrs, const void ** ptrs_src, void ** ptrs_dst,
int ne12, int ne13, int ne12, int ne13,
int ne23, int ne23,
int nb02, int nb03, int nb02, int nb03,
@ -7265,9 +7265,9 @@ __global__ void k_compute_batched_ptrs(
int i03 = i13 / r3; int i03 = i13 / r3;
int i02 = i12 / r2; int i02 = i12 / r2;
ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*nb02 + i03*nb03; ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2; ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* nb2/2 + i13* nb3/2; ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
} }
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -7372,14 +7372,20 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
} else { } else {
// use cublasGemmBatchedEx // use cublasGemmBatchedEx
const int ne23 = ne12*ne13; const int ne23 = ne12*ne13;
// allocate device memory for pointers
size_t ptrs_s = 0; const void ** ptrs_src = nullptr;
void ** ptrs_as = (void **)ggml_cuda_pool_malloc_async(3*ne23*sizeof(void *), &ptrs_s, id, main_stream); void ** ptrs_dst = nullptr;
size_t ptrs_src_s = 0;
size_t ptrs_dst_s = 0;
ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
dim3 block_dims(ne13, ne12); dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
src0_as_f16, src1_as_f16, dst_f16, src0_as_f16, src1_as_f16, dst_f16,
ptrs_as, ptrs_src, ptrs_dst,
ne12, ne13, ne12, ne13,
ne23, ne23,
nb02, nb03, nb02, nb03,
@ -7390,15 +7396,18 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10, ne01, ne11, ne10,
&alpha_f16, (const void * const *) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half), &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void * const *) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float), (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( void ** ) (ptrs_as + 2*ne23), CUDA_R_16F, ne01, &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
ne23, ne23,
CUBLAS_COMPUTE_16F, CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// free device memory for pointers
if (ptrs_s != 0) { if (ptrs_src_s != 0) {
ggml_cuda_pool_free_async(ptrs_as, ptrs_s, id, main_stream); ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
}
if (ptrs_dst_s != 0) {
ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
} }
} }
#endif #endif