diff --git a/.devops/full-rocm.Dockerfile b/.devops/full-rocm.Dockerfile new file mode 100644 index 000000000..6c521e9b4 --- /dev/null +++ b/.devops/full-rocm.Dockerfile @@ -0,0 +1,44 @@ +ARG UBUNTU_VERSION=22.04 + +# This needs to generally match the container host's environment. +ARG ROCM_VERSION=5.6 + +# Target the CUDA build image +ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete + +FROM ${BASE_ROCM_DEV_CONTAINER} as build + +# Unless otherwise specified, we make a fat build. +# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 +# This is mostly tied to rocBLAS supported archs. +ARG ROCM_DOCKER_ARCH=\ + gfx803 \ + gfx900 \ + gfx906 \ + gfx908 \ + gfx90a \ + gfx1010 \ + gfx1030 \ + gfx1100 \ + gfx1101 \ + gfx1102 + +COPY requirements.txt requirements.txt + +RUN pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.txt + +WORKDIR /app + +COPY . . + +# Set nvcc architecture +ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} +# Enable ROCm +ENV LLAMA_HIPBLAS=1 +ENV CC=/opt/rocm/llvm/bin/clang +ENV CXX=/opt/rocm/llvm/bin/clang++ + +RUN make + +ENTRYPOINT ["/app/.devops/tools.sh"] diff --git a/.devops/main-rocm.Dockerfile b/.devops/main-rocm.Dockerfile new file mode 100644 index 000000000..789deff6d --- /dev/null +++ b/.devops/main-rocm.Dockerfile @@ -0,0 +1,44 @@ +ARG UBUNTU_VERSION=22.04 + +# This needs to generally match the container host's environment. +ARG ROCM_VERSION=5.6 + +# Target the CUDA build image +ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete + +FROM ${BASE_ROCM_DEV_CONTAINER} as build + +# Unless otherwise specified, we make a fat build. +# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 +# This is mostly tied to rocBLAS supported archs. +ARG ROCM_DOCKER_ARCH=\ + gfx803 \ + gfx900 \ + gfx906 \ + gfx908 \ + gfx90a \ + gfx1010 \ + gfx1030 \ + gfx1100 \ + gfx1101 \ + gfx1102 + +COPY requirements.txt requirements.txt + +RUN pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.txt + +WORKDIR /app + +COPY . . + +# Set nvcc architecture +ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} +# Enable ROCm +ENV LLAMA_HIPBLAS=1 +ENV CC=/opt/rocm/llvm/bin/clang +ENV CXX=/opt/rocm/llvm/bin/clang++ + +RUN make + +ENTRYPOINT [ "/app/main" ] diff --git a/.dockerignore b/.dockerignore index 462fac23a..c6ef6c86c 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,14 +5,7 @@ .vscode/ .DS_Store -build/ -build-em/ -build-debug/ -build-release/ -build-static/ -build-no-accel/ -build-sanitize-addr/ -build-sanitize-thread/ +build*/ models/* diff --git a/.gitignore b/.gitignore index 6cb7d9bc6..e5faab774 100644 --- a/.gitignore +++ b/.gitignore @@ -16,20 +16,7 @@ .vs/ .vscode/ -build/ -build-em/ -build-debug/ -build-release/ -build-ci-debug/ -build-ci-release/ -build-static/ -build-cublas/ -build-opencl/ -build-metal/ -build-mpi/ -build-no-accel/ -build-sanitize-addr/ -build-sanitize-thread/ +build*/ out/ tmp/ diff --git a/CMakeLists.txt b/CMakeLists.txt index bb63ef98e..ba008bcc6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,7 @@ set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kern set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") +option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" OFF) option(LLAMA_MPI "llama: use MPI" OFF) @@ -352,6 +353,43 @@ if (LLAMA_CLBLAST) endif() endif() +if (LLAMA_HIPBLAS) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm) + + if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang") + endif() + if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") + endif() + + find_package(hip) + find_package(hipblas) + find_package(rocblas) + + if (${hipblas_FOUND} AND ${hip_FOUND}) + message(STATUS "HIP and hipBLAS found") + add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) + add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) + if (LLAMA_CUDA_FORCE_DMMV) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV) + endif() + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) + target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) + target_compile_definitions(ggml-rocm PRIVATE CC_TURING=1000000000) + set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) + target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) + + if (LLAMA_STATIC) + message(FATAL_ERROR "Static linking not supported for HIP/ROCm") + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm) + else() + message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") + endif() +endif() + if (LLAMA_ALL_WARNINGS) if (NOT MSVC) set(c_flags diff --git a/Makefile b/Makefile index d31acc450..a3400e491 100644 --- a/Makefile +++ b/Makefile @@ -280,6 +280,30 @@ ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h $(CXX) $(CXXFLAGS) -c $< -o $@ endif # LLAMA_CLBLAST +ifdef LLAMA_HIPBLAS + ROCM_PATH ?= /opt/rocm + HIPCC ?= $(ROCM_PATH)/bin/hipcc + GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch) + LLAMA_CUDA_DMMV_X ?= 32 + LLAMA_CUDA_MMV_Y ?= 1 + LLAMA_CUDA_KQUANTS_ITER ?= 2 + CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS + CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS + LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib + LDFLAGS += -lhipblas -lamdhip64 -lrocblas + HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) + HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) + HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y) + HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) + HIPFLAGS += -DCC_TURING=1000000000 +ifdef LLAMA_CUDA_FORCE_DMMV + HIPFLAGS += -DGGML_CUDA_FORCE_DMMV +endif # LLAMA_CUDA_FORCE_DMMV + OBJS += ggml-cuda.o +ggml-cuda.o: ggml-cuda.cu ggml-cuda.h + $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< +endif # LLAMA_HIPBLAS + ifdef LLAMA_METAL CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG CXXFLAGS += -DGGML_USE_METAL diff --git a/README.md b/README.md index eebb11392..95471fdbb 100644 --- a/README.md +++ b/README.md @@ -422,6 +422,35 @@ Building the program with BLAS support may lead to some performance improvements | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | +- #### hipBLAS + + This provide BLAS acceleation on HIP supported GPU like AMD GPU. + Make sure to have ROCm installed. + You can download it from your Linux distro's package manager or from here: [ROCm Quick Start (Linux)](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html). + Windows support is coming soon... + + - Using `make`: + ```bash + make LLAMA_HIPBLAS=1 + ``` + - Using `CMake`: + ```bash + mkdir build + cd build + CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DLLAMA_HIPBLAS=ON + cmake --build . + ``` + + The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used. + If your GPU is not officialy supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 or 11.0.0 on RDNA3. + The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above): + + | Option | Legal values | Default | Description | + |-------------------------|------------------------|---------|-------------| + | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the HIP dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | + | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the HIP mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | + | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per HIP thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | + - #### CLBlast OpenCL acceleration is provided by the matrix multiplication kernels from the [CLBlast](https://github.com/CNugteren/CLBlast) project and custom kernels for ggml that can generate tokens on the GPU. diff --git a/common/common.cpp b/common/common.cpp index 53002ba30..ff19ec4e5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -613,9 +613,11 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n"); fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n"); +#ifdef GGML_USE_CUBLAS fprintf(stdout, " -nommq, --no-mul-mat-q\n"); - fprintf(stdout, " use cuBLAS instead of custom mul_mat_q CUDA kernels.\n"); + fprintf(stdout, " use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n"); fprintf(stdout, " Not recommended since this is both slower and uses more VRAM.\n"); +#endif // GGML_USE_CUBLAS #endif fprintf(stdout, " --mtest compute maximum memory usage\n"); fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n"); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 36057bfca..7a2811584 100755 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -18,9 +18,7 @@ #include "llama.h" #include "common.h" #include "build-info.h" -#ifdef GGML_USE_CUBLAS #include "ggml-cuda.h" -#endif // utils static uint64_t get_time_ns() { @@ -504,7 +502,7 @@ struct test { static std::string get_backend() { if (cuda) { - return "CUDA"; + return GGML_CUDA_NAME; } if (opencl) { return "OpenCL"; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3bd1caf23..83d53c13c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6,15 +6,116 @@ #include #include +#if defined(GGML_USE_HIPBLAS) +#include +#include +#include +#ifdef __HIP_PLATFORM_AMD__ +// for rocblas_initialize() +#include "rocblas/rocblas.h" +#endif +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_32F HIPBLAS_R_32F +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define cublasCreate hipblasCreate +#define cublasGemmEx hipblasGemmEx +#define cublasHandle_t hipblasHandle_t +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEvent_t hipEvent_t +#define cudaEventDestroy hipEventDestroy +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaMalloc hipMalloc +#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMemcpy hipMemcpy +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyKind hipMemcpyKind +#define cudaMemset hipMemset +#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize +#define cudaSetDevice hipSetDevice +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0) +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess +#else #include #include #include +#endif #include "ggml-cuda.h" #include "ggml.h" #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#ifndef CC_TURING #define CC_TURING 700 +#endif + +#if defined(GGML_USE_HIPBLAS) +#define __CUDA_ARCH__ 1300 + +typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); +static __device__ __forceinline__ int __vsubss4(const int a, const int b) { + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); + const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); + return reinterpret_cast(c); +} + +static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { +#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) + c = __builtin_amdgcn_sdot4(a, b, c, false); +#elif defined(__gfx1100__) + c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); +#elif defined(__gfx1010__) || defined(__gfx900__) + int tmp1; + int tmp2; + asm("\n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + " + : "+v"(c), "=&v"(tmp1), "=&v"(tmp2) + : "v"(a), "v"(b) + ); +#else + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); + c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3]; +#endif + return c; +} +#endif #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -424,8 +525,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; - const dfloat d = x[ib].dm.x; - const dfloat m = x[ib].dm.y; + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); const int vui = x[ib].qs[iqs]; @@ -467,8 +568,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; - const dfloat d = x[ib].dm.x; - const dfloat m = x[ib].dm.y; + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); @@ -520,8 +621,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float const uint8_t q = x[i].qs[32*n + l]; float * y = yy + i*QK_K + 128*n; - float dall = x[i].dm.x; - float dmin = x[i].dm.y; + float dall = __low2half(x[i].dm); + float dmin = __high2half(x[i].dm); y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); @@ -531,8 +632,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float const int il = tid%16; // 0...15 const uint8_t q = x[i].qs[il] >> (2*is); float * y = yy + i*QK_K + 16*is + il; - float dall = x[i].dm.x; - float dmin = x[i].dm.y; + float dall = __low2half(x[i].dm); + float dmin = __high2half(x[i].dm); y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); #endif @@ -618,8 +719,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float float * y = yy + i*QK_K + 64*il + n*ir; - const float dall = x[i].dm.x; - const float dmin = x[i].dm.y; + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); const uint8_t * q = x[i].qs + 32*il + n*ir; @@ -657,8 +758,8 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float float * y = yy + i*QK_K + 64*il + 2*ir; - const float dall = x[i].dm.x; - const float dmin = x[i].dm.y; + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); const uint8_t * ql = x[i].qs + 32*il + 2*ir; const uint8_t * qh = x[i].qh + 2*ir; @@ -770,8 +871,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; - const float dall = x[i].dm.x; - const float dmin = x[i].dm.y; + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); aux[0] = a[0] & 0x0f0f0f0f; @@ -991,8 +1092,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; - const float dall = x[i].dm.x; - const float dmin = x[i].dm.y; + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); const uint16_t * a = (const uint16_t *)x[i].scales; aux[0] = a[im+0] & kmask1; @@ -1124,8 +1225,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; - const float dall = x[i].dm.x; - const float dmin = x[i].dm.y; + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); const uint16_t * a = (const uint16_t *)x[i].scales; aux[0] = a[im+0] & kmask1; @@ -1348,8 +1449,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest return; } - y[ib].ds.x = d; - y[ib].ds.y = sum; + reinterpret_cast(y[ib].ds.x) = d; + reinterpret_cast(y[ib].ds.y) = sum; } template @@ -2346,7 +2447,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); } - return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, bq8_1->ds.x); + return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); } template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { @@ -2432,7 +2533,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1( #pragma unroll for (int i = 0; i < QR2_K; ++ i) { u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); - d8[i] = bq8_1[bq8_offset + i].ds.x; + d8[i] = __low2half(bq8_1[bq8_offset + i].ds); } return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); @@ -2551,7 +2652,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( #pragma unroll for (int i = 0; i < QR3_K; ++i) { u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); - d8[i] = bq8_1[bq8_offset + i].ds.x; + d8[i] = __low2half(bq8_1[bq8_offset + i].ds); } return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); @@ -2720,7 +2821,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( for (int i = 0; i < QR4_K; ++i) { const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - d8[i] = bq8i->ds.x; + d8[i] = __low2half(bq8i->ds); const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); u[2*i+0] = q8[0]; @@ -2747,8 +2848,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const float dall = bq4_K->d[0]; const float dmin = bq4_K->d[1]; - const float d8_1 = bq8_1[0].ds.x; - const float d8_2 = bq8_1[1].ds.x; + const float d8_1 = __low2float(bq8_1[0].ds); + const float d8_2 = __low2float(bq8_1[1].ds); const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); @@ -2901,7 +3002,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( #pragma unroll for (int i = 0; i < QR5_K; ++i) { const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - d8[i] = bq8i->ds.x; + d8[i] = __low2float(bq8i->ds); const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); u[2*i+0] = q8[0]; @@ -2919,8 +3020,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( const float d = bq5_K->d; - const float d8_1 = bq8_1[0].ds.x; - const float d8_2 = bq8_1[1].ds.x; + const float d8_1 = __low2half(bq8_1[0].ds); + const float d8_2 = __low2half(bq8_1[1].ds); const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); @@ -3075,7 +3176,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( #pragma unroll for (int i = 0; i < QR6_K; ++i) { u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); - d8[i] = bq8_1[bq8_offset + 2*i].ds.x; + d8[i] = __low2half(bq8_1[bq8_offset + 2*i].ds); } return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); @@ -3243,7 +3344,7 @@ static __device__ __forceinline__ void mul_mat_q( *dsi_dst = *dsi_src; } else { float * dfi_dst = (float *) dsi_dst; - *dfi_dst = (*dsi_src).x; + *dfi_dst = __low2half(*dsi_src); } } @@ -4944,10 +5045,18 @@ void ggml_init_cublas() { static bool initialized = false; if (!initialized) { + +#ifdef __HIP_PLATFORM_AMD__ + // Workaround for a rocBLAS bug when using multiple graphics cards: + // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346 + rocblas_initialize(); + CUDA_CHECK(cudaDeviceSynchronize()); +#endif + CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; - fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count); + fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count); for (int id = 0; id < g_device_count; ++id) { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); diff --git a/ggml-cuda.h b/ggml-cuda.h index f66bb1678..a72e82069 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -2,6 +2,14 @@ #include "ggml.h" +#ifdef GGML_USE_HIPBLAS +#define GGML_CUDA_NAME "ROCm" +#define GGML_CUBLAS_NAME "hipBLAS" +#else +#define GGML_CUDA_NAME "CUDA" +#define GGML_CUBLAS_NAME "cuBLAS" +#endif + #ifdef __cplusplus extern "C" { #endif diff --git a/llama.cpp b/llama.cpp index 52ba31d79..d12b6d1cb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1836,7 +1836,7 @@ static void llm_load_tensors( (void) main_gpu; (void) mul_mat_q; #if defined(GGML_USE_CUBLAS) - LLAMA_LOG_INFO("%s: using CUDA for GPU acceleration\n", __func__); + LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__); ggml_cuda_set_main_device(main_gpu); ggml_cuda_set_mul_mat_q(mul_mat_q); #define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU