diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f32e83ab6..abad9cc39 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -68,8 +68,9 @@ #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #endif #define cudaMemcpy hipMemcpy -#define cudaMemcpy2DAsync hipMemcpy2DAsync #define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyPeerAsync hipMemcpyPeerAsync +#define cudaMemcpy2DAsync hipMemcpy2DAsync #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost #define cudaMemcpyHostToDevice hipMemcpyHostToDevice @@ -163,7 +164,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) { const int8x4_t vb = reinterpret_cast(b); #if __has_builtin(__builtin_elementwise_sub_sat) const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); - return reinterpret_cast(c); + return reinterpret_cast(c); #else int8x4_t c; int16_t tmp; @@ -174,7 +175,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) { if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); c[i] = tmp; } - return reinterpret_cast(c); + return reinterpret_cast(c); #endif // __has_builtin(__builtin_elementwise_sub_sat) } @@ -212,6 +213,28 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); +[[noreturn]] +static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) { + int id = -1; // in case cudaGetDevice fails + cudaGetDevice(&id); + + fprintf(stderr, "CUDA error: %s\n", msg); + fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line); + fprintf(stderr, " %s\n", stmt); + // abort with GGML_ASSERT to get a stack trace + GGML_ASSERT(!"CUDA error"); +} + +#define CUDA_CHECK_GEN(err, success, error_fn) \ + do { \ + auto err_ = (err); \ + if (err_ != (success)) { \ + ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \ + } \ + } while (0) + +#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) + #if CUDART_VERSION >= 12000 static const char * cublas_get_error_str(const cublasStatus_t err) { return cublasGetStatusString(err); @@ -233,15 +256,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } #endif // CUDART_VERSION >= 12000 -[[noreturn]] -static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) { - fprintf(stderr, "CUDA error: %s: %s\n", stmt, msg); - fprintf(stderr, " in function %s at %s:%d\n", func, file, line); - GGML_ASSERT(!"CUDA error"); -} - -#define CUDA_CHECK(err) do { auto err_ = (err); if (err_ != cudaSuccess) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cudaGetErrorString(err_)); } while (0) -#define CUBLAS_CHECK(err) do { auto err_ = (err); if (err_ != CUBLAS_STATUS_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cublas_get_error_str(err_)); } while (0) +#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) #if !defined(GGML_USE_HIPBLAS) static const char * cu_get_error_str(CUresult err) { @@ -249,7 +264,7 @@ static const char * cu_get_error_str(CUresult err) { cuGetErrorString(err, &err_str); return err_str; } -#define CU_CHECK(err) do { auto err_ = (err); if (err_ != CUDA_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cu_get_error_str(err_)); } while (0) +#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif #if CUDART_VERSION >= 11100 @@ -306,10 +321,10 @@ typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * s typedef void (*ggml_cuda_op_mul_mat_t)( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, const cudaStream_t & stream); + const int64_t src1_padded_row_size, cudaStream_t stream); typedef void (*ggml_cuda_op_flatten_t)( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream); + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream); // QK = number of values after dequantization // QR = QK / number of values before dequantization @@ -515,15 +530,15 @@ struct ggml_tensor_extra_gpu { // this is faster on Windows // probably because the Windows CUDA libraries forget to make this check before invoking the drivers -inline cudaError_t ggml_cuda_set_device(const int device) { +static void ggml_cuda_set_device(const int device) { int current_device; CUDA_CHECK(cudaGetDevice(¤t_device)); if (device == current_device) { - return cudaSuccess; + return; } - return cudaSetDevice(device); + CUDA_CHECK(cudaSetDevice(device)); } static int g_device_count = -1; @@ -538,7 +553,6 @@ struct cuda_device_capabilities { static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} }; - static void * g_scratch_buffer = nullptr; static size_t g_scratch_size = 0; // disabled by default static size_t g_scratch_offset = 0; @@ -580,6 +594,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; + GGML_UNUSED(a); } static __device__ __forceinline__ float op_add(const float a, const float b) { @@ -701,7 +716,7 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } -static __global__ void gelu_quick_f32(const float *x, float *dst, int k) { +static __global__ void gelu_quick_f32(const float * x, float * dst, int k) { const float GELU_QUICK_COEF = -1.702f; const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -710,7 +725,7 @@ static __global__ void gelu_quick_f32(const float *x, float *dst, int k) { dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i]))); } -static __global__ void tanh_f32(const float *x, float *dst, int k) { +static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; @@ -727,7 +742,7 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) { dst[i] = fmaxf(x[i], 0); } -static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) { +static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; @@ -780,7 +795,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c } } -static __global__ void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02) { +static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -805,7 +820,7 @@ static __global__ void concat_f32(const float *x,const float *y, float *dst, c } } -static __global__ void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor) { +static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) { int ne0 = ne00 * scale_factor; int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { @@ -825,7 +840,7 @@ static __global__ void upscale_f32(const float *x, float *dst, const int ne00, dst[offset_dst] = x[offset_src]; } -static __global__ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) { +static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -4727,7 +4742,6 @@ static __global__ void mul_mat_p021_f16_f32( const int row_y = col_x; - // y is not transposed but permuted const int iy = channel*nrows_y + row_y; @@ -5402,7 +5416,7 @@ struct bin_bcast_cuda { cne[3] = 1; }; - auto collapse_nb = [](size_t cnb[], int64_t cne[]) { + auto collapse_nb = [](size_t cnb[], const int64_t cne[]) { cnb[1] *= cne[1]; cnb[2] *= cne[2]; cnb[3] *= cne[3]; @@ -6566,18 +6580,16 @@ struct scoped_spin_lock { static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; // #define DEBUG_CUDA_MALLOC -struct cuda_buffer { +struct ggml_cuda_buffer { void * ptr = nullptr; size_t size = 0; }; -static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS]; +static ggml_cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS]; static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0}; -static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) { +static void * ggml_cuda_pool_malloc_leg(int device, size_t size, size_t * actual_size) { scoped_spin_lock lock(g_cuda_pool_lock); - int id; - CUDA_CHECK(cudaGetDevice(&id)); #ifdef DEBUG_CUDA_MALLOC int nnz = 0; size_t max_size = 0; @@ -6585,7 +6597,7 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) { size_t best_diff = 1ull << 36; int ibest = -1; for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { - cuda_buffer& b = g_cuda_buffer_pool[id][i]; + ggml_cuda_buffer& b = g_cuda_buffer_pool[device][i]; if (b.ptr != nullptr) { #ifdef DEBUG_CUDA_MALLOC ++nnz; @@ -6608,7 +6620,7 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) { } } if (ibest >= 0) { - cuda_buffer& b = g_cuda_buffer_pool[id][ibest]; + ggml_cuda_buffer& b = g_cuda_buffer_pool[device][ibest]; void * ptr = b.ptr; *actual_size = b.size; b.ptr = nullptr; @@ -6618,9 +6630,10 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) { void * ptr; size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); + ggml_cuda_set_device(device); CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); *actual_size = look_ahead_size; - g_cuda_pool_size[id] += look_ahead_size; + g_cuda_pool_size[device] += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz, (uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024)); @@ -6628,13 +6641,11 @@ static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) { return ptr; } -static void ggml_cuda_pool_free_leg(void * ptr, size_t size) { +static void ggml_cuda_pool_free_leg(int device, void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); - int id; - CUDA_CHECK(cudaGetDevice(&id)); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { - cuda_buffer& b = g_cuda_buffer_pool[id][i]; + ggml_cuda_buffer& b = g_cuda_buffer_pool[device][i]; if (b.ptr == nullptr) { b.ptr = ptr; b.size = size; @@ -6642,73 +6653,73 @@ static void ggml_cuda_pool_free_leg(void * ptr, size_t size) { } } fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); + ggml_cuda_set_device(device); CUDA_CHECK(cudaFree(ptr)); - g_cuda_pool_size[id] -= size; + g_cuda_pool_size[device] -= size; } #if !defined(GGML_USE_HIPBLAS) // pool with virtual memory -static std::vector g_cuda_pool_handles[GGML_CUDA_MAX_DEVICES]; static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0}; static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0}; static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 36; // 64 GB -static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) { +static void * ggml_cuda_pool_malloc_vmm(int device, size_t size, size_t * actual_size) { scoped_spin_lock lock(g_cuda_pool_lock); - int id; - CUDA_CHECK(cudaGetDevice(&id)); // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types const size_t alignment = 128; size = alignment * ((size + alignment - 1) / alignment); - size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id]; + size_t avail = g_cuda_pool_size[device] - g_cuda_pool_used[device]; if (size > avail) { // round up to the next multiple of the granularity size_t reserve_size = size - avail; - const size_t granularity = g_device_caps[id].vmm_granularity; + const size_t granularity = g_device_caps[device].vmm_granularity; reserve_size = granularity * ((reserve_size + granularity - 1) / granularity); - GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); + GGML_ASSERT(g_cuda_pool_size[device] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); // allocate more physical memory CUmemAllocationProp prop = {}; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = id; + prop.location.id = device; CUmemGenericAllocationHandle handle; CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0)); // reserve virtual address space (if not already reserved) - if (g_cuda_pool_addr[id] == 0) { - CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); + if (g_cuda_pool_addr[device] == 0) { + CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[device], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); } // map at the end of the pool - CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0)); + CU_CHECK(cuMemMap(g_cuda_pool_addr[device] + g_cuda_pool_size[device], reserve_size, 0, handle, 0)); + + // the memory allocation handle is no longer needed after mapping + CU_CHECK(cuMemRelease(handle)); // set access CUmemAccessDesc access = {}; access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - access.location.id = id; + access.location.id = device; access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1)); + CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[device] + g_cuda_pool_size[device], reserve_size, &access, 1)); // add to the pool - g_cuda_pool_handles[id].push_back(handle); - g_cuda_pool_size[id] += reserve_size; + g_cuda_pool_size[device] += reserve_size; //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n", // id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024), // (unsigned long long) (reserve_size/1024/1024)); } - GGML_ASSERT(g_cuda_pool_addr[id] != 0); + GGML_ASSERT(g_cuda_pool_addr[device] != 0); - void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]); + void * ptr = (void *) (g_cuda_pool_addr[device] + g_cuda_pool_used[device]); *actual_size = size; - g_cuda_pool_used[id] += size; + g_cuda_pool_used[device] += size; #ifdef DEBUG_CUDA_MALLOC printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr); @@ -6717,38 +6728,32 @@ static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) { return ptr; } -static void ggml_cuda_pool_free_vmm(void * ptr, size_t size) { +static void ggml_cuda_pool_free_vmm(int device, void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); - int id; - CUDA_CHECK(cudaGetDevice(&id)); #ifdef DEBUG_CUDA_MALLOC printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr); #endif - g_cuda_pool_used[id] -= size; + g_cuda_pool_used[device] -= size; // all deallocations must be in reverse order of the allocations - GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id])); + GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[device] + g_cuda_pool_used[device])); } -static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); - if (g_device_caps[id].vmm) { - return ggml_cuda_pool_malloc_vmm(size, actual_size); +static void * ggml_cuda_pool_malloc(int device, size_t size, size_t * actual_size) { + if (g_device_caps[device].vmm) { + return ggml_cuda_pool_malloc_vmm(device, size, actual_size); } else { - return ggml_cuda_pool_malloc_leg(size, actual_size); + return ggml_cuda_pool_malloc_leg(device, size, actual_size); } } -static void ggml_cuda_pool_free(void * ptr, size_t size) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); - if (g_device_caps[id].vmm) { - ggml_cuda_pool_free_vmm(ptr, size); +static void ggml_cuda_pool_free(int device, void * ptr, size_t size) { + if (g_device_caps[device].vmm) { + ggml_cuda_pool_free_vmm(device, ptr, size); } else { - ggml_cuda_pool_free_leg(ptr, size); + ggml_cuda_pool_free_leg(device, ptr, size); } } #else @@ -6758,13 +6763,15 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { template struct cuda_pool_alloc { + int device = -1; T * ptr = nullptr; size_t actual_size = 0; // size is in number of elements T * alloc(size_t size) { GGML_ASSERT(ptr == nullptr); - ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->actual_size); + CUDA_CHECK(cudaGetDevice(&device)); + ptr = (T *) ggml_cuda_pool_malloc(device, size * sizeof(T), &this->actual_size); return ptr; } @@ -6774,7 +6781,7 @@ struct cuda_pool_alloc { ~cuda_pool_alloc() { if (ptr != nullptr) { - ggml_cuda_pool_free(ptr, actual_size); + ggml_cuda_pool_free(device, ptr, actual_size); } } @@ -6839,7 +6846,7 @@ void ggml_init_cublas() { alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; alloc_prop.location.id = id; - CU_CHECK(cuMemGetAllocationGranularity(&g_device_caps[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + CU_CHECK(cuMemGetAllocationGranularity(&g_device_caps[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); } #endif // !defined(GGML_USE_HIPBLAS) g_device_caps[id].vmm = !!device_vmm; @@ -6861,7 +6868,7 @@ void ggml_init_cublas() { } for (int id = 0; id < g_device_count; ++id) { - CUDA_CHECK(ggml_cuda_set_device(id)); + ggml_cuda_set_device(id); // create cuda streams for (int is = 0; is < MAX_STREAMS; ++is) { @@ -6976,7 +6983,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( static void ggml_cuda_op_get_rows( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) { + const float * src0_d, const float * src1_d, float * dst_d, cudaStream_t stream) { GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -7018,9 +7025,9 @@ static void ggml_cuda_op_get_rows( } template -inline void ggml_cuda_op_bin_bcast( +static void ggml_cuda_op_bin_bcast( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -7039,7 +7046,7 @@ inline void ggml_cuda_op_bin_bcast( static void ggml_cuda_op_repeat( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) { + const float * src0_d, const float * src1_d, float * dst_d, cudaStream_t main_stream) { ggml_cuda_op_bin_bcast>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream); @@ -7047,16 +7054,16 @@ static void ggml_cuda_op_repeat( (void) src1_d; } -inline void ggml_cuda_op_add( +static void ggml_cuda_op_add( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } -inline void ggml_cuda_op_acc( +static void ggml_cuda_op_acc( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -7073,23 +7080,23 @@ inline void ggml_cuda_op_acc( (void) dst; } -inline void ggml_cuda_op_mul( +static void ggml_cuda_op_mul( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } -inline void ggml_cuda_op_div( +static void ggml_cuda_op_div( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } -inline void ggml_cuda_op_gelu( +static void ggml_cuda_op_gelu( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7101,9 +7108,9 @@ inline void ggml_cuda_op_gelu( (void) src1_dd; } -inline void ggml_cuda_op_silu( +static void ggml_cuda_op_silu( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7115,9 +7122,9 @@ inline void ggml_cuda_op_silu( (void) src1_dd; } -inline void ggml_cuda_op_gelu_quick( +static void ggml_cuda_op_gelu_quick( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7129,9 +7136,9 @@ inline void ggml_cuda_op_gelu_quick( (void) src1_dd; } -inline void ggml_cuda_op_tanh( +static void ggml_cuda_op_tanh( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7143,9 +7150,9 @@ inline void ggml_cuda_op_tanh( (void) src1_dd; } -inline void ggml_cuda_op_relu( +static void ggml_cuda_op_relu( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7157,9 +7164,9 @@ inline void ggml_cuda_op_relu( (void) src1_dd; } -inline void ggml_cuda_op_leaky_relu( +static void ggml_cuda_op_leaky_relu( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7174,9 +7181,9 @@ inline void ggml_cuda_op_leaky_relu( (void) src1_dd; } -inline void ggml_cuda_op_sqr( +static void ggml_cuda_op_sqr( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7188,9 +7195,9 @@ inline void ggml_cuda_op_sqr( (void) src1_dd; } -inline void ggml_cuda_op_norm( +static void ggml_cuda_op_norm( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7208,10 +7215,9 @@ inline void ggml_cuda_op_norm( (void) src1_dd; } - -inline void ggml_cuda_op_group_norm( +static void ggml_cuda_op_group_norm( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7225,9 +7231,9 @@ inline void ggml_cuda_op_group_norm( (void) src1_dd; } -inline void ggml_cuda_op_concat( +static void ggml_cuda_op_concat( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -7241,9 +7247,9 @@ inline void ggml_cuda_op_concat( (void) dst; } -inline void ggml_cuda_op_upscale( +static void ggml_cuda_op_upscale( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -7258,9 +7264,9 @@ inline void ggml_cuda_op_upscale( (void) src1_dd; } -inline void ggml_cuda_op_pad( +static void ggml_cuda_op_pad( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -7275,9 +7281,9 @@ inline void ggml_cuda_op_pad( (void) src1_dd; } -inline void ggml_cuda_op_rms_norm( +static void ggml_cuda_op_rms_norm( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7295,10 +7301,10 @@ inline void ggml_cuda_op_rms_norm( (void) src1_dd; } -inline void ggml_cuda_op_mul_mat_q( +static void ggml_cuda_op_mul_mat_q( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, const cudaStream_t & stream) { + const int64_t src1_padded_row_size, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; @@ -7360,7 +7366,7 @@ inline void ggml_cuda_op_mul_mat_q( static int64_t get_row_rounding(ggml_type type) { int64_t min_compute_capability = INT_MAX; int64_t max_compute_capability = INT_MIN; - for (int64_t id = 0; id < g_device_count; ++id) { + for (int id = 0; id < g_device_count; ++id) { if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { if (min_compute_capability > g_device_caps[id].cc) { min_compute_capability = g_device_caps[id].cc; @@ -7418,10 +7424,10 @@ static int64_t get_row_rounding(ggml_type type) { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } -inline void ggml_cuda_op_mul_mat_vec_q( +static void ggml_cuda_op_mul_mat_vec_q( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, const cudaStream_t & stream) { + const int64_t src1_padded_row_size, cudaStream_t stream) { GGML_ASSERT(ggml_nrows(src1) == 1); @@ -7471,10 +7477,10 @@ inline void ggml_cuda_op_mul_mat_vec_q( (void) src1_padded_row_size; } -inline void ggml_cuda_op_dequantize_mul_mat_vec( +static void ggml_cuda_op_dequantize_mul_mat_vec( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, const cudaStream_t & stream) { + const int64_t src1_padded_row_size, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; @@ -7545,10 +7551,10 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( (void) src1_padded_row_size; } -inline void ggml_cuda_op_mul_mat_cublas( +static void ggml_cuda_op_mul_mat_cublas( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, const cudaStream_t & stream) { + const int64_t src1_padded_row_size, cudaStream_t stream) { GGML_ASSERT(src0_dd_i != nullptr); GGML_ASSERT(src1_ddf_i != nullptr); @@ -7637,9 +7643,9 @@ inline void ggml_cuda_op_mul_mat_cublas( (void) src1_padded_row_size; } -inline void ggml_cuda_op_rope( +static void ggml_cuda_op_rope( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); @@ -7717,9 +7723,9 @@ inline void ggml_cuda_op_rope( (void) src1_dd; } -inline void ggml_cuda_op_alibi( +static void ggml_cuda_op_alibi( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7748,9 +7754,9 @@ inline void ggml_cuda_op_alibi( (void) src1_dd; } -inline void ggml_cuda_op_im2col( +static void ggml_cuda_op_im2col( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -7783,10 +7789,9 @@ inline void ggml_cuda_op_im2col( (void) src0_dd; } - -inline void ggml_cuda_op_sum_rows( +static void ggml_cuda_op_sum_rows( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7801,9 +7806,9 @@ inline void ggml_cuda_op_sum_rows( (void) src1_dd; } -inline void ggml_cuda_op_argsort( +static void ggml_cuda_op_argsort( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_I32); @@ -7820,9 +7825,9 @@ inline void ggml_cuda_op_argsort( (void) src1_dd; } -inline void ggml_cuda_op_diag_mask_inf( +static void ggml_cuda_op_diag_mask_inf( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7840,9 +7845,9 @@ inline void ggml_cuda_op_diag_mask_inf( (void) src1_dd; } -inline void ggml_cuda_op_soft_max( +static void ggml_cuda_op_soft_max( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7861,9 +7866,9 @@ inline void ggml_cuda_op_soft_max( (void) dst; } -inline void ggml_cuda_op_scale( +static void ggml_cuda_op_scale( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7879,9 +7884,9 @@ inline void ggml_cuda_op_scale( (void) src1_dd; } -inline void ggml_cuda_op_clamp( +static void ggml_cuda_op_clamp( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -7974,12 +7979,12 @@ static void ggml_cuda_set_peer_access(const int n_tokens) { #ifdef NDEBUG for (int id = 0; id < g_device_count; ++id) { - CUDA_CHECK(ggml_cuda_set_device(id)); + ggml_cuda_set_device(id); CUDA_CHECK(cudaDeviceSynchronize()); } for (int id = 0; id < g_device_count; ++id) { - CUDA_CHECK(ggml_cuda_set_device(id)); + ggml_cuda_set_device(id); for (int id_other = 0; id_other < g_device_count; ++id_other) { if (id == id_other) { @@ -8013,7 +8018,6 @@ static void ggml_cuda_op_mul_mat( const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; - const int64_t nrows0 = ggml_nrows(src0); const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; @@ -8056,27 +8060,29 @@ static void ggml_cuda_op_mul_mat( GGML_ASSERT(!(split && ne03 > 1)); GGML_ASSERT(!(split && ne02 < ne12)); - // dd = data device - char * src0_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; - float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float - char * src1_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // q8_1 - float * dst_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; + struct dev_data { + cuda_pool_alloc src0_dd_alloc; + cuda_pool_alloc src1_ddf_alloc; + cuda_pool_alloc src1_ddq_alloc; + cuda_pool_alloc dst_dd_alloc; - // as = actual size - size_t src0_as[GGML_CUDA_MAX_DEVICES] = {0}; - size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; - size_t src1_asq[GGML_CUDA_MAX_DEVICES] = {0}; - size_t dst_as[GGML_CUDA_MAX_DEVICES] = {0}; + char * src0_dd = nullptr; + float * src1_ddf = nullptr; // float + char * src1_ddq = nullptr; // q8_1 + float * dst_dd = nullptr; - int64_t row_low[GGML_CUDA_MAX_DEVICES]; - int64_t row_high[GGML_CUDA_MAX_DEVICES]; + int64_t row_low; + int64_t row_high; + }; + + dev_data dev[GGML_CUDA_MAX_DEVICES]; int used_devices = 0; - for (int64_t id = 0; id < g_device_count; ++id) { + for (int id = 0; id < g_device_count; ++id) { // by default, use all rows - row_low[id] = 0; - row_high[id] = ne01; + dev[id].row_low = 0; + dev[id].row_high = ne01; // for multi GPU, get the row boundaries from tensor split // and round to mul_mat_q tile sizes @@ -8084,23 +8090,23 @@ static void ggml_cuda_op_mul_mat( const int64_t rounding = get_row_rounding(src0->type); if (id != 0) { - row_low[id] = ne01*g_tensor_split[id]; - if (row_low[id] < ne01) { - row_low[id] -= row_low[id] % rounding; + dev[id].row_low = ne01*g_tensor_split[id]; + if (dev[id].row_low < ne01) { + dev[id].row_low -= dev[id].row_low % rounding; } } if (id != g_device_count - 1) { - row_high[id] = ne01*g_tensor_split[id + 1]; - if (row_high[id] < ne01) { - row_high[id] -= row_high[id] % rounding; + dev[id].row_high = ne01*g_tensor_split[id + 1]; + if (dev[id].row_high < ne01) { + dev[id].row_high -= dev[id].row_high % rounding; } } } } - for (int64_t id = 0; id < g_device_count; ++id) { - if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { + for (int id = 0; id < g_device_count; ++id) { + if ((!split && id != g_main_device) || dev[id].row_low == dev[id].row_high) { continue; } @@ -8110,42 +8116,41 @@ static void ggml_cuda_op_mul_mat( const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; ggml_cuda_set_device(id); - const cudaStream_t stream = g_cudaStreams[id][0]; + cudaStream_t stream = g_cudaStreams[id][0]; if (src0_on_device && src0_is_contiguous) { - src0_dd[id] = (char *) src0_extra->data_device[id]; + dev[id].src0_dd = (char *) src0_extra->data_device[id]; } else { - // const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); - src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); + dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ggml_nbytes(src0)); } if (src1_on_device && src1_is_contiguous) { - src1_ddf[id] = (float *) src1_extra->data_device[id]; + dev[id].src1_ddf = (float *) src1_extra->data_device[id]; } else { - src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]); + dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ggml_nelements(src1)); } if (convert_src1_to_q8_1) { - src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); + dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); if (src1_on_device && src1_is_contiguous) { - quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); + quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); } } if (dst_on_device) { - dst_dd[id] = (float *) dst_extra->data_device[id]; + dev[id].dst_dd = (float *) dst_extra->data_device[id]; } else { - const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); - dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]); + const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst); + dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(size_dst_ddf); } } // if multiple devices are used they need to wait for the main device // here an event is recorded that signals that the main device has finished calculating the input data if (split && used_devices > 1) { - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + ggml_cuda_set_device(g_main_device); CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0])); } @@ -8154,17 +8159,17 @@ static void ggml_cuda_op_mul_mat( const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0; const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; - for (int64_t id = 0; id < g_device_count; ++id) { - if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { + for (int id = 0; id < g_device_count; ++id) { + if ((!split && id != g_main_device) || dev[id].row_low == dev[id].row_high) { continue; } const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; - const int64_t row_diff = row_high[id] - row_low[id]; + const int64_t row_diff = dev[id].row_high - dev[id].row_low; ggml_cuda_set_device(id); - const cudaStream_t stream = g_cudaStreams[id][is]; + cudaStream_t stream = g_cudaStreams[id][is]; // wait for main GPU data if necessary if (split && (id != g_main_device || is != 0)) { @@ -8178,34 +8183,34 @@ static void ggml_cuda_op_mul_mat( const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; // for split tensors the data begins at i0 == i0_offset_low - char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; - float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10; - char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset; - float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); + char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; + float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10; + char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset; + float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); // the main device memory buffer can be on VRAM scratch, with space for all partial results // in that case an offset on dst_ddf_i is needed if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) { - dst_dd_i += row_low[id]; // offset is 0 if no tensor split + dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split } // copy src0, src1 to device if necessary if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { if (id != g_main_device) { if (convert_src1_to_q8_1) { - char * src1_ddq_i_source = src1_ddq[g_main_device] + src1_ddq_i_offset; - CUDA_CHECK(cudaMemcpyAsync(src1_ddq_i, src1_ddq_i_source, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, - cudaMemcpyDeviceToDevice, stream)); + char * src1_ddq_i_source = dev[g_main_device].src1_ddq + src1_ddq_i_offset; + CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, g_main_device, + src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream)); } else { float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device]; src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; - CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_ncols*ne10*sizeof(float), - cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddf_i, id, src1_ddf_i_source, g_main_device, + src1_ncols*ne10*sizeof(float), stream)); } } } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) { CUDA_CHECK(ggml_cuda_cpy_tensor_2d( - src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); + src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); } else { GGML_ASSERT(false); } @@ -8216,12 +8221,12 @@ static void ggml_cuda_op_mul_mat( } if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream)); } // do the computation op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, - row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream); + dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); // copy dst to host or other device if necessary @@ -8245,9 +8250,25 @@ static void ggml_cuda_op_mul_mat( // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results. float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); - dhf_dst_i += src1_col_0*ne0 + row_low[id]; - CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float), - row_diff*sizeof(float), src1_ncols, kind, stream)); + dhf_dst_i += src1_col_0*ne0 + dev[id].row_low; +#if !defined(GGML_USE_HIPBLAS) + if (kind == cudaMemcpyDeviceToDevice) { + // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices + cudaMemcpy3DPeerParms p = {}; + p.dstDevice = g_main_device; + p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols); + p.srcDevice = id; + p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols); + p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1); + CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream)); + } else +#endif + { + CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), + dst_dd_i, row_diff*sizeof(float), + row_diff*sizeof(float), src1_ncols, + kind, stream)); + } } else { float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); @@ -8264,35 +8285,14 @@ static void ggml_cuda_op_mul_mat( } } - for (int64_t id = 0; id < g_device_count; ++id) { - if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { - continue; - } - CUDA_CHECK(ggml_cuda_set_device(id)); - - // free buffers again when done - if (dst_as[id] > 0) { - ggml_cuda_pool_free(dst_dd[id], dst_as[id]); - } - if (src1_asq[id] > 0) { - ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); - } - if (src1_asf[id] > 0) { - ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); - } - if (src0_as[id] > 0) { - ggml_cuda_pool_free(src0_dd[id], src0_as[id]); - } - } - // main device waits for all other devices to be finished if (split && g_device_count > 1) { int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS; - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - for (int64_t id = 0; id < g_device_count; ++id) { - if (row_low[id] == row_high[id]) { + ggml_cuda_set_device(g_main_device); + for (int id = 0; id < g_device_count; ++id) { + if (dev[id].row_low == dev[id].row_high) { continue; } for (int64_t is = 0; is < is_max; ++is) { @@ -8302,7 +8302,7 @@ static void ggml_cuda_op_mul_mat( } if (dst->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + ggml_cuda_set_device(g_main_device); CUDA_CHECK(cudaDeviceSynchronize()); } } @@ -8412,7 +8412,7 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens const int64_t ne12 = src1->ne[2]; - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + ggml_cuda_set_device(g_main_device); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; @@ -8444,7 +8444,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor const int64_t ne12 = src1->ne[2]; - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + ggml_cuda_set_device(g_main_device); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; @@ -8515,7 +8515,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const const int64_t ne1 = ggml_nelements(src1); const int64_t ne = ggml_nelements(dst); - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + ggml_cuda_set_device(g_main_device); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream)); @@ -8656,7 +8656,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; int64_t min_compute_capability = INT_MAX; - for (int64_t id = 0; id < g_device_count; ++id) { + for (int id = 0; id < g_device_count; ++id) { if (min_compute_capability > g_device_caps[id].cc && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { min_compute_capability = g_device_caps[id].cc; } @@ -8799,7 +8799,7 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) { const int64_t ne1 = ggml_nelements(src1); const int64_t ne = ggml_nelements(dst); - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + ggml_cuda_set_device(g_main_device); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream)); @@ -8917,7 +8917,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s std::vector ids_host(ggml_nbytes(ids)); - const cudaStream_t stream = g_cudaStreams[g_main_device][0]; + cudaStream_t stream = g_cudaStreams[g_main_device][0]; if (ids->backend == GGML_BACKEND_GPU) { const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device]; @@ -9073,7 +9073,7 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg const int64_t nb11 = src1->nb[1]; const int64_t nb12 = src1->nb[2]; - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + ggml_cuda_set_device(g_main_device); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; @@ -9163,7 +9163,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu; memset(extra, 0, sizeof(*extra)); - for (int64_t id = 0; id < g_device_count; ++id) { + for (int id = 0; id < g_device_count; ++id) { if (backend == GGML_BACKEND_GPU && id != g_main_device) { continue; } @@ -9234,15 +9234,14 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; - for (int64_t id = 0; id < g_device_count; ++id) { + for (int id = 0; id < g_device_count; ++id) { + ggml_cuda_set_device(id); if (extra->data_device[id] != nullptr) { - CUDA_CHECK(ggml_cuda_set_device(id)); CUDA_CHECK(cudaFree(extra->data_device[id])); } for (int64_t is = 0; is < MAX_STREAMS; ++is) { if (extra->events[id][is] != nullptr) { - CUDA_CHECK(ggml_cuda_set_device(id)); CUDA_CHECK(cudaEventDestroy(extra->events[id][is])); } } @@ -9296,7 +9295,7 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra force_inplace; const size_t size = ggml_nbytes(tensor); - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + ggml_cuda_set_device(g_main_device); if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; @@ -9373,7 +9372,7 @@ void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) { GGML_ASSERT(ggml_is_contiguous(tensor)); ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + ggml_cuda_set_device(g_main_device); CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice)); } diff --git a/ggml.c b/ggml.c index d24560480..ed56e60a8 100644 --- a/ggml.c +++ b/ggml.c @@ -4041,7 +4041,6 @@ static struct ggml_tensor * ggml_group_norm_impl( result->op = GGML_OP_GROUP_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; // TODO: maybe store epsilon here? return result; } @@ -5541,7 +5540,6 @@ static struct ggml_tensor * ggml_upscale_impl( result->op_params[0] = scale_factor; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5846,7 +5844,6 @@ struct ggml_tensor * ggml_get_rel_pos( result->op = GGML_OP_GET_REL_POS; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } diff --git a/llama.cpp b/llama.cpp index 0b99f1e03..4aa59c4c0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9519,7 +9519,8 @@ struct llama_context * llama_new_context_with_model( ctx->alloc = ggml_allocr_new_from_buffer(ctx->buf_alloc); #if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) if (model->n_gpu_layers > 0) { - ggml_cuda_set_scratch_size(alloc_size); + // the CPU buffer adds this padding in case the malloc buffer is not aligned, so we need to do the same for the GPU buffer, since we use the same offsets + ggml_cuda_set_scratch_size(alloc_size + 64); LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MiB\n", __func__, alloc_size / 1024.0 / 1024.0); // calculate total VRAM usage