Compare commits

...

68 Commits

Author SHA1 Message Date
Georgi Gerganov 3a468e6f9f
llama : fix type of KQ_mask and KQ_pos 2024-03-22 17:12:17 +02:00
Georgi Gerganov 9495d3982d
Merge branch 'master' into gg/flash-attn 2024-03-22 16:34:34 +02:00
Georgi Gerganov 58c7f6167c
ggml : fix F16 store (ARM NEON) 2024-03-04 20:44:57 +02:00
Georgi Gerganov e307882c34
Merge branch 'master' into gg/flash-attn 2024-03-04 20:42:48 +02:00
Georgi Gerganov 6aefd11204
llama : adapt new models to F16 KQ_mask 2024-03-03 13:50:54 +02:00
Georgi Gerganov 02a645e7b7
Merge branch 'master' into gg/flash-attn 2024-03-03 13:44:11 +02:00
Georgi Gerganov f249c997a8
llama : adapt to F16 KQ_pos 2024-02-19 13:31:02 +02:00
Georgi Gerganov 31109ca00a
Merge branch 'master' into gg/flash-attn 2024-02-19 12:58:18 +02:00
Georgi Gerganov 6875997fd6
Merge branch 'master' into gg/flash-attn 2024-02-12 21:16:58 +02:00
Georgi Gerganov 1846e92a90
cuda : minor 2024-02-04 11:01:01 +02:00
Georgi Gerganov ef68fac2a8
cuda : fix matrix names 2024-02-03 18:36:58 +02:00
Georgi Gerganov cfd9732b2e
cuda : simplify softmax 2024-02-03 18:31:55 +02:00
Georgi Gerganov e04ff39181
cuda : fix -INF block check 2024-02-03 16:57:46 +02:00
Georgi Gerganov 5b263dd83a
cuda : unroll Q*K^T loop 2024-02-03 16:12:20 +02:00
Georgi Gerganov 3b1c4e7673
cuda : speed-up reduce part of the kernel 2024-02-03 15:36:05 +02:00
Georgi Gerganov a7b471569b
cuda : switch to 1 warp for bs > 16 2024-02-03 15:17:49 +02:00
Georgi Gerganov b958151e3f
cuda : use half2 in softmax 2024-02-03 15:00:25 +02:00
Georgi Gerganov c51f27c0db
cuda : avoid __hisinf branches 2024-02-03 14:27:36 +02:00
Georgi Gerganov 92472ea22c
cuda : unroll some of the loops 2024-02-03 14:10:01 +02:00
Georgi Gerganov 1f8a592482
cuda : make loops use the same loop values
Thanks Johannes again for the tip
2024-02-03 14:01:32 +02:00
Georgi Gerganov 7c34655b36
cuda : use int instead of int64_t
Noticeably improves performance (thanks to Johannes)
2024-02-03 13:39:46 +02:00
Georgi Gerganov b150abe83e
cuda : avoid warp_reduce for smax 2024-02-03 13:17:47 +02:00
Georgi Gerganov b68a112204
cuda : fix __hisinf() result check 2024-02-02 15:12:28 +02:00
Georgi Gerganov 12eaa22628
tests : update dims 2024-02-02 11:55:38 +02:00
Georgi Gerganov db1f3c482e
cuda : avoid zeroing fragments 2024-02-01 22:08:37 +02:00
Georgi Gerganov c6769b9422
tests : minor fix 2024-02-01 21:24:26 +02:00
Georgi Gerganov cda5a60a41
metal : optimize softmax 2024-02-01 21:05:31 +02:00
Georgi Gerganov 56e45a239e
metal : optimize softmax for C > 32 2024-02-01 20:16:32 +02:00
Georgi Gerganov 41d136b602
Merge branch 'master' into gg/flash-attn 2024-02-01 19:51:41 +02:00
Georgi Gerganov 5a19a9f6d0
cuda : add flash_attn kernel (wip) 2024-02-01 19:50:23 +02:00
Georgi Gerganov 2e46013749
cuda : fix soft_max to use correct mask size 2024-02-01 16:47:20 +02:00
Georgi Gerganov 910b15bb40
ggml : fix ggml_soft_max mask requirement 2024-02-01 16:41:02 +02:00
Georgi Gerganov 8ad92dc1ec
ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext 2024-01-31 20:39:29 +02:00
Georgi Gerganov 2ddc9bbef1
Merge branch 'master' into gg/flash-attn 2024-01-31 18:49:43 +02:00
Georgi Gerganov 3d03bcb7af
Merge branch 'master' into gg/flash-attn 2024-01-30 21:49:13 +02:00
Georgi Gerganov 78df5527e4
tests : ifdef 2024-01-30 21:46:49 +02:00
Georgi Gerganov d073e4f933
metal : fix array initialization 2024-01-30 21:45:32 +02:00
Georgi Gerganov 5fcb9c1c5a
metal : faster inner loop for C == 32 2024-01-29 19:51:26 +02:00
Georgi Gerganov c6c1132e5e
tests : more 2024-01-29 18:22:28 +02:00
Georgi Gerganov abeaf0d90e
metal : disable buffer allocation logs 2024-01-29 18:12:24 +02:00
Georgi Gerganov 4794821a31
tests : add ATTN tests 2024-01-29 16:44:55 +02:00
Georgi Gerganov 1db22d7032
metal : support Q > 8 2024-01-28 23:16:20 +02:00
Georgi Gerganov 134c81c78d
metal : minor 2024-01-28 22:23:40 +02:00
Georgi Gerganov 0ad44baf33
Merge branch 'master' into gg/flash-attn 2024-01-28 21:53:51 +02:00
Georgi Gerganov 8612864108
ggml : fix f16 mad 2024-01-28 18:10:16 +02:00
Georgi Gerganov 3a428a1097
metal : improve precision 2024-01-28 17:47:22 +02:00
Georgi Gerganov ecc466a460
metal : add tests, fix scaling, support C > 32 2024-01-28 16:06:18 +02:00
Georgi Gerganov 77f6976a87
metal : move output into local memory + optimize
- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments
2024-01-28 15:30:24 +02:00
Georgi Gerganov b3dd7d975f
Merge branch 'master' into gg/flash-attn 2024-01-28 10:54:11 +02:00
Georgi Gerganov 6fea843b24
metal : add parallel reduce version (disabled) 2024-01-25 18:09:30 +02:00
Georgi Gerganov f9ca5dcbe8
llama : avoid ggml_cast, use F32 query 2024-01-25 17:46:07 +02:00
Georgi Gerganov 40ea8cd1ac
metal : fix comment 2024-01-25 16:31:39 +02:00
Georgi Gerganov 432ad04ffa
metal : scale and mask in matrix form 2024-01-25 15:47:52 +02:00
Georgi Gerganov d917746ddb
metal : avoid redundant loads of the attention 2024-01-25 15:00:49 +02:00
Georgi Gerganov 1446a12b29
metal : efficient flash_attn_f16 implementation 2024-01-25 13:40:31 +02:00
Georgi Gerganov 17720fad66
metal : parallel reduce across heads 2024-01-21 23:01:46 +02:00
Georgi Gerganov 77d08f3272
metal : parallelize across KV size 2024-01-21 22:26:45 +02:00
Georgi Gerganov a4b6341c7b
wip : template for rows per warp 2024-01-21 19:06:30 +02:00
Georgi Gerganov f31955f5d1
wip : 4 rows per simd group 2024-01-21 18:01:28 +02:00
Georgi Gerganov 8cde449b8b
wip : 8 rows per simd group 2024-01-21 17:37:24 +02:00
Georgi Gerganov b97325800a
metal : specialize for head size 2024-01-21 12:01:55 +02:00
Georgi Gerganov 52ae085750
metal : reduce branches 2024-01-21 11:59:09 +02:00
Georgi Gerganov 528da7515e
metal : f16 precision 2024-01-21 11:13:24 +02:00
Georgi Gerganov 1173f49c3b
metal : initial implementation 2024-01-21 10:15:02 +02:00
Georgi Gerganov a9681febd6
ggml : online attention (CPU) 2024-01-20 16:45:41 +02:00
Georgi Gerganov c3cdfffa88
Merge branch 'master' into gg/flash-attn 2024-01-20 10:12:07 +02:00
Georgi Gerganov fa7ebcca99 ggml : fix GQA support in ggml_flash_attn_ext 2024-01-19 20:06:26 +02:00
Georgi Gerganov a1c004ef2e
ggml : add ggml_flash_attn_ext API 2024-01-18 18:55:48 +02:00
8 changed files with 1854 additions and 77 deletions

View File

@ -100,7 +100,7 @@ int main(int argc, char ** argv) {
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_max;
ctx_params.n_batch = 512;
ctx_params.n_batch = 2048;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

View File

@ -134,6 +134,7 @@
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <mma.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
@ -1621,7 +1622,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
return a;
}
#ifdef GGML_CUDA_F16
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
@ -1634,7 +1634,19 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
#endif // GGML_CUDA_F16
static __device__ __forceinline__ half warp_reduce_sum(half x) {
#if __CUDA_ARCH__ >= CC_VOLTA
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hadd(__shfl_xor_sync(0xffffffff, x, mask, 32), x);
}
return x;
#else
(void) x;
NO_DEVICE_CODE;
#endif
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
@ -1644,18 +1656,30 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
return x;
}
//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
//#pragma unroll
// for (int mask = 16; mask > 0; mask >>= 1) {
// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
// }
// return x;
//#else
// GGML_UNUSED(x);
// NO_DEVICE_CODE;
//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
//}
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
}
static __device__ __forceinline__ half warp_reduce_max(half x) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#else
(void) x;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
}
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
@ -2011,6 +2035,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
@ -7141,7 +7166,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
}
template <bool vals_smem, int ncols_template, int block_size_template>
static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
static __global__ void soft_max_f32(const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
@ -7183,7 +7208,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
@ -7359,6 +7384,536 @@ static __global__ void pool2d_nchw_kernel(
o_ptr[cur_oh * ow + cur_ow] = res;
}
#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256
template<int block_size, int k_seq_len, int k_head_dim>
static __global__ void flash_attn_f32(
const float* __restrict__ q,
const float* __restrict__ k,
const float* __restrict__ v,
float* __restrict__ kqv,
float kq_scale,
int head_dim, int seq_len, int num_heads) {
const int head = blockIdx.x / seq_len;
const int head_size = head_dim * seq_len;
const int s = blockIdx.x % seq_len;
extern __shared__ char flash_attn_shmem_f32[];
float* S = (float*)flash_attn_shmem_f32;
float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float));
// QK^T
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}
const int key_offset = is * head_dim + head * head_size;
const int query_offset = s * head_dim + head * head_size;
float tmp = 0.0f;
for(int d = 0; d < head_dim; d++) {
tmp += k[key_offset + d] * q[query_offset + d];
}
S[is] = tmp * kq_scale;
}
__syncthreads();
float max_val = -INFINITY;
// get the max
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}
max_val = fmaxf(max_val , S[is]);
}
max_val = warp_reduce_max(max_val);
{ // get max from all threads
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
warp_data[warp_id] = max_val;
}
__syncthreads();
max_val = warp_data[lane_id];
max_val = warp_reduce_max(max_val);
}
// softmax(QK^T)
float sum = 0.0f;
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}
float tmp = expf(S[is] - max_val);
sum += tmp;
S[is] = tmp;
}
__syncthreads();
sum = warp_reduce_sum(sum);
{ // softmax sum partials
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
warp_data[warp_id] = sum;
}
__syncthreads();
sum = warp_data[lane_id];
sum = warp_reduce_sum(sum);
}
float inv_sum = 1.0f / sum;
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}
S[is] *= inv_sum;
}
__syncthreads();
// softmax(QK^T)V
#pragma unroll
for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) {
const int d = threadIdx.x + d0;
if(d >= head_dim) {
break;
}
const int dst_index = d + s * head_dim + head * head_size;
const int value_offset = d * seq_len + head * head_size;
float temp = 0.0f;
#pragma unroll
for(int ic = 0; ic < k_seq_len;ic++) {
if(ic >= seq_len) {
break;
}
temp += v[value_offset + ic] * S[ic];
}
kqv[dst_index] = temp;
}
}
#if __CUDA_ARCH__ >= CC_VOLTA
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
#endif
// based on metal version
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
static __global__ void flash_attn_ext_f16(
const char* __restrict__ q,
const char* __restrict__ k,
const char* __restrict__ v,
const char* __restrict__ mask,
float* __restrict__ dst,
float scale,
int ne00,
int ne01,
int ne02,
int ne03,
int ne10,
int ne11,
int ne12,
int ne13,
int ne31,
int nb31,
int nb01,
int nb02,
int nb03,
int nb11,
int nb12,
int nb13,
int ne0,
int ne1,
int ne2,
int ne3) {
#if __CUDA_ARCH__ >= CC_VOLTA
const int warp_id = threadIdx.y;
const int lane_id = threadIdx.x;
const int num_warps = blockDim.y; // number of warps
const int iq3 = blockIdx.z;
const int iq2 = blockIdx.y;
const int iq1 = blockIdx.x * Q;
const int D16 = D/16;
const int Q16 = Q/16;
const int C16 = C/16;
const int NW = WARP_SIZE;
const int SH = (C + Q); // shared memory per simdgroup in (half)
const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2)
const int C2 = C/2;
const int D2 = D/2;
extern __shared__ half __flash_attn_f16_shmem[];
// pq
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2
half16x16_acc zr;
half16x16_acc lo[Q16][D16];
// load heads from Q to shared memory
#pragma unroll
for (int j0 = 0; j0 < Q; j0 += num_warps) {
const int j = j0 + warp_id;
if (j >= Q) {
break;
}
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
#pragma unroll
for (int i0 = 0; i0 < D2; i0 += NW) {
const int i = i0 + lane_id;
if (i >= D2) {
break;
}
if (iq1 + j < ne01) {
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
} else {
sq2[j*T2 + i] = make_half2(0.0, 0.0);
}
}
}
nvcuda::wmma::fill_fragment(zr, 0.0);
// zero out lo
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
}
}
// zero out shared memory SH
for (int j = 0; j < Q; ++j) {
for (int i0 = 0; i0 < SH; i0 += NW) {
const int i = i0 + lane_id;
if (i >= SH) {
break;
}
ss[j*T + i] = 0.0;
}
}
__syncthreads();
{
half S = __float2half(0.0f);
half M[Q];
for (int i = 0; i < Q; ++i) {
M[i] = CUDART_MIN_DENORM_FP16;
}
// assume K and V are same shape
const int ne22 = ne12;
const int ne23 = ne13;
const int nb21 = nb11;
const int nb22 = nb12;
const int nb23 = nb13;
// broadcast
const int rk2 = ne02/ne12;
const int rk3 = ne03/ne13;
const int rv2 = ne02/ne22;
const int rv3 = ne03/ne23;
// k indices
const int ik2 = iq2 / rk2;
const int ik3 = iq3 / rk3;
// v indices
const int iv2 = iq2 / rv2;
const int iv3 = iq3 / rv3;
// load the queries from shared memory into local memory
half16x16_a mq[Q16][D16];
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
}
}
// pointer to the mask
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
// prepare diagonal scale matrix
half16x16_b mscale;
for (int i = 0; i < 16; ++i) {
ss[i*T + i] = __float2half(scale);
}
nvcuda::wmma::load_matrix_sync(mscale, ss, T);
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) {
const int ic = ic0 + warp_id*C;
if (ic >= ne11) {
break;
}
// Q*K^T
{
#pragma unroll
for (int cc = 0; cc < C16; ++cc) {
half16x16_acc mqk[Q16];
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::fill_fragment(mqk[j], 0);
}
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (int i = 0; i < D16; ++i) {
half16x16_bT mk; // transposed key
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
}
}
// mqk = mqk*scale + mask
for (int j = 0; j < Q16; ++j) {
half16x16_a mqka;
half16x16_acc mm;
if (mp) {
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
}
// convert accumulator to matrix_a
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr);
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
}
}
}
// used to detect blocks full of -INF
half2 smax = make_half2(-INFINITY, -INFINITY);
// online softmax
for (int j = 0; j < Q; ++j) {
const half m = M[j];
for (int p0 = 0; p0 < C2; p0 += NW) {
const int p = p0 + lane_id;
const half2 s = ss2[j*T2 + p];
smax = __hmax2(smax, s);
M[j] = __hmax(M[j], __hmax(s.x, s.y));
}
M[j] = warp_reduce_max(M[j]);
// local sum
half2 ls = make_half2(0.0f, 0.0f);
half2 M2 = make_half2(M[j], M[j]);
for (int p0 = 0; p0 < C2; p0 += NW) {
const int p = p0 + lane_id;
const half2 s = ss2[j*T2 + p];
const half2 vs = h2exp(s - M2);
ls += vs;
// the P matrix from the paper (Q rows, C columns)
ss2[j*T2 + p] = vs;
}
ls = warp_reduce_sum(ls);
const half ms = hexp(m - M[j]);
// create a QxQ diagonal matrix for rescaling the output
if (lane_id == j) {
ss[j*T + C + j] = ms;
S = S*ms + ls.x + ls.y;
}
}
smax = warp_reduce_max(smax);
// skip -INF blocks
if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) {
continue;
}
// O = diag(ms)*O
for (int j = 0; j < Q16; ++j) {
half16x16_a mm;
half16x16_b lob;
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
for (int i = 0; i < D16; ++i) {
// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
}
}
// restore zeros
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
}
// O = O + (Q*K^T)*V
{
for (int cc = 0; cc < C16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
half16x16_b mv[D16];
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
}
half16x16_a ms[Q16];
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
}
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
}
}
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
if (lane_id < Q) {
ss[lane_id*T + 0] = S;
ss[lane_id*T + 1] = M[lane_id];
}
}
// reduce the warps sequentially
for (int sg = 1; sg < num_warps; ++sg) {
__syncthreads();
// each simdgroup stores its output to shared memory, reusing sq
if (warp_id == sg) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
}
}
}
__syncthreads();
// the first simdgroup accumulates the results from the other simdgroups
if (warp_id == 0) {
for (int j = lane_id; j < Q; j += NW) {
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + sg*SH + 0];
const half M0 = ss[j*T + 1];
const half M1 = ss[j*T + sg*SH + 1];
const half M = __hmax(M0, M1);
const half ms0 = hexp(M0 - M);
const half ms1 = hexp(M1 - M);
const half S = S0*ms0 + S1*ms1;
ss[j*T + 0] = S;
ss[j*T + 1] = M;
ss[j*T + C + j ] = ms0;
ss[j*T + C + j + sg*SH] = ms1;
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (int j = 0; j < Q16; ++j) {
half16x16_a ms0;
half16x16_a ms1;
half16x16_b t;
half16x16_acc t2;
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
}
}
}
}
// store result to shared memory (reuse sq)
if (warp_id == 0) {
for (int j = 0; j < Q16; ++j) {
for (int i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
}
}
}
// final rescale with 1/S and store to global memory
if (warp_id == 0) {
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0];
for (int i0 = 0; i0 < D; i0 += NW) {
const int i = i0 + lane_id;
if (i >= D) {
break;
}
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
}
}
}
#else
NO_DEVICE_CODE;
#endif
}
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
@ -8760,7 +9315,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}
static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
static void soft_max_f32_cuda(const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
@ -8821,6 +9376,13 @@ static void im2col_cuda(const float * x, T* dst,
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
int num_blocks = num_heads * seq_len;
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE, 1024, 64><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
}
static cudaError_t ggml_cuda_cpy_tensor_2d(
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
@ -9877,7 +10439,7 @@ static void ggml_cuda_op_soft_max(
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
@ -9890,16 +10452,16 @@ static void ggml_cuda_op_soft_max(
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
// positions tensor
float * src2_dd = nullptr;
half * src2_dd = nullptr;
ggml_tensor * src2 = dst->src[2];
const bool use_src2 = src2 != nullptr;
if (use_src2) {
src2_dd = (float *)src2->data;
src2_dd = (half *)src2->data;
}
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, src2_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
}
static void ggml_cuda_op_scale(
@ -10987,6 +11549,189 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, const ggml_ten
}
}
inline void ggml_cuda_flash_attn(ggml_backend_cuda_context & ctx, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) {
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F32);
GGML_ASSERT(V->type == GGML_TYPE_F32);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
ggml_cuda_set_device(ctx.device);
const int64_t d_head = Q->ne[0];
const int64_t sequence_length = Q->ne[1];
const int64_t num_heads = Q->ne[2];
GGML_ASSERT(Q->ne[0] == d_head);
GGML_ASSERT(K->ne[0] == d_head);
GGML_ASSERT(V->ne[1] == d_head);
GGML_ASSERT(Q->ne[1] == sequence_length);
GGML_ASSERT(K->ne[1] == sequence_length);
GGML_ASSERT(V->ne[0] == sequence_length);
GGML_ASSERT(Q->ne[2] == num_heads);
GGML_ASSERT(K->ne[2] == num_heads);
GGML_ASSERT(V->ne[2] == num_heads);
float KQ_scale = 1.0f / sqrtf((float)d_head);
flash_attn_f32_cuda(
(float *) Q->data, // Query
(float *) K->data, // Key
(float *) V->data, // Value
(float *) KQV->data, // dst
KQ_scale, d_head, sequence_length, num_heads, ctx.stream());
}
inline void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
ggml_cuda_set_device(ctx.device);
const cudaStream_t main_stream = ctx.stream();
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
#define NQPB 16
#define NCPW 128
const int nqpb = NQPB; // queries per block
const int ncpw = NCPW; // cache values per warp (does not work for other values)
GGML_ASSERT(NQPB <= 32);
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1;
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1);
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
// increase shared memory limit to 96KB
//const size_t shmem_max = 96*1024;
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
switch (Q->ne[0]) {
case 64:
flash_attn_ext_f16<64, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? (const char *) mask->data : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 80:
flash_attn_ext_f16<80, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? (const char *) mask->data : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 96:
flash_attn_ext_f16<96, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? (const char *) mask->data : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 112:
flash_attn_ext_f16<112, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? (const char *) mask->data : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 128:
flash_attn_ext_f16<128, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? (const char *) mask->data : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 256:
flash_attn_ext_f16<256, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? (const char *) mask->data : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
default:
break;
}
CUDA_CHECK(cudaGetLastError());
}
static void ggml_cuda_scale(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(ctx, src0, src1, dst, ggml_cuda_op_scale);
}
@ -11250,11 +11995,22 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ARGSORT:
func = ggml_cuda_argsort;
break;
case GGML_OP_FLASH_ATTN:
break;
case GGML_OP_FLASH_ATTN_EXT:
break;
default:
return false;
}
func(ctx, tensor->src[0], tensor->src[1], tensor);
if (tensor->op == GGML_OP_FLASH_ATTN) {
ggml_cuda_flash_attn(ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) {
ggml_cuda_flash_attn_ext(ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
} else {
func(ctx, tensor->src[0], tensor->src[1], tensor);
}
return true;
}
@ -11521,6 +12277,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
return true;
default:
return false;

View File

@ -168,6 +168,12 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@ -444,6 +450,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
[metal_function release]; \
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
(int) kernel->pipeline.threadExecutionWidth); \
if (error) { \
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
[metal_library release]; \
@ -594,6 +603,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@ -724,6 +739,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
return true;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
@ -1263,6 +1279,8 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_SOFT_MAX:
{
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16);
int nth = 32; // SIMD width
id<MTLComputePipelineState> pipeline = nil;
@ -2436,6 +2454,106 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
GGML_ASSERT(ggml_are_same_shape(src1, src2));
GGML_ASSERT(src3);
size_t offs_src3 = 0;
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
const int64_t ne31 = src3 ? src3->ne[1] : 0;
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
const uint64_t nb31 = src3 ? src3->nb[1] : 0;
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
float scale;
memcpy(&scale, dst->op_params, sizeof(float));
id<MTLComputePipelineState> pipeline = nil;
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
default:
{
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
GGML_ASSERT(false && "add template specialization for this size");
}
}
// TODO: extend if necessary
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
[encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);
// simdgroups per threadgroup (a.k.a. warps)
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
@ -2639,10 +2757,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe
UNUSED(buft);
}
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
#ifndef GGML_METAL_NDEBUG
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
if (@available(macOS 10.12, iOS 16.0, *)) {
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
__func__,
size_aligned / 1024.0 / 1024.0,
device.currentAllocatedSize / 1024.0 / 1024.0,
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
@ -2652,10 +2773,15 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
GGML_METAL_LOG_INFO("\n");
}
} else {
GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
__func__,
size_aligned / 1024.0 / 1024.0,
device.currentAllocatedSize / 1024.0 / 1024.0);
}
#endif
#endif
UNUSED(device);
UNUSED(size_aligned);
}
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@ -2689,8 +2815,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
return NULL;
}
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
ggml_backend_metal_log_allocated_size(device);
ggml_backend_metal_log_allocated_size(device, size_aligned);
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
}
@ -2777,7 +2902,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
return false;
}
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
ggml_backend_metal_log_allocated_size(device, size_aligned);
++ctx->n_buffers;
} else {
@ -2800,7 +2925,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
return false;
}
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i);
ggml_backend_metal_log_allocated_size(device, size_step_aligned);
if (i + size_step < size) {
GGML_METAL_LOG_INFO("\n");
}
@ -2809,8 +2935,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
}
}
ggml_backend_metal_log_allocated_size(device);
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
}

View File

@ -318,10 +318,10 @@ kernel void kernel_sum_rows(
}
kernel void kernel_soft_max(
device const float * src0,
device const float * src1,
device const float * src2,
device float * dst,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -340,10 +340,10 @@ kernel void kernel_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
device const float * ppos = src2 != src0 ? src2 : nullptr;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr;
device const half * ppos = src2 != src0 ? (device const half *) src2 : nullptr;
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
float slope = 0.0f;
@ -422,10 +422,10 @@ kernel void kernel_soft_max(
}
kernel void kernel_soft_max_4(
device const float * src0,
device const float * src1,
device const float * src2,
device float * dst,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -444,10 +444,10 @@ kernel void kernel_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr;
device const half4 * ppos = src2 != src0 ? (device const half4 *) src2 : nullptr;
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
float slope = 0.0f;
@ -464,7 +464,7 @@ kernel void kernel_soft_max_4(
float4 lmax4 = -INFINITY;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@ -490,7 +490,7 @@ kernel void kernel_soft_max_4(
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
@ -2031,6 +2031,411 @@ kernel void kernel_leaky_relu_f32(
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
}
typedef void (flash_attn_ext_f16_t)(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]);
// ref: https://arxiv.org/pdf/2307.08691.pdf
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_f16(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint nsg = ntg.y; // number of simdgroups
const int64_t iq3 = tgpig[2];
const int64_t iq2 = tgpig[1];
const int64_t iq1 = tgpig[0]*Q;
const int64_t D4 = D/4;
const int64_t D8 = D/8;
const int64_t Q8 = Q/8;
const int64_t NW = N_SIMDWIDTH;
const int64_t SH = (C + Q); // shared memory per simdgroup in (half)
const int64_t T = D + nsg*SH; // shared memory size per query in (half)
const int64_t T4 = T/4; // shared memory size per query in (half4)
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
simdgroup_half8x8 lo[Q8][D8];
// load heads from Q to shared memory
for (int64_t j = sgitg; j < Q; j += nsg) {
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
for (int64_t i = tiisg; i < D4; i += NW) {
if (iq1 + j < ne01) {
sq4[j*T4 + i] = (half4) q4[i];
} else {
sq4[j*T4 + i] = 0.0h;
}
}
}
// zero out lo
for (int64_t j = 0; j < Q8; ++j) {
for (int64_t i = 0; i < D8; ++i) {
lo[j][i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
}
}
// zero out shared memory SH
for (int64_t j = 0; j < Q; ++j) {
for (int64_t i = tiisg; i < SH; i += NW) {
ss[j*T + i] = 0.0h;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
{
half S[Q] = { [0 ... Q-1] = 0.0h };
half M[Q] = { [0 ... Q-1] = -INFINITY };
// assume K and V are same shape
const int64_t ne22 = ne12;
const int64_t ne23 = ne13;
const uint64_t nb21 = nb11;
const uint64_t nb22 = nb12;
const uint64_t nb23 = nb13;
// broadcast
const int64_t rk2 = ne02/ne12;
const int64_t rk3 = ne03/ne13;
const int64_t rv2 = ne02/ne22;
const int64_t rv3 = ne03/ne23;
// k indices
const int64_t ik2 = iq2 / rk2;
const int64_t ik3 = iq3 / rk3;
// v indices
const int64_t iv2 = iq2 / rv2;
const int64_t iv3 = iq3 / rv3;
// load the queries from shared memory into local memory
simdgroup_half8x8 mq[Q8][D8];
for (int64_t j = 0; j < Q8; ++j) {
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
}
}
// pointer to the mask
device const half * mp = (device const half *) (mask + iq1*nb31);
// prepare diagonal scale matrix
simdgroup_half8x8 mscale(scale);
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) {
// Q*K^T
{
for (int cc = 0; cc < C/8; ++cc) {
simdgroup_half8x8 mqk[Q8];
for (int64_t j = 0; j < Q8; ++j) {
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
}
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (int64_t i = 0; i < D8; ++i) {
simdgroup_half8x8 mk;
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
}
}
// mqk = mqk*scale + mask
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_half8x8 mm;
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
}
}
}
// used to detect blocks full of -INF
half smax = -INFINITY;
// online softmax
if (C == 32) {
half ms[Q];
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = tiisg;
const half m = M[j];
const half s = ss[j*T + p];
smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
S[j] = S[j]*ms[j] + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
// create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) {
ss[tiisg*T + C + tiisg] = ms[tiisg];
}
} else {
half ms[Q];
for (int64_t j = 0; j < Q; ++j) {
const half m = M[j];
for (int64_t p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
smax = max(smax, s);
M[j] = max(M[j], s);
}
smax = simd_max(smax);
M[j] = simd_max(M[j]);
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
// local sum
half ls = 0.0h;
for (int64_t p = tiisg; p < C; p += NW) {
const half s = ss[j*T + p];
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
ls += vs;
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
S[j] = S[j]*ms[j] + simd_sum(ls);
}
// create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) {
ss[tiisg*T + C + tiisg] = ms[tiisg];
}
}
// skip -INF blocks
if (smax == -INFINITY) {
continue;
}
// O = diag(ms)*O
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_half8x8 mm;
simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
for (int64_t i = 0; i < D8; ++i) {
simdgroup_multiply(lo[j][i], mm, lo[j][i]);
}
}
// O = O + (Q*K^T)*V
{
for (int cc = 0; cc < C/8; ++cc) {
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
for (int64_t i = 0; i < D8; ++i) {
simdgroup_half8x8 mk;
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_half8x8 mv;
simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false);
simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]);
}
}
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (int64_t j = 0; j < Q; ++j) {
if (tiisg == 0) {
ss[j*T + 0] = S[j];
ss[j*T + 1] = M[j];
}
}
}
// reduce the warps sequentially
for (int64_t sg = 1; sg < nsg; ++sg) {
half S = { 0.0h };
half M = { -INFINITY };
threadgroup_barrier(mem_flags::mem_threadgroup);
// each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) {
for (int64_t j = 0; j < Q8; ++j) {
for (int64_t i = 0; i < D8; ++i) {
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// the first simdgroup accumulates the results from the other simdgroups
if (sgitg == 0) {
for (int64_t j = 0; j < Q; ++j) {
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + sg*SH + 0];
const half M0 = ss[j*T + 1];
const half M1 = ss[j*T + sg*SH + 1];
M = max(M0, M1);
const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M);
const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M);
S = S0*ms0 + S1*ms1;
if (tiisg == 0) {
ss[j*T + 0] = S;
ss[j*T + 1] = M;
ss[j*T + C + j ] = ms0;
ss[j*T + C + j + sg*SH] = ms1;
}
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_half8x8 t;
simdgroup_half8x8 ms0;
simdgroup_half8x8 ms1;
simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false);
simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false);
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false);
simdgroup_multiply(t, ms1, t);
simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t);
}
}
}
}
// store result to shared memory (reuse sq)
if (sgitg == 0) {
for (int64_t j = 0; j < Q8; ++j) {
for (int64_t i = 0; i < D8; ++i) {
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
}
}
}
device float4 * dst4 = (device float4 *) dst;
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0];
for (int64_t i = tiisg; i < D4; i += NW) {
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
}
}
}
}
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,

341
ggml.c
View File

@ -896,7 +896,7 @@ inline static float vaddvq_f32(float32x4_t v) {
#define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
#define GGML_F16_VEC_SET1 GGML_F16x8_SET1
#define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), r[i])
#define GGML_F16_VEC_FMA GGML_F16x8_FMA
#define GGML_F16_VEC_ADD GGML_F16x8_ADD
#define GGML_F16_VEC_MUL GGML_F16x8_MUL
@ -922,7 +922,7 @@ inline static float vaddvq_f32(float32x4_t v) {
#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i])
#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
@ -1089,7 +1089,7 @@ do { \
#if defined(__F16C__)
// the _mm256_cvt intrinsics require F16C
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
#else
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
@ -1607,6 +1607,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
#endif
}
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
#if defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
}
#endif
}
// xs and vs are byte strides of x and v
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
@ -1691,6 +1722,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#endif
}
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
}
#endif
}
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
@ -1945,6 +2005,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"LEAKY_RELU",
"FLASH_ATTN",
"FLASH_ATTN_EXT",
"FLASH_FF",
"FLASH_ATTN_BACK",
"SSM_CONV",
@ -1971,7 +2032,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -2035,6 +2096,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"leaky_relu(x)",
"flash_attn(x)",
"flash_attn_ext(x)",
"flash_ff(x)",
"flash_attn_back(x)",
"ssm_conv(x)",
@ -2061,7 +2123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -4497,6 +4559,8 @@ struct ggml_tensor * ggml_mul_mat(
void ggml_mul_mat_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec) {
GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
const int32_t prec_i32 = (int32_t) prec;
ggml_set_op_params_i32(a, 0, prec_i32);
@ -5333,14 +5397,15 @@ static struct ggml_tensor * ggml_soft_max_impl(
GGML_ASSERT(ggml_is_contiguous(a));
if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16);
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_matrix(mask));
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
}
if (pos) {
GGML_ASSERT(ggml_is_vector(pos));
GGML_ASSERT(pos->type == GGML_TYPE_F32);
GGML_ASSERT(pos->type == GGML_TYPE_F16);
GGML_ASSERT(pos->ne[0] == a->ne[0]);
}
@ -6152,6 +6217,59 @@ struct ggml_tensor * ggml_flash_attn(
return result;
}
// ggml_flash_attn_ext
struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale) {
GGML_ASSERT(ggml_can_mul_mat(k, q));
// TODO: check if vT can be multiplied by (k*qT)
if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == 1);
GGML_ASSERT(mask->ne[3] == 1);
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
}
bool is_node = false;
if (q->grad || k->grad || v->grad) {
is_node = true;
}
// permute(0, 2, 1, 3)
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne);
float params[] = { scale };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_FLASH_ATTN_EXT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = q;
result->src[1] = k;
result->src[2] = v;
result->src[3] = mask;
return result;
}
void ggml_flash_attn_ext_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec) {
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
const int32_t prec_i32 = (int32_t) prec;
ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
}
// ggml_flash_ff
struct ggml_tensor * ggml_flash_ff(
@ -12184,12 +12302,14 @@ static void ggml_compute_forward_soft_max_f32(
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
// broadcast the mask across rows
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
ggml_vec_cpy_f32 (nc, wp, sp);
ggml_vec_scale_f32(nc, wp, scale);
if (mp) {
ggml_vec_acc_f32(nc, wp, mp);
for (int i = 0; i < nc; ++i) {
wp[i] += GGML_FP16_TO_FP32(mp[i]);
}
}
// ALiBi bias
@ -14466,6 +14586,197 @@ static void ggml_compute_forward_flash_attn(
}
}
// ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16(
const struct ggml_compute_params * params,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst) {
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
const int64_t D = neq0;
const int64_t N = neq1;
GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne2 == N);
GGML_ASSERT(nbq0 == sizeof(float));
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
GGML_ASSERT(neq0 == D);
GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq1 == N);
GGML_ASSERT(nev0 == D);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// broadcast factors
const int64_t rk2 = neq2/nek2;
const int64_t rk3 = neq3/nek3;
const int64_t rv2 = neq2/nev2;
const int64_t rv3 = neq3/nev3;
if (params->type == GGML_TASK_TYPE_INIT) {
return;
}
if (params->type == GGML_TASK_TYPE_FINALIZE) {
return;
}
// parallelize by q rows using ggml_vec_dot_f32
// total rows in q
const int nr = neq1*neq2*neq3;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
float scale = 1.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
const int iq3 = ir/(neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
float S = 0.0f;
float M = -INFINITY;
float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
memset(V16, 0, D*sizeof(ggml_fp16_t));
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
// k indices
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;
// v indices
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;
// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) {
const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
continue;
}
float s;
// convert Q to F16 in V32
{
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
for (int64_t d = 0; d < D; ++d) {
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
}
}
ggml_vec_dot_f16(D,
&s, 0,
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
Q16, 0, 1);
s = s*scale + mv;
const float Mold = M;
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
M = s;
ms = expf(Mold - M);
// V = V*expf(Mold - M)
ggml_vec_scale_f16(D, V16, ms);
} else {
vs = expf(s - M);
}
const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
// V += v*expf(s - M)
ggml_vec_mad_f16(D, V16, v16, vs);
S = S*ms + vs;
}
// V /= S
for (int64_t d = 0; d < D; ++d) {
V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
}
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// original
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
// permute(0, 2, 1, 3)
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
}
}
static void ggml_compute_forward_flash_attn_ext(
const struct ggml_compute_params * params,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst) {
switch (dst->op_params[1]) {
case GGML_PREC_DEFAULT:
{
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
} break;
default:
{
// TODO: implement F32 precision
GGML_ASSERT(false);
} break;
}
}
// ggml_compute_forward_flash_ff
static void ggml_compute_forward_flash_ff_f16(
@ -16293,6 +16604,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
const bool masked = t != 0;
ggml_compute_forward_flash_attn(params, masked, tensor);
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
} break;
case GGML_OP_FLASH_FF:
{
ggml_compute_forward_flash_ff(params, tensor);
@ -17305,6 +17620,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_FLASH_ATTN:
case GGML_OP_FLASH_ATTN_EXT:
{
struct ggml_tensor * flash_grad = NULL;
if (src0->grad || src1->grad || tensor->src[2]->grad) {
@ -18071,6 +18387,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
n_tasks = n_threads;
} break;
case GGML_OP_FLASH_ATTN:
case GGML_OP_FLASH_ATTN_EXT:
{
n_tasks = n_threads;
} break;
@ -18474,6 +18791,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
}
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne00 = node->src[0]->ne[0]; // D
cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
} break;
case GGML_OP_FLASH_FF:
{
if (node->src[1]->type == GGML_TYPE_F32) {

20
ggml.h
View File

@ -472,6 +472,7 @@ extern "C" {
GGML_OP_LEAKY_RELU,
GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_FF,
GGML_OP_FLASH_ATTN_BACK,
GGML_OP_SSM_CONV,
@ -1716,6 +1717,25 @@ extern "C" {
struct ggml_tensor * v,
bool masked);
#define GGML_KQ_MASK_PAD 32
// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale);
GGML_API void ggml_flash_attn_ext_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec);
GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx,
struct ggml_tensor * q,

View File

@ -104,6 +104,7 @@
#define LLAMA_MAX_NODES 8192
#define LLAMA_MAX_EXPERTS 8
#define LLAMA_FLASH_ATTN
//
// logging
@ -5265,23 +5266,34 @@ static void llm_build_kv_store(
GGML_ASSERT(kv.size == n_ctx);
// compute the transposed [n_tokens, n_embd] V matrix
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
cb(v_cur_t, "v_cur_t", il);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
cb(k_cache_view, "k_cache_view", il);
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
#if defined(LLAMA_FLASH_ATTN)
// NOTE: the V cache is not transposed when using FLASH attention !!
struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
(ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head);
cb(v_cache_view, "v_cache_view", il);
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
GGML_UNUSED(n_ctx);
#else
// compute the transposed [n_tokens, n_embd] V matrix
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
cb(v_cur_t, "v_cur_t", il);
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
( n_ctx)*ggml_element_size(kv.v_l[il]),
(kv_head)*ggml_element_size(kv.v_l[il]));
cb(v_cache_view, "v_cache_view", il);
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
#endif
}
static struct ggml_tensor * llm_build_norm(
@ -5442,6 +5454,33 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(k, "k", il);
struct ggml_tensor * cur;
#if defined(LLAMA_FLASH_ATTN)
GGML_UNUSED(model);
GGML_UNUSED(n_ctx);
GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention");
// split cached v into n_head heads (not transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
0);
cb(v, "v", il);
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT);
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
//printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]);
//printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]);
cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
#else
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
@ -5491,8 +5530,9 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
cb(cur, "kqv_merged_cont", il);
#endif
ggml_build_forward_expand(graph, cur);
@ -5770,20 +5810,20 @@ struct llm_build_context {
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
if (causal) {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
} else {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
}
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask);
return lctx.inp_KQ_mask;
return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16);
}
struct ggml_tensor * build_inp_KQ_pos() {
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
cb(lctx.inp_KQ_pos, "KQ_pos", -1);
ggml_set_input(lctx.inp_KQ_pos);
return lctx.inp_KQ_pos;
return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16);
}
struct ggml_tensor * build_inp_mean() {
@ -9126,7 +9166,7 @@ static int llama_decode_internal(
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
@ -13022,6 +13062,10 @@ struct llama_context * llama_new_context_with_model(
const auto & hparams = model->hparams;
auto & cparams = ctx->cparams;
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch);
// TODO: maybe add n_seq_max here too
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;

View File

@ -571,9 +571,19 @@ struct test_case {
// duplicate the op
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
#if 0
for (int i = 1; i < n_runs; i++) {
gf->nodes[gf->n_nodes++] = out;
}
#else
int n_nodes = gf->n_nodes;
n_runs = 1000;
for (int i = 1; i < n_runs; i++) {
for (int j = 0; j < n_nodes; j++) {
gf->nodes[gf->n_nodes++] = gf->nodes[j];
}
}
#endif
// calculate memory
size_t mem = n_runs * op_size(out);
@ -1103,11 +1113,11 @@ struct test_soft_max : public test_case {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * mask = nullptr;
if (this->mask) {
mask = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]);
}
ggml_tensor * pos = nullptr;
if (max_bias > 0.0f) {
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, ne[0]);
}
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
return out;
@ -1477,6 +1487,76 @@ struct test_leaky_relu : public test_case {
}
};
// GGML_OP_FLASH_ATTN_EXT
struct test_flash_attn_ext : public test_case {
const int64_t hs; // head size
const int64_t nh; // num heads
const int64_t kv; // kv size
const int64_t nb; // batch size
std::string vars() override {
return VARS_TO_STR4(hs, nh, kv, nb);
}
double max_nmse_err() override {
return 5e-4;
}
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
: hs(hs), nh(nh), kv(kv), nb(nb) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
return out;
}
};
// Attention
struct test_attn : public test_case {
const int64_t hs; // head size
const int64_t nh; // num heads
const int64_t kv; // kv size
const int64_t nb; // batch size
std::string op_desc(ggml_tensor * t) override {
return "ATTN";
GGML_UNUSED(t);
}
std::string vars() override {
return VARS_TO_STR4(hs, nh, kv, nb);
}
double max_nmse_err() override {
return 5e-4;
}
test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
: hs(hs), nh(nh), kv(kv), nb(nb) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, 1);
struct ggml_tensor * cur;
cur = ggml_mul_mat (ctx, k, q);
cur = ggml_soft_max_ext(ctx, cur, mask, nullptr, 1.0f/sqrtf(hs), 0.0f);
cur = ggml_mul_mat (ctx, v, cur);
cur = ggml_permute (ctx, cur, 0, 2, 1, 3);
cur = ggml_cont_2d (ctx, cur, hs*nh, nb);
return cur;
}
};
// Mixtral MOE
struct test_moe : public test_case {
const int n_experts;
@ -1749,7 +1829,7 @@ struct test_llama : public test_llm {
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
@ -1871,7 +1951,7 @@ struct test_falcon : public test_llm {
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
@ -2180,6 +2260,30 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
#if 1
for (int hs : { 128, 64, 80, }) {
for (int nh : { 32, }) {
for (int kv : { 512, 1024, 2048, 4096, }) {
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
}
}
}
}
#else
for (int hs : { 128, }) {
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, 512 }) {
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
}
}
}
}
#endif
// these tests are disabled to save execution time, but they can be handy for debugging
#if 0
#if !defined(__SANITIZE_THREAD__)