diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c1ec306f0..e8a1e77cb 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,11 +1,38 @@ +#include +#include #include #include -#include #include -#include "ggml-cuda.h" -typedef uint16_t ggml_fp16_t; -static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size"); +#include +#include +#include + +#include "ggml-cuda.h" +#include "ggml.h" + +static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); + +#define CUDA_CHECK(err) \ + do { \ + cudaError_t err_ = (err); \ + if (err_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ + cudaGetErrorString(err_)); \ + exit(1); \ + } \ + } while (0) + +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); #define QK4_0 32 typedef struct { @@ -24,14 +51,14 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 b #define QK4_2 16 typedef struct { - __half d; // delta + half d; // delta uint8_t qs[QK4_2 / 2]; // nibbles / quants } block_q4_2; static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding"); #define QK5_0 32 typedef struct { - __half d; // delta + half d; // delta uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_0 / 2]; // nibbles / quants } block_q5_0; @@ -39,9 +66,9 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5 #define QK5_1 32 typedef struct { - __half d; // delta - __half m; // min - uint32_t qh; // 5-th bit of quants + half d; // delta + half m; // min + uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_1 / 2]; // nibbles / quants } block_q5_1; static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); @@ -162,7 +189,8 @@ static __global__ void dequantize_block_q5_1(const void * vx, float * y) { const uint8_t * pp = x[i].qs; - const uint32_t qh = x[i].qh; + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); for (int l = 0; l < QK5_1; l += 2) { const uint8_t vi = pp[l/2]; @@ -197,37 +225,50 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) { } } -void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { +static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_0; dequantize_block_q4_0<<>>(vx, y); } -void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { +static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_1; dequantize_block_q4_1<<>>(vx, y); } -void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { +static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_2; dequantize_block_q4_2<<>>(vx, y); } -void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { +static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_0; dequantize_block_q5_0<<>>(vx, y); } -void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { +static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_1; dequantize_block_q5_1<<>>(vx, y); } -void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { +static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK8_0; dequantize_block_q8_0<<>>(vx, y); } -dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) { +// TODO: optimize +static __global__ void convert_fp16_to_fp32(const void * vx, float * y) { + const half * x = (const half *) vx; + + const int i = blockIdx.x; + + y[i] = __half2float(x[i]); +} + +static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) { + convert_fp16_to_fp32<<>>(x, y); +} + +static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; @@ -241,6 +282,8 @@ dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) { return dequantize_row_q5_1_cuda; case GGML_TYPE_Q8_0: return dequantize_row_q8_0_cuda; + case GGML_TYPE_F16: + return convert_fp16_to_fp32_cuda; default: return nullptr; } @@ -271,7 +314,7 @@ struct cuda_buffer { static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS]; static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; -void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { +static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { @@ -290,7 +333,7 @@ void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } -void ggml_cuda_pool_free(void * ptr, size_t size) { +static void ggml_cuda_pool_free(void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { @@ -305,28 +348,55 @@ void ggml_cuda_pool_free(void * ptr, size_t size) { CUDA_CHECK(cudaFree(ptr)); } -cublasHandle_t g_cublasH = nullptr; -cudaStream_t g_cudaStream = nullptr; -cudaStream_t g_cudaStream2 = nullptr; -cudaEvent_t g_cudaEvent = nullptr; +#define GGML_CUDA_MAX_STREAMS 8 +#define GGML_CUDA_MAX_EVENTS 64 +static cublasHandle_t g_cublasH = nullptr; +static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr }; +static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr }; +static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr }; void ggml_init_cublas() { if (g_cublasH == nullptr) { - // create cublas handle, bind a stream - CUBLAS_CHECK(cublasCreate(&g_cublasH)); - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking)); - CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream)); + // create streams + for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) { + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking)); + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking)); + } + // create events + for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) { + CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming)); + } - // create additional stream and event for synchronization - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking)); - CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming)); + // create cublas handle + CUBLAS_CHECK(cublasCreate(&g_cublasH)); + CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH)); // configure logging to stdout - // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); } } -cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) { +void * ggml_cuda_host_malloc(size_t size) { + if (getenv("GGML_CUDA_NO_PINNED") != nullptr) { + return nullptr; + } + + void * ptr = nullptr; + cudaError_t err = cudaMallocHost((void **) &ptr, size); + if (err != cudaSuccess) { + fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", + size/1024.0/1024.0, cudaGetErrorString(err)); + return nullptr; + } + + return ptr; +} + +void ggml_cuda_host_free(void * ptr) { + CUDA_CHECK(cudaFreeHost(ptr)); +} + +static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) { const uint64_t ne0 = src->ne[0]; const uint64_t ne1 = src->ne[1]; const uint64_t nb0 = src->nb[0]; @@ -354,22 +424,293 @@ cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, } } -void * ggml_cuda_host_malloc(size_t size) { - if (getenv("GGML_CUDA_NO_PINNED") != nullptr) { - return nullptr; +static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne00; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + const int n_mm = ne03 * ne02; + + size_t x_size, y_size, d_size; + float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); + float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + int i = i03*ne02 + i02; + cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS]; + + float * c_X = d_X + i * x_ne; + float * c_Y = d_Y + i * y_ne; + float * c_D = d_D + i * d_ne; + + // copy data to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream)); + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + + // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); + CUBLAS_CHECK( + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, c_X, ne00, + c_Y, ne10, + &beta, c_D, ne01)); + + // copy dst to host + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + } } - void * ptr = nullptr; - cudaError_t err = cudaMallocHost((void **) &ptr, size); - if (err != cudaSuccess) { - fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", - size/1024.0/1024.0, cudaGetErrorString(err)); - return nullptr; + CUDA_CHECK(cudaDeviceSynchronize()); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); +} + +static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) { + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne00; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + const int n_mm = ne03 * ne02; + + size_t x_size, y_size, d_size; + half * d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size); + half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size); + float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); + + bool src1_cont_rows = nb10 == sizeof(float); + bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + int i = i03*ne02 + i02; + cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS]; + + half * c_X = d_X + i * x_ne; + half * c_Y = d_Y + i * y_ne; + float * c_D = d_D + i * d_ne; + + // copy src0 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream)); + + // convert src1 to fp16 + // TODO: use multiple threads + ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02); + char * src1i = (char *) src1->data + i03*nb13 + i02*nb12; + if (src1_cont_rows) { + if (src1_cont_cols) { + ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11); + } + else { + for (int64_t i01 = 0; i01 < ne11; i01++) { + ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10); + } + } + } + else { + for (int64_t i01 = 0; i01 < ne11; i01++) { + for (int64_t i00 = 0; i00 < ne10; i00++) { + // very slow due to no inlining + tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10)); + } + } + } + + // copy src1 to device + CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + + // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); + CUBLAS_CHECK( + cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, c_X, CUDA_R_16F, ne00, + c_Y, CUDA_R_16F, ne10, + &beta, c_D, CUDA_R_32F, ne01, + CUBLAS_COMPUTE_32F_FAST_16F, + CUBLAS_GEMM_DEFAULT)); + + // copy dst to host + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + } } - return ptr; + CUDA_CHECK(cudaDeviceSynchronize()); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); } -void ggml_cuda_host_free(void * ptr) { - CUDA_CHECK(cudaFreeHost(ptr)); +static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + const ggml_type type = src0->type; + + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne00; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + const int n_mm = ne03 * ne02; + const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type); + + size_t x_size, y_size, d_size, q_size; + float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); + float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); + char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type); + GGML_ASSERT(to_fp32_cuda != nullptr); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + int i = i03*ne02 + i02; + cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS]; + cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS]; + cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS]; + + float * c_X = d_X + i * x_ne; + float * c_Y = d_Y + i * y_ne; + float * c_D = d_D + i * d_ne; + char * c_Q = d_Q + i * q_sz; + + // copy src0 and convert to fp32 on device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); + to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + + // copy src1 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + + // wait for conversion + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + + // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); + CUBLAS_CHECK( + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, c_X, ne00, + c_Y, ne10, + &beta, c_D, ne01)); + + // copy dst to host + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + } + } + + CUDA_CHECK(cudaDeviceSynchronize()); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); + ggml_cuda_pool_free(d_Q, q_size); +} + +bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { + + return true; + } + + return false; +} + +bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) { + size_t src0_sz = ggml_nbytes(src0); + size_t src1_sz = ggml_nbytes(src1); + + // mul_mat_q: src0 is converted to fp32 on device + size_t mul_mat_q_transfer = src0_sz + src1_sz; + + // mul_mat_f16: src1 is converted to fp16 on cpu + size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1); + + // choose the smaller one to transfer to the device + // TODO: this is not always the best choice due to the overhead of converting to fp16 + return mul_mat_f16_transfer < mul_mat_q_transfer; +} + +void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) { + GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst)); + + if (src0->type == GGML_TYPE_F32) { + ggml_cuda_mul_mat_f32(src0, src1, dst); + } + else if (src0->type == GGML_TYPE_F16) { + if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) { + ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize); + } + else { + ggml_cuda_mul_mat_q_f32(src0, src1, dst); + } + } + else if (ggml_is_quantized(src0->type)) { + ggml_cuda_mul_mat_q_f32(src0, src1, dst); + } + else { + GGML_ASSERT(false); + } +} + +size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) { + return ggml_nelements(src1) * sizeof(ggml_fp16_t); + } + else { + return 0; + } } diff --git a/ggml-cuda.h b/ggml-cuda.h index 36782d9e7..f7d6a8bc1 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -1,54 +1,19 @@ -#include -#include #include "ggml.h" #ifdef __cplusplus extern "C" { #endif -#define CUDA_CHECK(err) \ - do { \ - cudaError_t err_ = (err); \ - if (err_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ - cudaGetErrorString(err_)); \ - exit(1); \ - } \ - } while (0) - -#define CUBLAS_CHECK(err) \ - do { \ - cublasStatus_t err_ = (err); \ - if (err_ != CUBLAS_STATUS_SUCCESS) { \ - fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ - exit(1); \ - } \ - } while (0) - -extern cublasHandle_t g_cublasH; -extern cudaStream_t g_cudaStream; -extern cudaStream_t g_cudaStream2; -extern cudaEvent_t g_cudaEvent; - void ggml_init_cublas(void); + +bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); +size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); +void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize); + +// TODO: export these with GGML_API void * ggml_cuda_host_malloc(size_t size); void ggml_cuda_host_free(void * ptr); -void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); -void ggml_cuda_pool_free(void * ptr, size_t size); - -void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); -void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); -void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); -void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); -void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); -void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); - -cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream); - -typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); -dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type); - #ifdef __cplusplus } #endif diff --git a/ggml.c b/ggml.c index 5b5ed925e..bce7a7a57 100644 --- a/ggml.c +++ b/ggml.c @@ -135,14 +135,6 @@ inline static void* ggml_aligned_malloc(size_t size) { #define UNUSED(x) (void)(x) #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) -#define GGML_ASSERT(x) \ - do { \ - if (!(x)) { \ - fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ - abort(); \ - } \ - } while (0) - #if defined(GGML_USE_ACCELERATE) #include #elif defined(GGML_USE_OPENBLAS) @@ -370,6 +362,32 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) { return GGML_FP32_TO_FP16(x); } +void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) { + for (size_t i = 0; i < n; i++) { + y[i] = GGML_FP16_TO_FP32(x[i]); + } +} + +void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) { + size_t i = 0; +#if defined(__F16C__) + for (; i + 7 < n; i += 8) { + __m256 x_vec = _mm256_loadu_ps(x + i); + __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128((__m128i *)(y + i), y_vec); + } + for(; i + 3 < n; i += 4) { + __m128 x_vec = _mm_loadu_ps(x + i); + __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storel_epi64((__m128i *)(y + i), y_vec); + } +#endif + for (; i < n; i++) { + y[i] = GGML_FP32_TO_FP16(x[i]); + } +} + + // // timing // @@ -4325,12 +4343,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } - // initialize cuBLAS - #if defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_CUBLAS) ggml_init_cublas(); - #elif defined(GGML_USE_CLBLAST) +#elif defined(GGML_USE_CLBLAST) ggml_cl_init(); - #endif +#endif is_first_call = false; } @@ -8101,7 +8118,7 @@ static void ggml_compute_forward_rms_norm( // ggml_compute_forward_mul_mat -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) // helper function to determine if it is better to use BLAS or not // for large matrices, BLAS is faster static bool ggml_compute_forward_mul_mat_use_blas( @@ -8117,12 +8134,9 @@ static bool ggml_compute_forward_mul_mat_use_blas( const int64_t ne1 = dst->ne[1]; // TODO: find the optimal values for these - if ( -#if !defined(GGML_USE_CUBLAS) - ggml_is_contiguous(src0) && + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && -#endif - ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) { + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ return true; @@ -8130,7 +8144,6 @@ static bool ggml_compute_forward_mul_mat_use_blas( return false; } - #endif static void ggml_compute_forward_mul_mat_f32( @@ -8146,7 +8159,7 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) const int64_t ne10 = src1->ne[0]; #endif const int64_t ne11 = src1->ne[1]; @@ -8203,7 +8216,16 @@ static void ggml_compute_forward_mul_mat_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_CUBLAS) + if (ggml_cuda_can_mul_mat(src0, src1, dst)) { + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { + ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); + } + return; + } +#endif + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (params->ith != 0) { return; @@ -8217,43 +8239,13 @@ static void ggml_compute_forward_mul_mat_f32( return; } -#if defined(GGML_USE_CUBLAS) - const float alpha = 1.0f; - const float beta = 0.0f; - const int x_ne = ne01 * ne00; - const int y_ne = ne11 * ne10; - const int d_ne = ne11 * ne01; - - size_t x_size, y_size, d_size; - float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); -#endif - for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { -#if !defined(GGML_USE_CUBLAS) const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); -#endif float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); -#if defined(GGML_USE_CUBLAS) - // copy data to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream)); - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream)); - - // compute - CUBLAS_CHECK( - cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha, d_X, ne00, - d_Y, ne10, - &beta, d_D, ne01)); - - // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); -#elif defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_CLBLAST) // zT = y * xT ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T, ne11, ne01, ne10, @@ -8270,12 +8262,6 @@ static void ggml_compute_forward_mul_mat_f32( #endif } } -#if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); - ggml_cuda_pool_free(d_X, x_size); - ggml_cuda_pool_free(d_Y, y_size); - ggml_cuda_pool_free(d_D, d_size); -#endif //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); return; @@ -8405,7 +8391,16 @@ static void ggml_compute_forward_mul_mat_f16_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_CUBLAS) + if (ggml_cuda_can_mul_mat(src0, src1, dst)) { + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { + ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); + } + return; + } +#endif + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { GGML_ASSERT(nb10 == sizeof(float)); @@ -8421,37 +8416,8 @@ static void ggml_compute_forward_mul_mat_f16_f32( return; } -#if defined(GGML_USE_CUBLAS) - const float alpha = 1.0f; - const float beta = 0.0f; - const int x_ne = ne01 * ne00; - const int y_ne = ne11 * ne10; - const int d_ne = ne11 * ne01; - - size_t x_size, y_size, d_size; - ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); -#endif for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { -#if defined(GGML_USE_CUBLAS) - // copy src0 while converting src1 - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream)); - - // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16 - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02); - { - size_t id = 0; - for (int64_t i01 = 0; i01 < ne11; ++i01) { - for (int64_t i00 = 0; i00 < ne10; ++i00) { - wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)); - } - } - - assert(id*sizeof(ggml_fp16_t) <= params->wsize); - } -#else float * const wdata = params->wdata; { size_t id = 0; @@ -8463,28 +8429,8 @@ static void ggml_compute_forward_mul_mat_f16_f32( assert(id*sizeof(float) <= params->wsize); } -#endif -#if defined(GGML_USE_CUBLAS) - const ggml_fp16_t * y = (ggml_fp16_t *) wdata; - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - - // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream)); - - // compute - CUBLAS_CHECK( - cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha, d_X, CUDA_R_16F, ne00, - d_Y, CUDA_R_16F, ne10, - &beta, d_D, CUDA_R_32F, ne01, - CUBLAS_COMPUTE_32F, - CUBLAS_GEMM_DEFAULT)); - - // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); -#elif defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_CLBLAST) const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); @@ -8513,12 +8459,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( } } -#if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); - ggml_cuda_pool_free(d_X, x_size); - ggml_cuda_pool_free(d_Y, y_size); - ggml_cuda_pool_free(d_D, d_size); -#endif /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ return; @@ -8671,7 +8611,16 @@ static void ggml_compute_forward_mul_mat_q_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_CUBLAS) + if (ggml_cuda_can_mul_mat(src0, src1, dst)) { + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { + ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); + } + return; + } +#endif + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (params->ith != 0) { return; @@ -8685,25 +8634,8 @@ static void ggml_compute_forward_mul_mat_q_f32( return; } -#if defined(GGML_USE_CUBLAS) - const float alpha = 1.0f; - const float beta = 0.0f; - const int x_ne = ne01 * ne00; - const int y_ne = ne11 * ne10; - const int d_ne = ne11 * ne01; - - size_t x_size, y_size, d_size, q_size; - float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); - void * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size); - - const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type); - GGML_ASSERT(dequantize_row_q_cuda != NULL); -#else float * const wdata = params->wdata; dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; -#endif for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -8711,14 +8643,7 @@ static void ggml_compute_forward_mul_mat_q_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); -#if defined(GGML_USE_CUBLAS) - // copy and dequantize on device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2)); - - dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2); - CUDA_CHECK(cudaGetLastError()); - CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2)); -#elif defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_CLBLAST) const void* x = (char *) src0->data + i03*nb03 + i02*nb02; #else { @@ -8734,24 +8659,7 @@ static void ggml_compute_forward_mul_mat_q_f32( const float * x = wdata; #endif -#if defined(GGML_USE_CUBLAS) - // copy data to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream)); - - // wait for dequantization - CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0)); - - // compute - CUBLAS_CHECK( - cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha, d_X, ne00, - d_Y, ne10, - &beta, d_D, ne01)); - - // copy data to host - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); -#elif defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_CLBLAST) // zT = y * xT ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T, ne11, ne01, ne10, @@ -8769,13 +8677,6 @@ static void ggml_compute_forward_mul_mat_q_f32( } } -#if defined(GGML_USE_CUBLAS) - CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); - ggml_cuda_pool_free(d_X, x_size); - ggml_cuda_pool_free(d_Y, y_size); - ggml_cuda_pool_free(d_D, d_size); - ggml_cuda_pool_free(d_Q, q_size); -#endif //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); return; @@ -11759,18 +11660,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) size_t cur = 0; +#if defined(GGML_USE_CUBLAS) + if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) { + node->n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node); + } + else +#endif if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning -#if defined(GGML_USE_CUBLAS) - // with cuBLAS, we need memory for the full 3D / 4D data of src1 - cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); -#else // here we need memory just for single 2D matrix from src0 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); -#endif } else { cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); } @@ -11779,13 +11683,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) #endif } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; } #endif } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); diff --git a/ggml.h b/ggml.h index d6feacd78..ef5a048c3 100644 --- a/ggml.h +++ b/ggml.h @@ -197,6 +197,14 @@ #define GGML_MAX_OPT 4 #define GGML_DEFAULT_N_THREADS 4 +#define GGML_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + #ifdef __cplusplus extern "C" { #endif @@ -212,6 +220,9 @@ extern "C" { GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x); + GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n); + GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n); + struct ggml_object; struct ggml_context;