cuda : revert CUDA pool stuff (#3944)

* Revert "cuda : add ROCM aliases for CUDA pool stuff (#3918)"

This reverts commit 629f917cd6.

* Revert "cuda : use CUDA memory pool with async memory allocation/deallocation when available (#3903)"

This reverts commit d6069051de.

ggml-ci
This commit is contained in:
slaren 2023-11-05 08:12:13 +01:00 committed by GitHub
parent f28af0d81a
commit 48ade94538
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -39,10 +39,6 @@
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceGetMemPool hipDeviceGetMemPool
#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold
#define cudaMemPoolSetAttribute hipMemPoolSetAttribute
#define cudaMemPool_t hipMemPool_t
#define cudaDeviceProp hipDeviceProp_t #define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize #define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t #define cudaError_t hipError_t
@ -52,7 +48,6 @@
#define cudaEvent_t hipEvent_t #define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy #define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree #define cudaFree hipFree
#define cudaFreeAsync hipFreeAsync
#define cudaFreeHost hipHostFree #define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice #define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount #define cudaGetDeviceCount hipGetDeviceCount
@ -60,7 +55,6 @@
#define cudaGetErrorString hipGetErrorString #define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError #define cudaGetLastError hipGetLastError
#define cudaMalloc hipMalloc #define cudaMalloc hipMalloc
#define cudaMallocFromPoolAsync hipMallocFromPoolAsync
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaMemcpy hipMemcpy #define cudaMemcpy hipMemcpy
#define cudaMemcpy2DAsync hipMemcpy2DAsync #define cudaMemcpy2DAsync hipMemcpy2DAsync
@ -187,11 +181,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
do { \ do { \
cudaError_t err_ = (err); \ cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \ if (err_ != cudaSuccess) { \
int dev_id; \ int id; \
cudaGetDevice(&dev_id); \ cudaGetDevice(&id); \
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \ cudaGetErrorString(err_)); \
fprintf(stderr, "current device: %d\n", dev_id); \ fprintf(stderr, "current device: %d\n", id); \
exit(1); \ exit(1); \
} \ } \
} while (0) } while (0)
@ -201,11 +195,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
do { \ do { \
cublasStatus_t err_ = (err); \ cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \ if (err_ != CUBLAS_STATUS_SUCCESS) { \
int dev_id; \ int id; \
cudaGetDevice(&dev_id); \ cudaGetDevice(&id); \
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
fprintf(stderr, "current device: %d\n", dev_id); \ fprintf(stderr, "current device: %d\n", id); \
exit(1); \ exit(1); \
} \ } \
} while (0) } while (0)
@ -471,7 +465,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
#define MAX_STREAMS 8 #define MAX_STREAMS 8
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
struct ggml_tensor_extra_gpu { struct ggml_tensor_extra_gpu {
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@ -5780,16 +5773,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
return ptr; return ptr;
} }
static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
if (g_cudaMemPools[id] == nullptr) {
return ggml_cuda_pool_malloc(size, actual_size);
}
void *ptr;
CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
*actual_size = size;
return ptr;
}
static void ggml_cuda_pool_free(void * ptr, size_t size) { static void ggml_cuda_pool_free(void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock); scoped_spin_lock lock(g_cuda_pool_lock);
int id; int id;
@ -5808,13 +5791,6 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
} }
static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
if (g_cudaMemPools[id] == nullptr) {
return ggml_cuda_pool_free(ptr, actual_size);
}
CUDA_CHECK(cudaFreeAsync(ptr, stream));
}
void ggml_init_cublas() { void ggml_init_cublas() {
static bool initialized = false; static bool initialized = false;
@ -5869,13 +5845,6 @@ void ggml_init_cublas() {
// create cublas handle // create cublas handle
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
// configure memory pool
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
if (err == cudaSuccess) {
size_t treshold = UINT64_MAX;
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
}
} }
// configure logging to stdout // configure logging to stdout
@ -6469,7 +6438,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
GGML_ASSERT(to_fp16_cuda != nullptr); GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = row_diff*ne00; size_t ne = row_diff*ne00;
src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream); src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
} }
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
@ -6480,12 +6449,13 @@ inline void ggml_cuda_op_mul_mat_cublas(
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr); GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = src1_ncols*ne10; size_t ne = src1_ncols*ne10;
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream); src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
} }
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
size_t dst_f16_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream); size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
const half alpha_f16 = 1.0f; const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f; const half beta_f16 = 0.0f;
@ -6503,15 +6473,14 @@ inline void ggml_cuda_op_mul_mat_cublas(
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
if (dst_f16_as != 0) { ggml_cuda_pool_free(dst_f16, dst_as);
ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
}
if (src0_as != 0) { if (src0_as != 0) {
ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream); ggml_cuda_pool_free(src0_as_f16, src0_as);
} }
if (src1_as != 0) { if (src1_as != 0) {
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream); ggml_cuda_pool_free(src1_as_f16, src1_as);
} }
} }
else { else {
@ -6521,7 +6490,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
if (src0->type != GGML_TYPE_F32) { if (src0->type != GGML_TYPE_F32) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
GGML_ASSERT(to_fp32_cuda != nullptr); GGML_ASSERT(to_fp32_cuda != nullptr);
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
} }
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
@ -6538,7 +6507,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
&beta, dst_dd_i, ldc)); &beta, dst_dd_i, ldc));
if (src0_as != 0) { if (src0_as != 0) {
ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream); ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
} }
} }
@ -6961,22 +6930,21 @@ static void ggml_cuda_op_mul_mat(
src0_dd[id] = (char *) src0_extra->data_device[id]; src0_dd[id] = (char *) src0_extra->data_device[id];
} else { } else {
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream); src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
} }
if (src1_on_device && src1_is_contiguous) { if (src1_on_device && src1_is_contiguous) {
src1_ddf[id] = (float *) src1_extra->data_device[id]; src1_ddf[id] = (float *) src1_extra->data_device[id];
} else { } else {
src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream); src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
} }
if (convert_src1_to_q8_1) { if (convert_src1_to_q8_1) {
const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
if (src1_on_device && src1_is_contiguous) { if (src1_on_device && src1_is_contiguous) {
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
// CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
} }
} }
@ -6984,7 +6952,7 @@ static void ggml_cuda_op_mul_mat(
dst_dd[id] = (float *) dst_extra->data_device[id]; dst_dd[id] = (float *) dst_extra->data_device[id];
} else { } else {
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream); dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
} }
} }
@ -7110,6 +7078,24 @@ static void ggml_cuda_op_mul_mat(
} }
} }
for (int64_t id = 0; id < g_device_count; ++id) {
CUDA_CHECK(ggml_cuda_set_device(id));
// free buffers again when done
if (src0_as[id] > 0) {
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
}
if (src1_asf[id] > 0) {
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
}
if (src1_asq[id] > 0) {
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
}
if (dst_as[id] > 0) {
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
}
}
// main device waits for all other devices to be finished // main device waits for all other devices to be finished
if (split && g_device_count > 1) { if (split && g_device_count > 1) {
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
@ -7127,21 +7113,6 @@ static void ggml_cuda_op_mul_mat(
CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(ggml_cuda_set_device(g_main_device));
CUDA_CHECK(cudaDeviceSynchronize()); CUDA_CHECK(cudaDeviceSynchronize());
} }
for (int64_t id = 0; id < g_device_count; ++id) {
if (src0_as[id] > 0) {
ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
}
if (src1_asf[id] > 0) {
ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
}
if (src1_asq[id] > 0) {
ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
}
if (dst_as[id] > 0) {
ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
}
}
} }
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -7328,11 +7299,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
GGML_ASSERT(to_fp16_cuda != nullptr); GGML_ASSERT(to_fp16_cuda != nullptr);
size_t src1_as = 0; size_t src1_as = 0;
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream); half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
size_t dst_as = 0; size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream); half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0); GGML_ASSERT(ne13 % ne03 == 0);
@ -7386,8 +7357,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
size_t ptrs_src_s = 0; size_t ptrs_src_s = 0;
size_t ptrs_dst_s = 0; size_t ptrs_dst_s = 0;
ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream); ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream); ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
dim3 block_dims(ne13, ne12); dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
@ -7400,6 +7371,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
dst->nb[2], dst->nb[3], dst->nb[2], dst->nb[3],
r2, r3); r2, r3);
CUDA_CHECK(cudaGetLastError()); CUDA_CHECK(cudaGetLastError());
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10, ne01, ne11, ne10,
@ -7411,22 +7383,19 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
if (ptrs_src_s != 0) { if (ptrs_src_s != 0) {
ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream); ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
} }
if (ptrs_dst_s != 0) { if (ptrs_dst_s != 0) {
ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream); ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
} }
} }
#endif #endif
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
if (src1_as != 0) {
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream); ggml_cuda_pool_free(src1_as_f16, src1_as);
} ggml_cuda_pool_free(dst_f16, dst_as);
if (dst_as != 0) {
ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
}
} }
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {