ggml : sync latest ggml + llama.cpp updates (quantization)

pull/833/head
Georgi Gerganov 2023-04-29 12:31:52 +03:00
parent 5cc17418c7
commit acec73ab6e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 3980 additions and 1609 deletions

365
ggml-cuda.cu 100644
View File

@ -0,0 +1,365 @@
#include <stdint.h>
#include <stdio.h>
#include <cuda_fp16.h>
#include <atomic>
#include "ggml-cuda.h"
typedef uint16_t ggml_fp16_t;
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
#define QK4_0 32
typedef struct {
float d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
#define QK4_1 32
typedef struct {
float d; // delta
float m; // min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
#define QK4_2 16
typedef struct {
__half d; // delta
uint8_t qs[QK4_2 / 2]; // nibbles / quants
} block_q4_2;
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
#define QK5_0 32
typedef struct {
__half d; // delta
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_0 / 2]; // nibbles / quants
} block_q5_0;
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
#define QK5_1 32
typedef struct {
__half d; // delta
__half m; // min
uint32_t qh; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1;
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
#define QK8_0 32
typedef struct {
float d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
const block_q4_0 * x = (const block_q4_0 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
const uint8_t * pp = x[i].qs;
for (int l = 0; l < QK4_0; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
y[i*QK4_0 + l + 0] = v0;
y[i*QK4_0 + l + 1] = v1;
}
}
static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
const block_q4_1 * x = (const block_q4_1 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
const float m = x[i].m;
const uint8_t * pp = x[i].qs;
for (int l = 0; l < QK4_1; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const float v0 = vi0*d + m;
const float v1 = vi1*d + m;
y[i*QK4_1 + l + 0] = v0;
y[i*QK4_1 + l + 1] = v1;
}
}
static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
const block_q4_2 * x = (const block_q4_2 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
const uint8_t * pp = x[i].qs;
for (int l = 0; l < QK4_2; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
y[i*QK4_2 + l + 0] = v0;
y[i*QK4_2 + l + 1] = v1;
}
}
static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
const block_q5_0 * x = (const block_q5_0 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
const uint8_t * pp = x[i].qs;
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
for (int l = 0; l < QK5_0; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
const int8_t vi0 = ((vi & 0xf) | vh0);
const int8_t vi1 = ((vi >> 4) | vh1);
const float v0 = (vi0 - 16)*d;
const float v1 = (vi1 - 16)*d;
y[i*QK5_0 + l + 0] = v0;
y[i*QK5_0 + l + 1] = v1;
}
}
static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
const block_q5_1 * x = (const block_q5_1 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
const float m = x[i].m;
const uint8_t * pp = x[i].qs;
const uint32_t qh = x[i].qh;
for (int l = 0; l < QK5_1; l += 2) {
const uint8_t vi = pp[l/2];
const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
const int8_t vi0 = (vi & 0xf) | vh0;
const int8_t vi1 = (vi >> 4) | vh1;
const float v0 = vi0*d + m;
const float v1 = vi1*d + m;
y[i*QK5_1 + l + 0] = v0;
y[i*QK5_1 + l + 1] = v1;
}
}
static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
const block_q8_0 * x = (const block_q8_0 *) vx;
const int i = blockIdx.x;
const float d = x[i].d;
const int8_t * pp = x[i].qs;
for (int l = 0; l < QK8_0; l++) {
const int8_t vi = pp[l];
y[i*QK8_0 + l] = vi*d;
}
}
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
}
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_1;
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
}
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_2;
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
}
void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK5_0;
dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
}
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK5_1;
dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
}
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK8_0;
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q4_2:
return dequantize_row_q4_2_cuda;
case GGML_TYPE_Q5_0:
return dequantize_row_q5_0_cuda;
case GGML_TYPE_Q5_1:
return dequantize_row_q5_1_cuda;
case GGML_TYPE_Q8_0:
return dequantize_row_q8_0_cuda;
default:
return nullptr;
}
}
// buffer pool for cuda
#define MAX_CUDA_BUFFERS 16
struct scoped_spin_lock {
std::atomic_flag& lock;
scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
while (lock.test_and_set(std::memory_order_acquire)) {
; // spin
}
}
~scoped_spin_lock() {
lock.clear(std::memory_order_release);
}
scoped_spin_lock(const scoped_spin_lock&) = delete;
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
};
struct cuda_buffer {
void * ptr = nullptr;
size_t size = 0;
};
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
cuda_buffer& b = g_cuda_buffer_pool[i];
if (b.size >= size && b.ptr != nullptr) {
void * ptr = b.ptr;
*actual_size = b.size;
b.ptr = nullptr;
b.size = 0;
return ptr;
}
}
void * ptr;
CUDA_CHECK(cudaMalloc((void **) &ptr, size));
*actual_size = size;
return ptr;
}
void ggml_cuda_pool_free(void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
cuda_buffer& b = g_cuda_buffer_pool[i];
if (b.ptr == nullptr) {
b.ptr = ptr;
b.size = size;
return;
}
}
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
CUDA_CHECK(cudaFree(ptr));
}
cublasHandle_t g_cublasH = nullptr;
cudaStream_t g_cudaStream = nullptr;
cudaStream_t g_cudaStream2 = nullptr;
cudaEvent_t g_cudaEvent = nullptr;
void ggml_init_cublas() {
if (g_cublasH == nullptr) {
// create cublas handle, bind a stream
CUBLAS_CHECK(cublasCreate(&g_cublasH));
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
// create additional stream and event for synchronization
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
}
}
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
const uint64_t ne0 = src->ne[0];
const uint64_t ne1 = src->ne[1];
const uint64_t nb0 = src->nb[0];
const uint64_t nb1 = src->nb[1];
const uint64_t nb2 = src->nb[2];
const uint64_t nb3 = src->nb[3];
const enum ggml_type type = src->type;
const size_t ts = ggml_type_size(type);
const size_t bs = ggml_blck_size(type);
const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
if (nb0 == ts && nb1 == ts*ne0/bs) {
return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
} else if (nb0 == ts) {
return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
} else {
for (uint64_t i1 = 0; i1 < ne1; i1++) {
const void * rx = (const void *) ((const char *) x + i1*nb1);
void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
// pretend the row is a matrix with cols=1
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
if (r != cudaSuccess) return r;
}
return cudaSuccess;
}
}
void * ggml_cuda_host_malloc(size_t size) {
void * ptr;
CUDA_CHECK(cudaMallocHost((void **) &ptr, size));
return ptr;
}
void ggml_cuda_host_free(void * ptr) {
CUDA_CHECK(cudaFreeHost(ptr));
}

54
ggml-cuda.h 100644
View File

@ -0,0 +1,54 @@
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "ggml.h"
#ifdef __cplusplus
extern "C" {
#endif
#define CUDA_CHECK(err) \
do { \
cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \
fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \
exit(1); \
} \
} while (0)
#define CUBLAS_CHECK(err) \
do { \
cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \
fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
exit(1); \
} \
} while (0)
extern cublasHandle_t g_cublasH;
extern cudaStream_t g_cudaStream;
extern cudaStream_t g_cudaStream2;
extern cudaEvent_t g_cudaEvent;
void ggml_init_cublas(void);
void * ggml_cuda_host_malloc(size_t size);
void ggml_cuda_host_free(void * ptr);
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
void ggml_cuda_pool_free(void * ptr, size_t size);
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
#ifdef __cplusplus
}
#endif

3817
ggml.c

File diff suppressed because it is too large Load Diff

275
ggml.h
View File

@ -169,14 +169,27 @@
//
//
#ifdef __cplusplus
extern "C" {
#ifdef GGML_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef GGML_BUILD
# define GGML_API __declspec(dllexport)
# else
# define GGML_API __declspec(dllimport)
# endif
# else
# define GGML_API __attribute__ ((visibility ("default")))
# endif
#else
# define GGML_API
#endif
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#define GGML_FILE_MAGIC 0x67676d6c // "ggml"
#define GGML_FILE_VERSION 1
#define GGML_MAX_DIMS 4
#define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 16
@ -184,6 +197,10 @@ extern "C" {
#define GGML_MAX_OPT 4
#define GGML_DEFAULT_N_THREADS 4
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __ARM_NEON
// we use the built-in 16-bit float type
typedef __fp16 ggml_fp16_t;
@ -192,18 +209,23 @@ typedef uint16_t ggml_fp16_t;
#endif
// convert FP16 <-> FP32
float ggml_fp16_to_fp32(ggml_fp16_t x);
ggml_fp16_t ggml_fp32_to_fp16(float x);
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
struct ggml_object;
struct ggml_context;
enum ggml_type {
// explicitly numbered values are used in llama.cpp files
GGML_TYPE_F32 = 0,
GGML_TYPE_F16 = 1,
GGML_TYPE_Q4_0 = 2,
GGML_TYPE_Q4_1 = 3,
GGML_TYPE_Q4_2 = 4,
// GGML_TYPE_Q4_3 (5) support has been removed
GGML_TYPE_Q5_0 = 6,
GGML_TYPE_Q5_1 = 7,
GGML_TYPE_Q8_0 = 8,
GGML_TYPE_Q8_1 = 9,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
@ -247,6 +269,7 @@ enum ggml_op {
GGML_OP_DIAG_MASK_INF,
GGML_OP_SOFT_MAX,
GGML_OP_ROPE,
GGML_OP_ALIBI,
GGML_OP_CONV_1D_1S,
GGML_OP_CONV_1D_2S,
@ -338,56 +361,64 @@ struct ggml_init_params {
bool no_alloc; // don't allocate memory for the tensor data
};
void ggml_time_init(void); // call this once at the beginning of the program
int64_t ggml_time_ms(void);
int64_t ggml_time_us(void);
int64_t ggml_cycles(void);
int64_t ggml_cycles_per_ms(void);
// misc
void ggml_print_object (const struct ggml_object * obj);
void ggml_print_objects(const struct ggml_context * ctx);
GGML_API void ggml_time_init(void); // call this once at the beginning of the program
GGML_API int64_t ggml_time_ms(void);
GGML_API int64_t ggml_time_us(void);
GGML_API int64_t ggml_cycles(void);
GGML_API int64_t ggml_cycles_per_ms(void);
int64_t ggml_nelements(const struct ggml_tensor * tensor);
size_t ggml_nbytes (const struct ggml_tensor * tensor);
GGML_API void ggml_print_object (const struct ggml_object * obj);
GGML_API void ggml_print_objects(const struct ggml_context * ctx);
int ggml_blck_size (enum ggml_type type);
size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
size_t ggml_element_size(const struct ggml_tensor * tensor);
GGML_API int ggml_blck_size (enum ggml_type type);
GGML_API size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
struct ggml_context * ggml_init(struct ggml_init_params params);
void ggml_free(struct ggml_context * ctx);
GGML_API const char * ggml_type_name(enum ggml_type type);
size_t ggml_used_mem(const struct ggml_context * ctx);
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
GGML_API bool ggml_is_quantized(enum ggml_type type);
struct ggml_tensor * ggml_new_tensor(
// main
GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
GGML_API void ggml_free(struct ggml_context * ctx);
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
GGML_API size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
GGML_API struct ggml_tensor * ggml_new_tensor(
struct ggml_context * ctx,
enum ggml_type type,
int n_dims,
const int64_t *ne);
struct ggml_tensor * ggml_new_tensor_1d(
GGML_API struct ggml_tensor * ggml_new_tensor_1d(
struct ggml_context * ctx,
enum ggml_type type,
int64_t ne0);
struct ggml_tensor * ggml_new_tensor_2d(
GGML_API struct ggml_tensor * ggml_new_tensor_2d(
struct ggml_context * ctx,
enum ggml_type type,
int64_t ne0,
int64_t ne1);
struct ggml_tensor * ggml_new_tensor_3d(
GGML_API struct ggml_tensor * ggml_new_tensor_3d(
struct ggml_context * ctx,
enum ggml_type type,
int64_t ne0,
int64_t ne1,
int64_t ne2);
struct ggml_tensor * ggml_new_tensor_4d(
GGML_API struct ggml_tensor * ggml_new_tensor_4d(
struct ggml_context * ctx,
enum ggml_type type,
int64_t ne0,
@ -395,122 +426,127 @@ struct ggml_tensor * ggml_new_tensor_4d(
int64_t ne2,
int64_t ne3);
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
void * ggml_get_data (const struct ggml_tensor * tensor);
float * ggml_get_data_f32(const struct ggml_tensor * tensor);
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
//
// operations on tensors with backpropagation
//
struct ggml_tensor * ggml_dup(
GGML_API struct ggml_tensor * ggml_dup(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_add(
GGML_API struct ggml_tensor * ggml_add(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * ggml_sub(
GGML_API struct ggml_tensor * ggml_add_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * ggml_mul(
GGML_API struct ggml_tensor * ggml_sub(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * ggml_div(
GGML_API struct ggml_tensor * ggml_mul(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * ggml_sqr(
GGML_API struct ggml_tensor * ggml_div(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_sqr(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_sqrt(
GGML_API struct ggml_tensor * ggml_sqrt(
struct ggml_context * ctx,
struct ggml_tensor * a);
// return scalar
// TODO: compute sum along rows
struct ggml_tensor * ggml_sum(
GGML_API struct ggml_tensor * ggml_sum(
struct ggml_context * ctx,
struct ggml_tensor * a);
// mean along rows
struct ggml_tensor * ggml_mean(
GGML_API struct ggml_tensor * ggml_mean(
struct ggml_context * ctx,
struct ggml_tensor * a);
// if a is the same shape as b, and a is not parameter, return a
// otherwise, return a new tensor: repeat(a) to fit in b
struct ggml_tensor * ggml_repeat(
GGML_API struct ggml_tensor * ggml_repeat(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * ggml_abs(
GGML_API struct ggml_tensor * ggml_abs(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_sgn(
GGML_API struct ggml_tensor * ggml_sgn(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_neg(
GGML_API struct ggml_tensor * ggml_neg(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_step(
GGML_API struct ggml_tensor * ggml_step(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_relu(
GGML_API struct ggml_tensor * ggml_relu(
struct ggml_context * ctx,
struct ggml_tensor * a);
// TODO: double-check this computation is correct
struct ggml_tensor * ggml_gelu(
GGML_API struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_silu(
GGML_API struct ggml_tensor * ggml_silu(
struct ggml_context * ctx,
struct ggml_tensor * a);
// normalize along rows
// TODO: eps is hardcoded to 1e-5 for now
struct ggml_tensor * ggml_norm(
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_rms_norm(
GGML_API struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a);
// A: m rows, n columns
// B: p rows, n columns (i.e. we transpose it internally)
// result is m columns, p rows
struct ggml_tensor * ggml_mul_mat(
GGML_API struct ggml_tensor * ggml_mul_mat(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
@ -520,32 +556,32 @@ struct ggml_tensor * ggml_mul_mat(
//
// in-place, returns view(a)
struct ggml_tensor * ggml_scale(
GGML_API struct ggml_tensor * ggml_scale(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// a -> b, return view(b)
struct ggml_tensor * ggml_cpy(
GGML_API struct ggml_tensor * ggml_cpy(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// make contiguous
struct ggml_tensor * ggml_cont(
GGML_API struct ggml_tensor * ggml_cont(
struct ggml_context * ctx,
struct ggml_tensor * a);
// return view(a), b specifies the new shape
// TODO: when we start computing gradient, make a copy instead of view
struct ggml_tensor * ggml_reshape(
GGML_API struct ggml_tensor * ggml_reshape(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// return view(a)
// TODO: when we start computing gradient, make a copy instead of view
struct ggml_tensor * ggml_reshape_2d(
GGML_API struct ggml_tensor * ggml_reshape_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
@ -553,7 +589,7 @@ struct ggml_tensor * ggml_reshape_2d(
// return view(a)
// TODO: when we start computing gradient, make a copy instead of view
struct ggml_tensor * ggml_reshape_3d(
GGML_API struct ggml_tensor * ggml_reshape_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
@ -561,13 +597,13 @@ struct ggml_tensor * ggml_reshape_3d(
int64_t ne2);
// offset in bytes
struct ggml_tensor * ggml_view_1d(
GGML_API struct ggml_tensor * ggml_view_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
size_t offset);
struct ggml_tensor * ggml_view_2d(
GGML_API struct ggml_tensor * ggml_view_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
@ -575,7 +611,7 @@ struct ggml_tensor * ggml_view_2d(
size_t nb1, // row stride in bytes
size_t offset);
struct ggml_tensor * ggml_view_3d(
GGML_API struct ggml_tensor * ggml_view_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int64_t ne0,
@ -585,7 +621,7 @@ struct ggml_tensor * ggml_view_3d(
size_t nb2, // slice stride in bytes
size_t offset);
struct ggml_tensor * ggml_permute(
GGML_API struct ggml_tensor * ggml_permute(
struct ggml_context * ctx,
struct ggml_tensor * a,
int axis0,
@ -594,60 +630,69 @@ struct ggml_tensor * ggml_permute(
int axis3);
// alias for ggml_permute(ctx, a, 1, 0, 2, 3)
struct ggml_tensor * ggml_transpose(
GGML_API struct ggml_tensor * ggml_transpose(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_get_rows(
GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// set elements above the diagonal to -INF
// in-place, returns view(a)
struct ggml_tensor * ggml_diag_mask_inf(
GGML_API struct ggml_tensor * ggml_diag_mask_inf(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past);
// in-place, returns view(a)
struct ggml_tensor * ggml_soft_max(
GGML_API struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx,
struct ggml_tensor * a);
// rotary position embedding
// in-place, returns view(a)
// if mode == 1, skip n_past elements
// if mode & 1 == 1, skip n_past elements
// if mode & 2 == 1, GPT-NeoX style
// TODO: avoid creating a new tensor every time
struct ggml_tensor * ggml_rope(
GGML_API struct ggml_tensor * ggml_rope(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode);
// alibi position embedding
// in-place, returns view(a)
struct ggml_tensor * ggml_alibi(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_head);
// padding = 1
// TODO: we don't support extra parameters for now
// that's why we are hard-coding the stride, padding, and dilation
// not great ..
struct ggml_tensor * ggml_conv_1d_1s(
GGML_API struct ggml_tensor * ggml_conv_1d_1s(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * ggml_conv_1d_2s(
GGML_API struct ggml_tensor * ggml_conv_1d_2s(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * ggml_flash_attn(
GGML_API struct ggml_tensor * ggml_flash_attn(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
bool masked);
struct ggml_tensor * ggml_flash_ff(
GGML_API struct ggml_tensor * ggml_flash_ff(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b0,
@ -656,15 +701,15 @@ struct ggml_tensor * ggml_flash_ff(
struct ggml_tensor * c1);
// Mapping operations
typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
GGML_API typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
GGML_API typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
struct ggml_tensor * ggml_map_unary_f32(
GGML_API struct ggml_tensor * ggml_map_unary_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
const ggml_unary_op_f32_t fun);
struct ggml_tensor * ggml_map_binary_f32(
GGML_API struct ggml_tensor * ggml_map_binary_f32(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
@ -674,23 +719,23 @@ struct ggml_tensor * ggml_map_binary_f32(
// automatic differentiation
//
void ggml_set_param(
GGML_API void ggml_set_param(
struct ggml_context * ctx,
struct ggml_tensor * tensor);
void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
void ggml_graph_reset (struct ggml_cgraph * cgraph);
GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
// print info and performance information for the graph
void ggml_graph_print(const struct ggml_cgraph * cgraph);
GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
// dump the graph into a file using the dot format
void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
//
// optimization
@ -783,10 +828,10 @@ struct ggml_opt_params {
} lbfgs;
};
struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
// optimize the function defined by the tensor f
enum ggml_opt_result ggml_opt(
GGML_API enum ggml_opt_result ggml_opt(
struct ggml_context * ctx,
struct ggml_opt_params params,
struct ggml_tensor * f);
@ -795,26 +840,36 @@ enum ggml_opt_result ggml_opt(
// quantization
//
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
//
// system info
//
int ggml_cpu_has_avx(void);
int ggml_cpu_has_avx2(void);
int ggml_cpu_has_avx512(void);
int ggml_cpu_has_fma(void);
int ggml_cpu_has_neon(void);
int ggml_cpu_has_arm_fma(void);
int ggml_cpu_has_f16c(void);
int ggml_cpu_has_fp16_va(void);
int ggml_cpu_has_wasm_simd(void);
int ggml_cpu_has_blas(void);
int ggml_cpu_has_sse3(void);
int ggml_cpu_has_vsx(void);
GGML_API int ggml_cpu_has_avx (void);
GGML_API int ggml_cpu_has_avx2 (void);
GGML_API int ggml_cpu_has_avx512 (void);
GGML_API int ggml_cpu_has_avx512_vbmi(void);
GGML_API int ggml_cpu_has_avx512_vnni(void);
GGML_API int ggml_cpu_has_fma (void);
GGML_API int ggml_cpu_has_neon (void);
GGML_API int ggml_cpu_has_arm_fma (void);
GGML_API int ggml_cpu_has_f16c (void);
GGML_API int ggml_cpu_has_fp16_va (void);
GGML_API int ggml_cpu_has_wasm_simd (void);
GGML_API int ggml_cpu_has_blas (void);
GGML_API int ggml_cpu_has_cublas (void);
GGML_API int ggml_cpu_has_clblast (void);
GGML_API int ggml_cpu_has_gpublas (void);
GGML_API int ggml_cpu_has_sse3 (void);
GGML_API int ggml_cpu_has_vsx (void);
//
// Internal types and functions exposed for tests and benchmarks
@ -834,7 +889,9 @@ typedef struct {
dequantize_row_q_t dequantize_row_q;
quantize_row_q_t quantize_row_q;
quantize_row_q_t quantize_row_q_reference;
quantize_row_q_t quantize_row_q_dot;
vec_dot_q_t vec_dot_q;
enum ggml_type vec_dot_type;
} quantize_fns_t;
quantize_fns_t ggml_internal_get_quantize_fn(size_t i);