diff --git a/CMakeLists.txt b/CMakeLists.txt index d7aa051da..1f9fdd30f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -110,6 +110,7 @@ if (APPLE AND LLAMA_ACCELERATE) message(WARNING "Accelerate framework not found") endif() endif() + if (LLAMA_OPENBLAS) if (LLAMA_STATIC) set(BLA_STATIC ON) @@ -150,6 +151,10 @@ if (LLAMA_CUBLAS) if (CUDAToolkit_FOUND) message(STATUS "cuBLAS found") + enable_language(CUDA) + + set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) + add_compile_definitions(GGML_USE_CUBLAS) if (LLAMA_STATIC) @@ -241,21 +246,26 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") message(STATUS "x86 detected") if (MSVC) if (LLAMA_AVX512) - add_compile_options(/arch:AVX512) + add_compile_options($<$:/arch:AVX512>) + add_compile_options($<$:/arch:AVX512>) # MSVC has no compile-time flags enabling specific # AVX512 extensions, neither it defines the # macros corresponding to the extensions. # Do it manually. if (LLAMA_AVX512_VBMI) - add_compile_definitions(__AVX512VBMI__) + add_compile_definitions($<$:__AVX512VBMI__>) + add_compile_definitions($<$:__AVX512VBMI__>) endif() if (LLAMA_AVX512_VNNI) - add_compile_definitions(__AVX512VNNI__) + add_compile_definitions($<$:__AVX512VNNI__>) + add_compile_definitions($<$:__AVX512VNNI__>) endif() elseif (LLAMA_AVX2) - add_compile_options(/arch:AVX2) + add_compile_options($<$:/arch:AVX2>) + add_compile_options($<$:/arch:AVX2>) elseif (LLAMA_AVX) - add_compile_options(/arch:AVX) + add_compile_options($<$:/arch:AVX>) + add_compile_options($<$:/arch:AVX>) endif() else() if (LLAMA_F16C) @@ -292,7 +302,8 @@ endif() add_library(ggml OBJECT ggml.c - ggml.h) + ggml.h + ${GGML_CUDA_SOURCES}) target_include_directories(ggml PUBLIC .) target_compile_features(ggml PUBLIC c_std_11) # don't bump @@ -314,6 +325,14 @@ if (BUILD_SHARED_LIBS) target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD) endif() +if (GGML_CUDA_SOURCES) + message(STATUS "GGML CUDA sources found, configuring CUDA architecture") + set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF) + set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") + set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF) +endif() + + # # programs, examples and tests # diff --git a/Makefile b/Makefile index d9a2d836b..4bf481aa2 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,6 @@ +# Define the default target now so that it is always the first target +default: main quantize quantize-stats perplexity embedding vdot + ifndef UNAME_S UNAME_S := $(shell uname -s) endif @@ -100,6 +103,9 @@ endif ifdef LLAMA_CUBLAS CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64 + OBJS += ggml-cuda.o +ggml-cuda.o: ggml-cuda.cu ggml-cuda.h + nvcc -arch=native -c -o $@ $< endif ifdef LLAMA_GPROF CFLAGS += -pg @@ -137,8 +143,6 @@ $(info I CC: $(CCV)) $(info I CXX: $(CXXV)) $(info ) -default: main quantize quantize-stats perplexity embedding vdot - # # Build library # @@ -155,35 +159,35 @@ common.o: examples/common.cpp examples/common.h clean: rm -vf *.o main quantize quantize-stats perplexity embedding benchmark-q4_0-matmult -main: examples/main/main.cpp ggml.o llama.o common.o +main: examples/main/main.cpp ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' @echo -quantize: examples/quantize/quantize.cpp ggml.o llama.o +quantize: examples/quantize/quantize.cpp ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) -quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o +quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) -perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o +perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) -embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o +embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) -vdot: pocs/vdot/vdot.cpp ggml.o +vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) -libllama.so: llama.o ggml.o +libllama.so: llama.o ggml.o $(OBJS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) # # Tests # -benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o +benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o benchmark-q4_0-matmult $(LDFLAGS) ./benchmark-q4_0-matmult diff --git a/ggml-cuda.cu b/ggml-cuda.cu new file mode 100644 index 000000000..7cd116602 --- /dev/null +++ b/ggml-cuda.cu @@ -0,0 +1,116 @@ +#include +#include +#include "ggml-cuda.h" + +typedef uint16_t ggml_fp16_t; +static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size"); + +#define QK4_0 32 +typedef struct { + float d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + float d; // delta + float m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK4_2 16 +typedef struct { + __half d; // delta + uint8_t qs[QK4_2 / 2]; // nibbles / quants +} block_q4_2; +static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding"); + + +static __global__ void dequantize_block_q4_0(const void * vx, float * y) { + const block_q4_0 * x = (const block_q4_0 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + + const uint8_t * pp = x[i].qs; + + for (int l = 0; l < QK4_0; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + y[i*QK4_0 + l + 0] = v0; + y[i*QK4_0 + l + 1] = v1; + } +} + +static __global__ void dequantize_block_q4_1(const void * vx, float * y) { + const block_q4_1 * x = (const block_q4_1 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + const float m = x[i].m; + + const uint8_t * pp = x[i].qs; + + for (int l = 0; l < QK4_1; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; + + y[i*QK4_1 + l + 0] = v0; + y[i*QK4_1 + l + 1] = v1; + } +} + +static __global__ void dequantize_block_q4_2(const void * vx, float * y) { + const block_q4_2 * x = (const block_q4_2 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + + const uint8_t * pp = x[i].qs; + + for (int l = 0; l < QK4_2; l += 2) { + const uint8_t vi = pp[l/2]; + + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; + + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; + + y[i*QK4_2 + l + 0] = v0; + y[i*QK4_2 + l + 1] = v1; + } +} + +extern "C" { + __host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_0; + dequantize_block_q4_0<<>>(vx, y); + } + + __host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_1; + dequantize_block_q4_1<<>>(vx, y); + } + + __host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_2; + dequantize_block_q4_2<<>>(vx, y); + } +} diff --git a/ggml-cuda.h b/ggml-cuda.h new file mode 100644 index 000000000..646caafc6 --- /dev/null +++ b/ggml-cuda.h @@ -0,0 +1,11 @@ +#ifdef __cplusplus +extern "C" { +#endif + +void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); + +#ifdef __cplusplus +} +#endif diff --git a/ggml.c b/ggml.c index 431cdb9c9..9a3430859 100644 --- a/ggml.c +++ b/ggml.c @@ -150,23 +150,25 @@ inline static void* ggml_aligned_malloc(size_t size) { #elif defined(GGML_USE_CUBLAS) #include #include -#define CUDA_CHECK(err) \ - do { \ - cudaError_t err_ = (err); \ - if (err_ != cudaSuccess) { \ - printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ - cudaGetErrorString(err_)); \ - exit(1); \ - } \ +#include "ggml-cuda.h" + +#define CUDA_CHECK(err) \ + do { \ + cudaError_t err_ = (err); \ + if (err_ != cudaSuccess) { \ + printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ + cudaGetErrorString(err_)); \ + exit(1); \ + } \ } while (0) -#define CUBLAS_CHECK(err) \ - do { \ - cublasStatus_t err_ = (err); \ - if (err_ != CUBLAS_STATUS_SUCCESS) { \ - printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ - exit(1); \ - } \ +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + exit(1); \ + } \ } while (0) static cublasHandle_t cublasH = NULL; @@ -177,6 +179,7 @@ static void init_cublas(void) { CUBLAS_CHECK(cublasCreate(&cublasH)); CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); + CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); // configure logging to stdout @@ -7311,7 +7314,6 @@ static void ggml_compute_forward_mul_mat_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); #else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, @@ -7323,6 +7325,7 @@ static void ggml_compute_forward_mul_mat_f32( } } #if defined(GGML_USE_CUBLAS) + CUDA_CHECK(cudaStreamSynchronize(cudaStream)); CUDA_CHECK(cudaFree(d_X)); CUDA_CHECK(cudaFree(d_Y)); CUDA_CHECK(cudaFree(d_D)); @@ -7535,7 +7538,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); #else const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); @@ -7553,6 +7555,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( } #if defined(GGML_USE_CUBLAS) + CUDA_CHECK(cudaStreamSynchronize(cudaStream)); CUDA_CHECK(cudaFree(d_X)); CUDA_CHECK(cudaFree(d_Y)); CUDA_CHECK(cudaFree(d_D)); @@ -7722,13 +7725,11 @@ static void ggml_compute_forward_mul_mat_q_f32( return; } - float * const wdata = params->wdata; - dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; - #if defined(GGML_USE_CUBLAS) float *d_X = NULL; float *d_Y = NULL; float *d_D = NULL; + float *d_Q = NULL; const float alpha = 1.0f; const float beta = 0.0f; const int x_ne = ne01 * ne10; @@ -7738,10 +7739,41 @@ static void ggml_compute_forward_mul_mat_q_f32( CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne)); CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne)); CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne)); + CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type])); + + void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL; + if (type == GGML_TYPE_Q4_0) { + dequantize_row_q_cuda = dequantize_row_q4_0_cuda; + } + else if (type == GGML_TYPE_Q4_1) { + dequantize_row_q_cuda = dequantize_row_q4_1_cuda; + } + else if (type == GGML_TYPE_Q4_2) { + dequantize_row_q_cuda = dequantize_row_q4_2_cuda; + } + else { + GGML_ASSERT(false); + } +#else + float * const wdata = params->wdata; + dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; #endif for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + +#if defined(GGML_USE_CUBLAS) + // copy and dequantize on device + CUDA_CHECK( + cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02, + GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream)); + + dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream); + CUDA_CHECK(cudaGetLastError()); +#else { size_t id = 0; for (int64_t i01 = 0; i01 < ne01; ++i01) { @@ -7749,15 +7781,12 @@ static void ggml_compute_forward_mul_mat_q_f32( id += ne00; } } - const float * x = wdata; - const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); +#endif - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); #if defined(GGML_USE_CUBLAS) // copy data to device - CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream)); CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream)); // compute @@ -7770,7 +7799,6 @@ static void ggml_compute_forward_mul_mat_q_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); - CUDA_CHECK(cudaStreamSynchronize(cudaStream)); #else // zT = y * xT cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, @@ -7783,9 +7811,11 @@ static void ggml_compute_forward_mul_mat_q_f32( } #if defined(GGML_USE_CUBLAS) + CUDA_CHECK(cudaStreamSynchronize(cudaStream)); CUDA_CHECK(cudaFree(d_X)); CUDA_CHECK(cudaFree(d_Y)); CUDA_CHECK(cudaFree(d_D)); + CUDA_CHECK(cudaFree(d_Q)); #endif //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);