Improve cuBLAS performance by dequantizing on the GPU (#1065)

This commit is contained in:
slaren 2023-04-20 03:14:14 +02:00 committed by GitHub
parent 834695fe3a
commit 02d6988121
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 221 additions and 41 deletions

View file

@ -110,6 +110,7 @@ if (APPLE AND LLAMA_ACCELERATE)
message(WARNING "Accelerate framework not found")
endif()
endif()
if (LLAMA_OPENBLAS)
if (LLAMA_STATIC)
set(BLA_STATIC ON)
@ -150,6 +151,10 @@ if (LLAMA_CUBLAS)
if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")
enable_language(CUDA)
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
add_compile_definitions(GGML_USE_CUBLAS)
if (LLAMA_STATIC)
@ -241,21 +246,26 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
message(STATUS "x86 detected")
if (MSVC)
if (LLAMA_AVX512)
add_compile_options(/arch:AVX512)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
# Do it manually.
if (LLAMA_AVX512_VBMI)
add_compile_definitions(__AVX512VBMI__)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
endif()
if (LLAMA_AVX512_VNNI)
add_compile_definitions(__AVX512VNNI__)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
elseif (LLAMA_AVX2)
add_compile_options(/arch:AVX2)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
elseif (LLAMA_AVX)
add_compile_options(/arch:AVX)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
endif()
else()
if (LLAMA_F16C)
@ -292,7 +302,8 @@ endif()
add_library(ggml OBJECT
ggml.c
ggml.h)
ggml.h
${GGML_CUDA_SOURCES})
target_include_directories(ggml PUBLIC .)
target_compile_features(ggml PUBLIC c_std_11) # don't bump
@ -314,6 +325,14 @@ if (BUILD_SHARED_LIBS)
target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD)
endif()
if (GGML_CUDA_SOURCES)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF)
endif()
#
# programs, examples and tests
#

View file

@ -1,3 +1,6 @@
# Define the default target now so that it is always the first target
default: main quantize quantize-stats perplexity embedding vdot
ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
@ -100,6 +103,9 @@ endif
ifdef LLAMA_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
LDFLAGS += -lcublas_static -lculibos -lcudart_static -lcublasLt_static -lpthread -ldl -L/usr/local/cuda/lib64
OBJS += ggml-cuda.o
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
nvcc -arch=native -c -o $@ $<
endif
ifdef LLAMA_GPROF
CFLAGS += -pg
@ -137,8 +143,6 @@ $(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )
default: main quantize quantize-stats perplexity embedding vdot
#
# Build library
#
@ -155,35 +159,35 @@ common.o: examples/common.cpp examples/common.h
clean:
rm -vf *.o main quantize quantize-stats perplexity embedding benchmark-q4_0-matmult
main: examples/main/main.cpp ggml.o llama.o common.o
main: examples/main/main.cpp ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
@echo
@echo '==== Run ./main -h for help. ===='
@echo
quantize: examples/quantize/quantize.cpp ggml.o llama.o
quantize: examples/quantize/quantize.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o
quantize-stats: examples/quantize-stats/quantize-stats.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o
perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o
embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
vdot: pocs/vdot/vdot.cpp ggml.o
vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
libllama.so: llama.o ggml.o
libllama.so: llama.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
#
# Tests
#
benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o
benchmark: examples/benchmark/benchmark-q4_0-matmult.c ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o benchmark-q4_0-matmult $(LDFLAGS)
./benchmark-q4_0-matmult

116
ggml-cuda.cu Normal file
View file

@ -0,0 +1,116 @@
#include <stdint.h>
#include <cuda_fp16.h>
#include "ggml-cuda.h"
typedef uint16_t ggml_fp16_t;
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
#define QK4_0 32
typedef struct {
float d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
#define QK4_1 32
typedef struct {
float d; // delta
float m; // min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
#define QK4_2 16
typedef struct {
__half d; // delta
uint8_t qs[QK4_2 / 2]; // nibbles / quants
} block_q4_2;
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
const block_q4_0 * x = (const block_q4_0 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
const uint8_t * pp = x[i].qs;
for (int l = 0; l < QK4_0; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
y[i*QK4_0 + l + 0] = v0;
y[i*QK4_0 + l + 1] = v1;
}
}
static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
const block_q4_1 * x = (const block_q4_1 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
const float m = x[i].m;
const uint8_t * pp = x[i].qs;
for (int l = 0; l < QK4_1; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const float v0 = vi0*d + m;
const float v1 = vi1*d + m;
y[i*QK4_1 + l + 0] = v0;
y[i*QK4_1 + l + 1] = v1;
}
}
static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
const block_q4_2 * x = (const block_q4_2 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
const uint8_t * pp = x[i].qs;
for (int l = 0; l < QK4_2; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
y[i*QK4_2 + l + 0] = v0;
y[i*QK4_2 + l + 1] = v1;
}
}
extern "C" {
__host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
}
__host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_1;
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
}
__host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_2;
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
}
}

11
ggml-cuda.h Normal file
View file

@ -0,0 +1,11 @@
#ifdef __cplusplus
extern "C" {
#endif
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
#ifdef __cplusplus
}
#endif

80
ggml.c
View file

@ -150,23 +150,25 @@ inline static void* ggml_aligned_malloc(size_t size) {
#elif defined(GGML_USE_CUBLAS)
#include <cublas_v2.h>
#include <cuda_runtime.h>
#define CUDA_CHECK(err) \
do { \
cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \
exit(1); \
} \
#include "ggml-cuda.h"
#define CUDA_CHECK(err) \
do { \
cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \
exit(1); \
} \
} while (0)
#define CUBLAS_CHECK(err) \
do { \
cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
exit(1); \
} \
#define CUBLAS_CHECK(err) \
do { \
cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
exit(1); \
} \
} while (0)
static cublasHandle_t cublasH = NULL;
@ -177,6 +179,7 @@ static void init_cublas(void) {
CUBLAS_CHECK(cublasCreate(&cublasH));
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
// configure logging to stdout
@ -7311,7 +7314,6 @@ static void ggml_compute_forward_mul_mat_f32(
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
// zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@ -7323,6 +7325,7 @@ static void ggml_compute_forward_mul_mat_f32(
}
}
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
@ -7535,7 +7538,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@ -7553,6 +7555,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
}
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
@ -7722,13 +7725,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
return;
}
float * const wdata = params->wdata;
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
#if defined(GGML_USE_CUBLAS)
float *d_X = NULL;
float *d_Y = NULL;
float *d_D = NULL;
float *d_Q = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
@ -7738,10 +7739,41 @@ static void ggml_compute_forward_mul_mat_q_f32(
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
if (type == GGML_TYPE_Q4_0) {
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
}
else if (type == GGML_TYPE_Q4_1) {
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
}
else if (type == GGML_TYPE_Q4_2) {
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
}
else {
GGML_ASSERT(false);
}
#else
float * const wdata = params->wdata;
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
// copy and dequantize on device
CUDA_CHECK(
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
CUDA_CHECK(cudaGetLastError());
#else
{
size_t id = 0;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
@ -7749,15 +7781,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
id += ne00;
}
}
const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
#endif
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
// copy data to device
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
// compute
@ -7770,7 +7799,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
#else
// zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@ -7783,9 +7811,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
}
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
CUDA_CHECK(cudaFree(d_X));
CUDA_CHECK(cudaFree(d_Y));
CUDA_CHECK(cudaFree(d_D));
CUDA_CHECK(cudaFree(d_Q));
#endif
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);