ggml : sync latest (SAM + SD operators, CUDA alibi) (#2709)

* ggml : sync latest (SAM + SD operators, CUDA alibi)

ggml-ci

* ggml : fix tabs
This commit is contained in:
Georgi Gerganov 2023-08-22 14:22:08 +03:00 committed by GitHub
parent 8e4364f2af
commit ef3f333d37
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 1090 additions and 61 deletions

View file

@ -1868,10 +1868,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
t04->grad = expand(gb, ggml_add_inplace(ctx0,
ggml_add_inplace(ctx0,

View file

@ -76,7 +76,7 @@ struct ggml_allocr {
};
#ifdef GGML_ALLOCATOR_DEBUG
static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == NULL) {
alloc->allocated_tensors[i] = tensor;
@ -85,7 +85,7 @@ static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tens
}
GGML_ASSERT(!"out of allocated_tensors");
}
static void remove_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) {
static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == tensor ||
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {

View file

@ -259,6 +259,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
#define CUDA_ALIBI_BLOCK_SIZE 32
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
@ -3940,6 +3941,29 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
}
static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
const int n_heads_log2_floor, const float m0, const float m1) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
if (col >= ncols) {
return;
}
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int i = row*ncols + col;
const int k = row/k_rows;
float m_k;
if (k < n_heads_log2_floor) {
m_k = powf(m0, k + 1);
} else {
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
}
dst[i] = col * m_k + x[i];
}
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int row = blockDim.y*blockIdx.y + threadIdx.y;
@ -4766,6 +4790,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
}
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
const int k_rows, const int n_heads_log2_floor, const float m0,
const float m1, cudaStream_t stream) {
const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
const dim3 block_nums(num_blocks_x, nrows, 1);
alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
}
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@ -5501,6 +5534,41 @@ inline void ggml_cuda_op_rope(
(void) i1;
}
inline void ggml_cuda_op_alibi(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
cudaStream_t & cudaStream_main){
GGML_ASSERT(src0_ddf_i != nullptr);
GGML_ASSERT(dst_ddf_i != nullptr);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t i01_diff = i01_high - i01_low;
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
GGML_ASSERT(ne01 + n_past == ne00);
GGML_ASSERT(n_head == ne02);
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
// compute
alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main);
(void) src1;
(void) src0_ddq_i;
(void) src1_ddf_i;
(void) i1;
}
inline void ggml_cuda_op_diag_mask_inf(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@ -6121,6 +6189,11 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
}
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true);
}
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
(void) src0;
(void) src1;
@ -6456,6 +6529,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
}
func = ggml_cuda_rope;
break;
case GGML_OP_ALIBI:
if (!any_on_device) {
return false;
}
func = ggml_cuda_alibi;
break;
default:
return false;
}

929
ggml.c

File diff suppressed because it is too large Load diff

115
ggml.h
View file

@ -211,6 +211,7 @@
#define GGML_MAX_OP_PARAMS 32
#define GGML_DEFAULT_N_THREADS 4
#define GGML_EXIT_SUCCESS 0
#define GGML_EXIT_ABORTED 1
@ -345,10 +346,12 @@ extern "C" {
GGML_OP_ARGMAX,
GGML_OP_REPEAT,
GGML_OP_REPEAT_BACK,
GGML_OP_CONCAT,
GGML_OP_SILU_BACK,
GGML_OP_NORM, // normalize
GGML_OP_RMS_NORM,
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
GGML_OP_MUL_MAT,
GGML_OP_OUT_PROD,
@ -374,14 +377,19 @@ extern "C" {
GGML_OP_CLAMP,
GGML_OP_CONV_1D,
GGML_OP_CONV_2D,
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF,
GGML_OP_FLASH_ATTN_BACK,
GGML_OP_WIN_PART,
GGML_OP_WIN_UNPART,
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_UNARY,
@ -805,6 +813,13 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
// concat a and b on dim 2
// used in stable-diffusion
GGML_API struct ggml_tensor * ggml_concat(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_abs(
struct ggml_context * ctx,
struct ggml_tensor * a);
@ -913,6 +928,19 @@ extern "C" {
struct ggml_tensor * a,
float eps);
// group normalize along ne0*ne1*n_groups
// used in stable-diffusion
// TODO: eps is hardcoded to 1e-6 for now
GGML_API struct ggml_tensor * ggml_group_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_groups);
GGML_API struct ggml_tensor * ggml_group_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_groups);
// a - x
// b - dy
// TODO: update with configurable eps
@ -1213,6 +1241,15 @@ extern "C" {
float freq_base,
float freq_scale);
// xPos RoPE, in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
float base,
bool down);
// rotary position embedding backward, i.e compute dx from dy
// a - dy
GGML_API struct ggml_tensor * ggml_rope_back(
@ -1221,7 +1258,11 @@ extern "C" {
int n_past,
int n_dims,
int mode,
int n_ctx);
int n_ctx,
float freq_base,
float freq_scale,
float xpos_base,
bool xpos_down);
// alibi position embedding
// in-place, returns view(a)
@ -1248,6 +1289,15 @@ extern "C" {
int p0, // padding
int d0); // dilation
// conv_1d with padding = half
// alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
GGML_API struct ggml_tensor* ggml_conv_1d_ph(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s,
int d);
GGML_API struct ggml_tensor * ggml_conv_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -1259,14 +1309,38 @@ extern "C" {
int d0,
int d1);
// conv_1d with padding = half
// alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
GGML_API struct ggml_tensor * ggml_conv_1d_ph(
// kernel size is a->ne[0] x a->ne[1]
// stride is equal to kernel size
// padding is zero
// example:
// a: 16 16 3 768
// b: 1024 1024 3 1
// res: 64 64 768 1
// used in sam
GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// kernel size is a->ne[0] x a->ne[1]
// stride is 1
// padding is half
// example:
// a: 3 3 256 256
// b: 64 64 256 1
// res: 64 64 256 1
// used in sam
GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s,
int d);
int stride);
enum ggml_op_pool {
GGML_OP_POOL_MAX,
@ -1293,6 +1367,13 @@ extern "C" {
int p0,
int p1);
// nearest interpolate
// used in stable-diffusion
GGML_API struct ggml_tensor * ggml_upscale(
struct ggml_context * ctx,
struct ggml_tensor * a,
int scale_factor);
GGML_API struct ggml_tensor * ggml_flash_attn(
struct ggml_context * ctx,
struct ggml_tensor * q,
@ -1346,6 +1427,27 @@ extern "C" {
struct ggml_tensor * a,
enum ggml_unary_op op);
// used in sam
GGML_API struct ggml_tensor * ggml_get_rel_pos(
struct ggml_context * ctx,
struct ggml_tensor * a,
int qh,
int kh);
// used in sam
GGML_API struct ggml_tensor * ggml_add_rel_pos(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * pw,
struct ggml_tensor * ph);
GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * pw,
struct ggml_tensor * ph);
// custom operators
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
@ -1500,6 +1602,7 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * tensor);
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);

View file

@ -1,14 +1,16 @@
#!/bin/bash
cp -rpv ../ggml/src/ggml.c ./ggml.c
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
cp -rpv ../ggml/src/ggml.c ./ggml.c
cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
cp -rpv ../ggml/tests/test-opt.cpp ./tests/test-opt.cpp
cp -rpv ../ggml/tests/test-grad0.cpp ./tests/test-grad0.cpp