From 799a1cb13b0b1b560ab0ceff485caed68faa8f1f Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 13 Dec 2023 13:04:25 +0100 Subject: [PATCH] llama : add Mixtral support (#4406) * convert : support Mixtral as LLAMA arch * convert : fix n_ff typo * llama : model loading * ggml : sync latest ggml_mul_mat_id * llama : update graph to support MoE * llama : fix cur -> cur_expert * llama : first working version * llama : fix expert weighting in the FFN * ggml : ggml_get_rows support 2D indexing [n_tokens, n_experts] (cpu only) * ggml : add n_as argument to ggml_mul_mat_id * ggml : fix ggml_get_rows to take into account ne02 / ne11 * metal : add more general support for ggml_get_rows + tests * llama : add basic support for offloading moe with CUDA * metal : add/mul/div use general kernel when src1 not cont * metal : reduce the kernel launches for ggml_mul_mat_id * ggml : get_rows : support non-contiguos tensors with gaps, generalize up to 3D * ggml : update get_rows f16 and q * cuda : support non-contiguous src1 in get_rows * llama : offload missing ffn_moe_silu * metal : fix ggml_get_rows to work with non-cont src1 * metal : add indirect mat-vec kernels for all quantization types * llama : do not quantize expert gating tensors * llama : add n_expert and n_expert_used to hparams + change quants * test-backend-ops : add moe test * cuda : fix get_rows when ncols is odd * convert : determine n_ctx correctly * metal : fix ggml_mul_mat_id for F32 * test-backend-ops : make experts more evenly probable (test_moe) * test-backend-ops : cleanup, add moe test for batches * test-backend-ops : add cpy from f32 -> all types test * test-backend-ops : fix dequantize block offset * llama : fix hard-coded number of experts * test-backend-ops : simplify and disable slow tests to avoid CI timeout * test-backend-ops : disable MOE test with thread sanitizer * cuda : fix mul_mat_id with multi gpu * convert : use 1e6 rope_freq_base for mixtral * convert : fix style * convert : support safetensors format * gguf-py : bump version * metal : add cpy f16 -> f32 kernel * metal : fix binary ops for ne10 % 4 != 0 * test-backend-ops : add one more sum_rows test * ggml : do not use BLAS with ggml_mul_mat_id * convert-hf : support for mixtral-instruct (#4428) * convert : typo fix, add additional hyperparameters, use LLaMA arch for Mixtral-instruct * convert : use sentencepiece tokenizer for Mixtral-instruct * convert : make flake8 happy * metal : fix soft_max kernels ref: https://github.com/ggerganov/ggml/pull/621/commits/1914017863d2f9ab8ecc0281cc2a56d683668b92 * metal : limit kernels to not use more than the allowed threads --------- Co-authored-by: Georgi Gerganov Co-authored-by: Radek Pilar --- Makefile | 5 + convert-hf-to-gguf.py | 21 +- convert.py | 74 +- ggml-cuda.cu | 297 +++++-- ggml-metal.m | 332 ++++++-- ggml-metal.metal | 1320 +++++++++++++++++++++++++++++--- ggml.c | 168 ++-- ggml.h | 6 +- gguf-py/gguf/constants.py | 16 +- gguf-py/gguf/gguf_writer.py | 6 + gguf-py/gguf/tensor_mapping.py | 39 +- gguf-py/pyproject.toml | 2 +- llama.cpp | 202 ++++- tests/test-backend-ops.cpp | 277 +++++-- 14 files changed, 2370 insertions(+), 395 deletions(-) diff --git a/Makefile b/Makefile index e77595952..b7afda2b5 100644 --- a/Makefile +++ b/Makefile @@ -399,6 +399,11 @@ ifdef LLAMA_CUBLAS MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib OBJS += ggml-cuda.o NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math + +ifdef LLAMA_DEBUG + NVCCFLAGS += -lineinfo +endif + ifdef LLAMA_CUDA_NVCC NVCC = $(LLAMA_CUDA_NVCC) else diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index bced1f561..e46a7813a 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -77,8 +77,18 @@ class Model: self.gguf_writer.add_embedding_length(n_embd) if (n_ff := self.hparams.get("intermediate_size")) is not None: self.gguf_writer.add_feed_forward_length(n_ff) - if (n_head := self.hparams.get("num_attention_head")) is not None: + if (n_head := self.hparams.get("num_attention_heads")) is not None: self.gguf_writer.add_head_count(n_head) + if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: + self.gguf_writer.add_head_count_kv(n_head_kv) + + if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None: + self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps) + if (n_experts := self.hparams.get("num_local_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True)) def write_tensors(self): @@ -170,6 +180,8 @@ class Model: return StableLMModel if model_architecture == "QWenLMHeadModel": return QwenModel + if model_architecture == "MixtralForCausalLM": + return MixtralModel return Model def _is_model_safetensors(self) -> bool: @@ -207,6 +219,8 @@ class Model: return gguf.MODEL_ARCH.STABLELM if arch == "QWenLMHeadModel": return gguf.MODEL_ARCH.QWEN + if arch == "MixtralForCausalLM": + return gguf.MODEL_ARCH.LLAMA raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -837,6 +851,11 @@ class StableLMModel(Model): self.gguf_writer.add_layer_norm_eps(1e-5) +class MixtralModel(Model): + def set_vocab(self): + self._set_vocab_sentencepiece() + + class QwenModel(Model): @staticmethod def token_bytes_to_string(b): diff --git a/convert.py b/convert.py index a6fc6b8ea..e4b69d172 100755 --- a/convert.py +++ b/convert.py @@ -42,6 +42,7 @@ NDArray: TypeAlias = 'np.ndarray[Any, Any]' ARCH = gguf.MODEL_ARCH.LLAMA DEFAULT_CONCURRENCY = 8 + # # data types # @@ -62,10 +63,10 @@ class UnquantizedDataType(DataType): pass -DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0']) -DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0']) -DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = []) -DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0']) +DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0']) +DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0']) +DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = []) +DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0']) @dataclass(frozen=True) @@ -151,14 +152,16 @@ GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = { @dataclass class Params: - n_vocab: int - n_embd: int - n_layer: int - n_ctx: int - n_ff: int - n_head: int - n_head_kv: int - f_norm_eps: float + n_vocab: int + n_embd: int + n_layer: int + n_ctx: int + n_ff: int + n_head: int + n_head_kv: int + n_experts: int | None = None + n_experts_used: int | None = None + f_norm_eps: float | None = None rope_scaling_type: gguf.RopeScalingType | None = None f_rope_freq_base: float | None = None @@ -233,6 +236,13 @@ class Params: raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n" "Suggestion: provide 'config.json' of the model in the same directory containing model files.") + n_experts = None + n_experts_used = None + + if "num_local_experts" in config: + n_experts = config["num_local_experts"] + n_experts_used = config["num_experts_per_tok"] + return Params( n_vocab = config["vocab_size"], n_embd = config["hidden_size"], @@ -241,6 +251,8 @@ class Params: n_ff = config["intermediate_size"], n_head = (n_head := config["num_attention_heads"]), n_head_kv = config.get("num_key_value_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, f_norm_eps = config["rms_norm_eps"], f_rope_freq_base = config.get("rope_theta"), rope_scaling_type = rope_scaling_type, @@ -255,8 +267,15 @@ class Params: def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: config = json.load(open(config_path)) + n_experts = None + n_experts_used = None + f_rope_freq_base = None + # hack to determine LLaMA v1 vs v2 vs CodeLlama - if config.get("rope_theta") == 1000000: + if config.get("moe"): + # Mixtral + n_ctx = 32768 + elif config.get("rope_theta") == 1000000: # CodeLlama n_ctx = 16384 elif config["norm_eps"] == 1e-05: @@ -266,16 +285,27 @@ class Params: # LLaMA v1 n_ctx = 2048 + if "layers.0.feed_forward.w1.weight" in model: + n_ff = model["layers.0.feed_forward.w1.weight"].shape[0] + + if config.get("moe"): + n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0] + n_experts = config["moe"]["num_experts"] + n_experts_used = config["moe"]["num_experts_per_tok"] + f_rope_freq_base = 1e6 + return Params( n_vocab = model["tok_embeddings.weight"].shape[0], n_embd = config["dim"], n_layer = config["n_layers"], n_ctx = n_ctx, - n_ff = model["layers.0.feed_forward.w1.weight"].shape[0], + n_ff = n_ff, n_head = (n_head := config["n_heads"]), n_head_kv = config.get("n_kv_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, f_norm_eps = config["norm_eps"], - f_rope_freq_base = config.get("rope_theta"), + f_rope_freq_base = config.get("rope_theta", f_rope_freq_base), ) @staticmethod @@ -832,7 +862,17 @@ class OutputFile: self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) self.gguf.add_head_count (params.n_head) self.gguf.add_head_count_kv (params.n_head_kv) - self.gguf.add_layer_norm_rms_eps (params.f_norm_eps) + + if params.n_experts: + self.gguf.add_expert_count(params.n_experts) + + if params.n_experts_used: + self.gguf.add_expert_used_count(params.n_experts_used) + + if params.f_norm_eps: + self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) + else: + raise ValueError('f_norm_eps is None') if params.f_rope_freq_base is not None: self.gguf.add_rope_freq_base(params.f_rope_freq_base) @@ -956,7 +996,7 @@ class OutputFile: def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType: - wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) +".weight"].data_type + wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32): return GGMLFileType.AllF32 diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 85f7a2937..9e1acd3f1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,13 +1,15 @@ #include +#include +#include +#include #include #include -#include #include #include #include #include -#include -#include +#include + #if defined(GGML_USE_HIPBLAS) #include @@ -1684,31 +1686,65 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest } template -static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) { - const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2; - const int row = blockDim.y*blockIdx.y + threadIdx.y; +static __global__ void k_get_rows( + const void * src0, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12/*, size_t s13*/) { - if (col >= ncols) { + const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; + const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + + if (i00 >= ne00) { return; } - const int r = y[row]; + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - // copy x[r*ncols + col] to dst[row*ncols + col] - const int xi = r*ncols + col; - const int di = row*ncols + col; + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; - const int ib = xi/qk; // block index - const int iqs = (xi%qk)/qr; // quant index - const int iybs = di - di%qk; // y block start index + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index const int y_offset = qr == 1 ? 1 : qk/2; // dequantize dfloat2 v; - dequantize_kernel(x, ib, iqs, v); + dequantize_kernel(src0_row, ib, iqs, v); - dst[iybs + iqs + 0] = v.x; - dst[iybs + iqs + y_offset] = v.y; + dst_row[iybs + iqs + 0] = v.x; + dst_row[iybs + iqs + y_offset] = v.y; +} + +template +static __global__ void k_get_rows_float( + const src0_t * src0, const int32_t * src1, dst_t * dst, + int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ + /*size_t s0,*/ size_t s1, size_t s2, size_t s3, + /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, + size_t s10, size_t s11, size_t s12/*, size_t s13*/) { + + const int i00 = blockIdx.x*blockDim.x + threadIdx.x; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; + const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); + + dst_row[i00] = src0_row[i00]; } template @@ -5053,11 +5089,69 @@ static __global__ void im2col_f32_f16( } template -static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { +static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(block_num_x, nrows, 1); - k_get_rows<<>>(x, y, dst, ncols); + const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); + const dim3 block_nums(block_num_x, ne10, ne11*ne12); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + GGML_ASSERT(ne00 % 2 == 0); + + k_get_rows<<>>( + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); + + (void) dst; +} + +template +static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); + const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; + const dim3 block_nums(block_num_x, ne10, ne11*ne12); + + // strides in elements + //const size_t s0 = nb0 / ggml_element_size(dst); + const size_t s1 = nb1 / ggml_element_size(dst); + const size_t s2 = nb2 / ggml_element_size(dst); + const size_t s3 = nb3 / ggml_element_size(dst); + + const size_t s10 = nb10 / ggml_element_size(src1); + const size_t s11 = nb11 / ggml_element_size(src1); + const size_t s12 = nb12 / ggml_element_size(src1); + //const size_t s13 = nb13 / ggml_element_size(src1); + + k_get_rows_float<<>>( + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); + + (void) dst; } template @@ -5069,7 +5163,6 @@ struct bin_bcast_cuda { GGML_TENSOR_BINARY_OP_LOCALS - int nr0 = ne10/ne0; int nr1 = ne11/ne1; int nr2 = ne12/ne2; @@ -5117,26 +5210,28 @@ struct bin_bcast_cuda { int64_t ne12 = cne1[2]; int64_t ne13 = cne1[3]; - //size_t nb0 = cnb0[0]; + size_t nb0 = cnb0[0]; size_t nb1 = cnb0[1]; size_t nb2 = cnb0[2]; size_t nb3 = cnb0[3]; - //size_t nb10 = cnb1[0]; + size_t nb10 = cnb1[0]; size_t nb11 = cnb1[1]; size_t nb12 = cnb1[2]; size_t nb13 = cnb1[3]; - //size_t s0 = nb0 / sizeof(src1_t); + size_t s0 = nb0 / sizeof(src1_t); size_t s1 = nb1 / sizeof(src1_t); size_t s2 = nb2 / sizeof(src1_t); size_t s3 = nb3 / sizeof(src1_t); - //size_t s10 = nb10 / sizeof(src1_t); + size_t s10 = nb10 / sizeof(src1_t); size_t s11 = nb11 / sizeof(src1_t); size_t s12 = nb12 / sizeof(src1_t); size_t s13 = nb13 / sizeof(src1_t); + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s10 == 1); const int block_size = 128; @@ -6447,36 +6542,34 @@ static void ggml_cuda_op_get_rows( GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(dst)); - const int ncols = src0->ne[0]; - const int nrows = ggml_nelements(src1); + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); + GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); const int32_t * src1_i32 = (const int32_t *) src1_d; switch (src0->type) { case GGML_TYPE_F16: - get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream); + get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_F32: - get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream); + get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q4_0: - get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q4_1: - get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q5_0: - get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q5_1: - get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; case GGML_TYPE_Q8_0: - get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); break; default: // TODO: k-quants @@ -8234,36 +8327,69 @@ static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) { } #endif -static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) { +static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #if 0 -//#ifdef CUDA_USE_TENSOR_CORES -// const bool use_tensor_cores = true; -//#else -// const bool use_tensor_cores = false; -//#endif - ggml_cuda_mul_mat_id_cublas(dst); - // TODO: mmq/mmv support -#else - const struct ggml_tensor * ids = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const int id = dst->op_params[0]; - - int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device]; - - int32_t a_id; - CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); - CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); - - GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]); - const struct ggml_tensor * src0 = dst->src[a_id + 2]; - - ggml_cuda_mul_mat(src0, src1, dst); #endif - (void) _src0; - (void) _src1; + GGML_ASSERT(dst->backend == GGML_BACKEND_GPU); + + const struct ggml_tensor * ids = src0; + const int32_t id = ((int32_t *) dst->op_params)[0]; + const int32_t n_as = ((int32_t *) dst->op_params)[1]; + + std::vector ids_host(ggml_nbytes(ids)); + + if (ids->backend == GGML_BACKEND_GPU) { + const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device]; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); + } else { + memcpy(ids_host.data(), ids->data, ggml_nbytes(ids)); + } + + const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra; + const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra; + + ggml_tensor_extra_gpu src1_row_extra; + ggml_tensor_extra_gpu dst_row_extra; + + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + + src1_row.ne[1] = 1; + dst_row.ne[1] = 1; + + src1_row.nb[2] = src1_row.nb[1]; + dst_row.nb[2] = dst_row.nb[1]; + + src1_row.nb[3] = src1_row.nb[1]; + dst_row.nb[3] = dst_row.nb[1]; + + src1_row.extra = &src1_row_extra; + dst_row.extra = &dst_row_extra; + + + for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { + //int32_t row_id; + //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); + //CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); + + const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(row_id >= 0 && row_id < n_as); + + const struct ggml_tensor * src0_row = dst->src[row_id + 2]; + + src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1]; + src1_row.data = (char *) src1->data + i01*src1->nb[1]; + + dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1]; + dst_row.data = (char *) dst->data + i01*dst->nb[1]; + + ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row); + } } static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -9181,6 +9307,45 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten } return true; } break; + case GGML_OP_GET_ROWS: + { + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return true; + default: + return false; + } + } break; + case GGML_OP_CPY: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1]->type; + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) { + return true; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return true; + } + return false; + } break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -9188,7 +9353,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_REPEAT: - case GGML_OP_GET_ROWS: case GGML_OP_DUP: case GGML_OP_ADD: case GGML_OP_MUL: @@ -9197,7 +9361,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_CLAMP: - case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -9264,7 +9427,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use UNUSED(params); } -extern "C" int ggml_backend_cuda_reg_devices() { +extern "C" int ggml_backend_cuda_reg_devices(); + +int ggml_backend_cuda_reg_devices() { int device_count = ggml_cuda_get_device_count(); //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization for (int i = 0; i < device_count; i++) { diff --git a/ggml-metal.m b/ggml-metal.m index f9bd69dc8..1dcfa6edd 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -102,6 +102,21 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32); + //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16); + GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32); + //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row); + //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4); + GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); @@ -140,6 +155,7 @@ struct ggml_metal_context { //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0); //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1); GGML_METAL_DECL_KERNEL(cpy_f16_f16); + GGML_METAL_DECL_KERNEL(cpy_f16_f32); GGML_METAL_DECL_KERNEL(concat); GGML_METAL_DECL_KERNEL(sqr); GGML_METAL_DECL_KERNEL(sum_rows); @@ -177,6 +193,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data); } else { char* buffer2 = malloc(len+1); + va_end(args); + va_start(args, format); vsnprintf(buffer2, len+1, format, args); buffer2[len] = 0; ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data); @@ -352,6 +370,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32); + //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16); + GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32); + //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row); + //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4); + GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32); if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); @@ -392,6 +425,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0); //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1); GGML_METAL_ADD_KERNEL(cpy_f16_f16); + GGML_METAL_ADD_KERNEL(cpy_f16_f32); GGML_METAL_ADD_KERNEL(concat); GGML_METAL_ADD_KERNEL(sqr); GGML_METAL_ADD_KERNEL(sum_rows); @@ -452,6 +486,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32); + //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16); + GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32); + //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row); + //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4); + GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32); if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); @@ -492,6 +541,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0); //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1); GGML_METAL_DEL_KERNEL(cpy_f16_f16); + GGML_METAL_DEL_KERNEL(cpy_f16_f32); GGML_METAL_DEL_KERNEL(concat); GGML_METAL_DEL_KERNEL(sqr); GGML_METAL_DEL_KERNEL(sum_rows); @@ -803,8 +853,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_GET_ROWS: case GGML_OP_CONCAT: case GGML_OP_ADD: case GGML_OP_MUL: @@ -819,14 +870,38 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { case GGML_OP_ROPE: case GGML_OP_IM2COL: case GGML_OP_ARGSORT: - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return true; + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_CONT: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + switch (op->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return true; + default: + return false; + } + case GGML_TYPE_F16: + switch (op->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + default: + return false; + }; + } case GGML_OP_DIAG_MASK_INF: - case GGML_OP_GET_ROWS: { return op->ne[0] % 4 == 0; } @@ -1001,34 +1076,37 @@ void ggml_metal_graph_compute( case GGML_OP_MUL: case GGML_OP_DIV: { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - bool bcast_row = false; int64_t nb = ne00; - if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) { + id pipeline = nil; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + // src1 is a row GGML_ASSERT(ne11 == 1); nb = ne00 / 4; switch (dst->op) { - case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break; - case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break; - case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break; + case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break; + case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break; + case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break; default: GGML_ASSERT(false); } bcast_row = true; } else { switch (dst->op) { - case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break; - case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break; - case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break; + case GGML_OP_ADD: pipeline = ctx->pipeline_add; break; + case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break; + case GGML_OP_DIV: pipeline = ctx->pipeline_div; break; default: GGML_ASSERT(false); } } + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1063,7 +1141,7 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } else { - const int nth = MIN(1024, ne0); + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } @@ -1193,7 +1271,11 @@ void ggml_metal_graph_compute( const float scale = ((float *) dst->op_params)[0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; @@ -1444,7 +1526,7 @@ void ggml_metal_graph_compute( else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - int64_t ny = (ne11 + nrows - 1)/nrows; + const int64_t ny = (ne11 + nrows - 1)/nrows; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } @@ -1456,7 +1538,7 @@ void ggml_metal_graph_compute( GGML_ASSERT(src0t == GGML_TYPE_I32); - const int n_as = ne00; + const int n_as = ((int32_t *) dst->op_params)[1]; // TODO: make this more general GGML_ASSERT(n_as <= 8); @@ -1488,14 +1570,22 @@ void ggml_metal_graph_compute( // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel - int ne11_mm_min = 0; + int ne11_mm_min = 1; const int idx = ((int32_t *) dst->op_params)[0]; + // batch size + GGML_ASSERT(ne01 == ne11); + + const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne11 > ne11_mm_min) { + // !!! + // TODO: for now, always use mat-vec kernels until we figure out how to improve the + // indirect matrix multiplication + // !!! + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) { switch (src2->type) { case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break; @@ -1514,19 +1604,22 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3]; - [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:15]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:16]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:18]; // TODO: how to make this an array? read Metal docs for (int j = 0; j < n_as; ++j) { struct ggml_tensor * src_cur = dst->src[2 + j]; @@ -1534,11 +1627,157 @@ void ggml_metal_graph_compute( size_t offs_src_cur = 0; id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); - [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j]; + [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j]; } [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + + // TODO: processing one row at a time (ne11 -> 1) is not efficient + [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + // use custom matrix x vector kernel + switch (src2t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32]; + } break; + case GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32]; + } break; + case GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32]; + } break; + case GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32]; + } break; + case GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32]; + } break; + case GGML_TYPE_Q8_0: + { + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32]; + } break; + case GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32]; + } break; + case GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32]; + } break; + case GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32]; + } break; + case GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32]; + } break; + case GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32]; + } break; + default: + { + GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); + GGML_ASSERT(false && "not implemented"); + } + }; + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:20]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:21]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:22]; + // TODO: how to make this an array? read Metal docs + for (int j = 0; j < n_as; ++j) { + struct ggml_tensor * src_cur = dst->src[2 + j]; + + size_t offs_src_cur = 0; + id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); + + [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j]; + } + + if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || + src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || + src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src2t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src2t == GGML_TYPE_Q3_K) { +#ifdef GGML_QKK_64 + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#else + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#endif + } + else if (src2t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src2t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (_ne1 + nrows - 1)/nrows; + [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } } } break; case GGML_OP_GET_ROWS: @@ -1559,16 +1798,19 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false && "not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; - const int64_t n = ggml_nelements(src1); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_RMS_NORM: { @@ -1813,7 +2055,7 @@ void ggml_metal_graph_compute( { switch (dstt) { case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break; - case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break; + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break; default: GGML_ASSERT(false && "not implemented"); }; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 2f8ea22d6..773fac124 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -347,9 +347,9 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 ? src1 + i01*ne00 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; + device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max float lmax = -INFINITY; @@ -385,7 +385,12 @@ kernel void kernel_soft_max( pdst[i00] = exp_psrc0; } + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + float sum = simd_sum(lsum); + if (ntg > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; @@ -428,9 +433,9 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; + device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max float4 lmax4 = -INFINITY; @@ -468,7 +473,13 @@ kernel void kernel_soft_max_4( } const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + float sum = simd_sum(lsum); + if (ntg > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; @@ -731,7 +742,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // giard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. template -void mul_vec_q_n_f32( +void mul_vec_q_n_f32_impl( device const void * src0, device const float * src1, device float * dst, @@ -813,7 +824,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -832,7 +843,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -851,7 +862,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -870,28 +881,28 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } #define NB_Q8_0 8 -kernel void kernel_mul_mv_q8_0_f32( +void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; @@ -945,9 +956,29 @@ kernel void kernel_mul_mv_q8_0_f32( } } +[[host_name("kernel_mul_mv_q8_0_f32")]] +kernel void kernel_mul_mv_q8_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); +} + #define N_F32_F32 4 -kernel void kernel_mul_mv_f32_f32( +void kernel_mul_mv_f32_f32_impl( device const char * src0, device const char * src1, device float * dst, @@ -965,8 +996,8 @@ kernel void kernel_mul_mv_f32_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1025,6 +1056,32 @@ kernel void kernel_mul_mv_f32_f32( } } +[[host_name("kernel_mul_mv_f32_f32")]] +kernel void kernel_mul_mv_f32_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + #define N_F16_F16 4 kernel void kernel_mul_mv_f16_f16( @@ -1105,7 +1162,7 @@ kernel void kernel_mul_mv_f16_f16( } } -kernel void kernel_mul_mv_f16_f32_1row( +void kernel_mul_mv_f16_f32_1row_impl( device const char * src0, device const char * src1, device float * dst, @@ -1123,8 +1180,8 @@ kernel void kernel_mul_mv_f16_f32_1row( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1161,12 +1218,10 @@ kernel void kernel_mul_mv_f16_f32_1row( dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } } - } -#define N_F16_F32 4 - -kernel void kernel_mul_mv_f16_f32( +[[host_name("kernel_mul_mv_f16_f32_1row")]] +kernel void kernel_mul_mv_f16_f32_1row( device const char * src0, device const char * src1, device float * dst, @@ -1187,6 +1242,33 @@ kernel void kernel_mul_mv_f16_f32( constant uint & r2 [[buffer(17)]], constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + +#define N_F16_F32 4 + +void kernel_mul_mv_f16_f32_impl( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; @@ -1244,6 +1326,32 @@ kernel void kernel_mul_mv_f16_f32( } } +[[host_name("kernel_mul_mv_f16_f32")]] +kernel void kernel_mul_mv_f16_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); +} + // Assumes row size (ne00) is a multiple of 4 kernel void kernel_mul_mv_f16_f32_l4( device const char * src0, @@ -1601,8 +1709,8 @@ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_ar template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; kernel void kernel_cpy_f16_f16( - device const half * src0, - device half * dst, + device const half * src0, + device half * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -1641,6 +1749,47 @@ kernel void kernel_cpy_f16_f16( } } +kernel void kernel_cpy_f16_f32( + device const half * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + kernel void kernel_cpy_f32_f16( device const float * src0, device half * dst, @@ -2064,19 +2213,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { //====================================== dot products ========================= -kernel void kernel_mul_mv_q2_K_f32( +void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2214,8 +2363,8 @@ kernel void kernel_mul_mv_q2_K_f32( } } -#if QK_K == 256 -kernel void kernel_mul_mv_q3_K_f32( +[[host_name("kernel_mul_mv_q2_K_f32")]] +kernel void kernel_mul_mv_q2_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -2229,8 +2378,29 @@ kernel void kernel_mul_mv_q3_K_f32( constant uint & r2 [[buffer(17)]], constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +#if QK_K == 256 +void kernel_mul_mv_q3_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; @@ -2373,19 +2543,19 @@ kernel void kernel_mul_mv_q3_K_f32( } } #else -kernel void kernel_mul_mv_q3_K_f32( +void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2450,20 +2620,41 @@ kernel void kernel_mul_mv_q3_K_f32( } #endif -#if QK_K == 256 -kernel void kernel_mul_mv_q4_K_f32( +[[host_name("kernel_mul_mv_q3_K_f32")]] +kernel void kernel_mul_mv_q3_K_f32( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01 [[buffer(4)]], - constant int64_t & ne02 [[buffer(5)]], - constant int64_t & ne10 [[buffer(9)]], - constant int64_t & ne12 [[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +#if QK_K == 256 +void kernel_mul_mv_q4_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2564,19 +2755,19 @@ kernel void kernel_mul_mv_q4_K_f32( } } #else -kernel void kernel_mul_mv_q4_K_f32( +void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2660,7 +2851,8 @@ kernel void kernel_mul_mv_q4_K_f32( } #endif -kernel void kernel_mul_mv_q5_K_f32( +[[host_name("kernel_mul_mv_q4_K_f32")]] +kernel void kernel_mul_mv_q4_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -2677,6 +2869,26 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q5_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; @@ -2836,10 +3048,10 @@ kernel void kernel_mul_mv_q5_K_f32( dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; } } - } -kernel void kernel_mul_mv_q6_K_f32( +[[host_name("kernel_mul_mv_q5_K_f32")]] +kernel void kernel_mul_mv_q5_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -2853,8 +3065,28 @@ kernel void kernel_mul_mv_q6_K_f32( constant uint & r2 [[buffer(17)]], constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +void kernel_mul_mv_q6_K_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; @@ -2945,6 +3177,27 @@ kernel void kernel_mul_mv_q6_K_f32( } } +[[host_name("kernel_mul_mv_q6_K_f32")]] +kernel void kernel_mul_mv_q6_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template @@ -3219,22 +3472,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg template kernel void kernel_get_rows( device const void * src0, - device const int * src1, + device const char * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, constant uint64_t & nb1, - uint tgpig[[threadgroup_position_in_grid]], + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], - uint tptg[[threads_per_threadgroup]]) { - const int i = tgpig; - const int r = ((device int32_t *) src1)[i]; + uint3 tptg [[threads_per_threadgroup]]) { + //const int64_t i = tgpig; + //const int64_t r = ((device int32_t *) src1)[i]; - for (int ind = tiitg; ind < ne00/16; ind += tptg) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { float4x4 temp; dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp; + ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} + +kernel void kernel_get_rows_f32( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + device const void * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; } } @@ -3426,19 +3747,22 @@ kernel void kernel_mul_mm(device const uchar * src0, template kernel void kernel_mul_mm_id( - device const int32_t * ids, + device const uchar * ids, device const uchar * src1, - device float * dst, + device uchar * dst, + constant int64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, constant int64_t & nb01, constant int64_t & nb02, constant int64_t & ne12, + constant int64_t & ne13, constant int64_t & nb10, constant int64_t & nb11, constant int64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, + constant int64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -3456,10 +3780,16 @@ kernel void kernel_mul_mm_id( uint sgitg[[simdgroup_index_in_threadgroup]]) { device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + kernel_mul_mm_impl( - src0[ids[idx]], - src1, - dst, + src0[id], + src1 + bid*nb11, + (device float *) (dst + bid*nb1), ne00, ne02, nb01, @@ -3484,17 +3814,26 @@ kernel void kernel_mul_mm_id( #define QK_NL 4 #endif +// +// get rows +// + typedef void (get_rows_t)( device const void * src0, - device const int * src1, + device const char * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, constant uint64_t & nb1, - uint, uint, uint); + constant uint64_t & nb2, + uint3, uint, uint3); -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; +//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; @@ -3506,6 +3845,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; +// +// matrix-matrix multiplication +// + typedef void (mat_mm_t)( device const uchar * src0, device const uchar * src1, @@ -3538,20 +3881,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +// +// indirect matrix-matrix multiplication +// + typedef void (mat_mm_id_t)( - device const int32_t * ids, + device const uchar * ids, device const uchar * src1, - device float * dst, + device uchar * dst, + constant int64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, constant int64_t & nb01, constant int64_t & nb02, constant int64_t & ne12, + constant int64_t & ne13, constant int64_t & nb10, constant int64_t & nb11, constant int64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, + constant int64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -3578,3 +3928,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; + +// +// matrix-vector multiplication +// + +[[host_name("kernel_mul_mv_id_f32_f32")]] +kernel void kernel_mul_mv_id_f32_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_f32_f32_impl( + src0[id], + src1 + bid*nb11, + (device float *) (dst + bid*nb1), + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); +} + +[[host_name("kernel_mul_mv_id_f16_f32")]] +kernel void kernel_mul_mv_id_f16_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_f16_f32_impl( + src0[id], + src1 + bid*nb11, + (device float *) (dst + bid*nb1), + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); +} + +[[host_name("kernel_mul_mv_id_q8_0_f32")]] +kernel void kernel_mul_mv_id_q8_0_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q8_0_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_0_f32")]] +kernel void kernel_mul_mv_id_q4_0_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_1_f32")]] +kernel void kernel_mul_mv_id_q4_1_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_0_f32")]] +kernel void kernel_mul_mv_id_q5_0_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_1_f32")]] +kernel void kernel_mul_mv_id_q5_1_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + mul_vec_q_n_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q2_K_f32")]] +kernel void kernel_mul_mv_id_q2_K_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q2_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q3_K_f32")]] +kernel void kernel_mul_mv_id_q3_K_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q3_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q4_K_f32")]] +kernel void kernel_mul_mv_id_q4_K_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q4_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q5_K_f32")]] +kernel void kernel_mul_mv_id_q5_K_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q5_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} + +[[host_name("kernel_mul_mv_id_q6_K_f32")]] +kernel void kernel_mul_mv_id_q6_K_f32( + device const char * ids, + device const char * src1, + device uchar * dst, + constant int64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_q6_K_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + (device float *) ( dst + bid*nb1), + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg, + sgitg); +} diff --git a/ggml.c b/ggml.c index eb7989dc4..66658ff4b 100644 --- a/ggml.c +++ b/ggml.c @@ -4075,17 +4075,18 @@ struct ggml_tensor * ggml_mul_mat( struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, - struct ggml_tensor * as[], + struct ggml_tensor * const as[], + int n_as, struct ggml_tensor * ids, int id, struct ggml_tensor * b) { - int64_t n_as = ids->ne[0]; - GGML_ASSERT(ids->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_is_vector(ids)); + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); + GGML_ASSERT(ids->ne[1] == b->ne[1]); + GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]); GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2); - GGML_ASSERT(id >= 0 && id < n_as); + GGML_ASSERT(id >= 0 && id < ids->ne[0]); bool is_node = false; @@ -4097,13 +4098,14 @@ struct ggml_tensor * ggml_mul_mat_id( struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne); ggml_set_op_params_i32(result, 0, id); + ggml_set_op_params_i32(result, 1, n_as); result->op = GGML_OP_MUL_MAT_ID; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = ids; result->src[1] = b; - for (int64_t i = 0; i < n_as; i++) { + for (int i = 0; i < n_as; i++) { struct ggml_tensor * a = as[i]; GGML_ASSERT(ggml_are_same_shape(as[0], a)); GGML_ASSERT(ggml_can_mul_mat(a, b)); @@ -4731,7 +4733,9 @@ struct ggml_tensor * ggml_get_rows( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(b->ne[3] == 1); + GGML_ASSERT(b->type == GGML_TYPE_I32); bool is_node = false; @@ -4741,7 +4745,7 @@ struct ggml_tensor * ggml_get_rows( // TODO: implement non F32 return //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); result->op = GGML_OP_GET_ROWS; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -9504,8 +9508,11 @@ static bool ggml_compute_forward_mul_mat_use_blas( const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; + // NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float) + // all the experts for each batch element and the processing would become incredibly slow // TODO: find the optimal values for these - if (ggml_is_contiguous(src0) && + if (dst->op != GGML_OP_MUL_MAT_ID && + ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && //src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && @@ -9519,11 +9526,16 @@ static bool ggml_compute_forward_mul_mat_use_blas( } #endif +// off1 = offset in i11 and i1 +// cne1 = ne11 and ne1 +// in a normal matrix multiplication, off1 = 0 and cne1 = ne1 +// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1 static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - struct ggml_tensor * dst) { + struct ggml_tensor * dst, + int64_t off1, int64_t cne1) { int64_t t0 = ggml_perf_time_us(); UNUSED(t0); @@ -9591,10 +9603,9 @@ static void ggml_compute_forward_mul_mat( const int64_t i03 = i13/r3; const int64_t i02 = i12/r2; - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); - - float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13); + float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3); if (type != GGML_TYPE_F32) { float * const wdata = params->wdata; @@ -9611,10 +9622,10 @@ static void ggml_compute_forward_mul_mat( } cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne00, - 0.0f, d, ne01); + cne1, ne01, ne10, + 1.0f, y, ne10, + x, ne00, + 0.0f, d, ne01); } } @@ -9630,6 +9641,7 @@ static void ggml_compute_forward_mul_mat( const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); assert(params->wsize >= ne11*ne12*ne13*row_size); + assert(src1->type == GGML_TYPE_F32); for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { @@ -9652,7 +9664,7 @@ static void ggml_compute_forward_mul_mat( const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = ne11*ne12*ne13; // src1 rows + const int64_t nr1 = cne1*ne12*ne13; // src1 rows //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); @@ -9694,9 +9706,9 @@ static void ggml_compute_forward_mul_mat( for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t i13 = (ir1/(ne12*ne11)); - const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; - const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); + const int64_t i13 = (ir1/(ne12*cne1)); + const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1; + const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1; // broadcast src0 into src1 const int64_t i03 = i13/r3; @@ -9736,20 +9748,28 @@ static void ggml_compute_forward_mul_mat( static void ggml_compute_forward_mul_mat_id( const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { - const struct ggml_tensor * ids = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + // during GGML_TASK_INIT the entire src1 is converted to vec_dot_type + ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]); + return; + } - const int id = ggml_get_op_params_i32(dst, 0); + const struct ggml_tensor * ids = src0; + const int id = ggml_get_op_params_i32(dst, 0); + const int n_as = ggml_get_op_params_i32(dst, 1); - const int a_id = ((int32_t *)ids->data)[id]; + for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { + const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]); + GGML_ASSERT(row_id >= 0 && row_id < n_as); - const struct ggml_tensor * src0 = dst->src[a_id + 2]; - - ggml_compute_forward_mul_mat(params, src0, src1, dst); + const struct ggml_tensor * src0_row = dst->src[row_id + 2]; + ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1); + } } // ggml_compute_forward_out_prod @@ -10325,21 +10345,30 @@ static void ggml_compute_forward_get_rows_q( return; } - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr); + const enum ggml_type type = src0->type; ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == ggml_type_size(type)); + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == ggml_type_size(type)); + assert(ggml_nrows(dst) == nr); - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; + // TODO: multi-thread + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - dequantize_row_q( - (const void *) ((char *) src0->data + r*src0->nb[1]), - (float *) ((char *) dst->data + i*dst->nb[1]), nc); + dequantize_row_q( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } + } } } @@ -10354,19 +10383,26 @@ static void ggml_compute_forward_get_rows_f16( return; } - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); + GGML_TENSOR_BINARY_OP_LOCALS - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == sizeof(ggml_fp16_t)); + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr); - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(ggml_fp16_t)); + assert(ggml_nrows(dst) == nr); - for (int j = 0; j < nc; ++j) { - ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; - ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); + // TODO: multi-thread + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + ggml_fp16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } } } } @@ -10382,19 +10418,27 @@ static void ggml_compute_forward_get_rows_f32( return; } - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); + GGML_TENSOR_BINARY_OP_LOCALS - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == sizeof(float)); + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr); - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(ggml_nrows(dst) == nr); - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i*dst->nb[1]), - (float *) ((char *) src0->data + r*src0->nb[1])); + // TODO: multi-thread + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + } + } } } @@ -14037,11 +14081,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT: { - ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]); } break; case GGML_OP_MUL_MAT_ID: { - ggml_compute_forward_mul_mat_id(params, tensor); + ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_OUT_PROD: { diff --git a/ggml.h b/ggml.h index 41a075e92..32f256481 100644 --- a/ggml.h +++ b/ggml.h @@ -217,7 +217,7 @@ #define GGML_MAX_DIMS 4 #define GGML_MAX_PARAMS 2048 #define GGML_MAX_CONTEXTS 64 -#define GGML_MAX_SRC 6 +#define GGML_MAX_SRC 10 #define GGML_MAX_NAME 64 #define GGML_MAX_OP_PARAMS 64 #define GGML_DEFAULT_N_THREADS 4 @@ -1051,7 +1051,8 @@ extern "C" { // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b) GGML_API struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, - struct ggml_tensor * as[], + struct ggml_tensor * const as[], + int n_as, struct ggml_tensor * ids, int id, struct ggml_tensor * b); @@ -1263,6 +1264,7 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // supports 3D: a->ne[2] == b->ne[1] GGML_API struct ggml_tensor * ggml_get_rows( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 685c88f1a..12133882b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -38,6 +38,8 @@ class Keys: FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" + EXPERT_COUNT = "{arch}.expert_count" + EXPERT_USED_COUNT = "{arch}.expert_used_count" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -111,10 +113,14 @@ class MODEL_TENSOR(IntEnum): ATTN_NORM = auto() ATTN_NORM_2 = auto() ATTN_ROT_EMBD = auto() + FFN_GATE_INP = auto() + FFN_NORM = auto() FFN_GATE = auto() FFN_DOWN = auto() FFN_UP = auto() - FFN_NORM = auto() + FFN_GATE_EXP = auto() + FFN_DOWN_EXP = auto() + FFN_UP_EXP = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() @@ -154,10 +160,14 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}", + MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}", + MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -172,10 +182,14 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, ], MODEL_ARCH.GPTNEOX: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index b8ec977c8..73e021607 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -339,6 +339,12 @@ class GGUFWriter: def add_clamp_kqv(self, value: float) -> None: self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value) + def add_expert_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count) + + def add_expert_used_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index cc6236014..0115ea1c6 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -149,6 +149,11 @@ class TensorNameMap: "model.layers.{bid}.ln2", # yi ), + MODEL_TENSOR.FFN_GATE_INP: ( + "layers.{bid}.feed_forward.gate", # mixtral + "model.layers.{bid}.block_sparse_moe.gate", # mixtral + ), + # Feed-forward up MODEL_TENSOR.FFN_UP: ( "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox @@ -164,11 +169,21 @@ class TensorNameMap: "transformer.h.{bid}.mlp.w1", # qwen ), + MODEL_TENSOR.FFN_UP_EXP: ( + "layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral + "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral + ), + # Feed-forward gate MODEL_TENSOR.FFN_GATE: ( - "model.layers.{bid}.mlp.gate_proj", # llama-hf refact - "layers.{bid}.feed_forward.w1", # llama-pth - "transformer.h.{bid}.mlp.w2", # qwen + "model.layers.{bid}.mlp.gate_proj", # llama-hf refact + "layers.{bid}.feed_forward.w1", # llama-pth + "transformer.h.{bid}.mlp.w2", # qwen + ), + + MODEL_TENSOR.FFN_GATE_EXP: ( + "layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral + "model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral ), # Feed-forward down @@ -185,6 +200,11 @@ class TensorNameMap: "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon ), + MODEL_TENSOR.FFN_DOWN_EXP: ( + "layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral + "model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral + ), + MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", ), @@ -213,11 +233,14 @@ class TensorNameMap: for tensor, keys in self.block_mappings_cfg.items(): if tensor not in MODEL_TENSORS[arch]: continue - tensor_name = TENSOR_NAMES[tensor].format(bid = bid) - self.mapping[tensor_name] = (tensor, tensor_name) - for key in keys: - key = key.format(bid = bid) - self.mapping[key] = (tensor, tensor_name) + # TODO: make this configurable + n_experts = 8 + for xid in range(n_experts): + tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) + self.mapping[tensor_name] = (tensor, tensor_name) + for key in keys: + key = key.format(bid = bid, xid = xid) + self.mapping[key] = (tensor, tensor_name) def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: result = self.mapping.get(key) diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index e6374bfe8..9789c2c87 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.6.0" +version = "0.7.0" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ diff --git a/llama.cpp b/llama.cpp index 54fa9e43e..0e5ab044c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -91,7 +91,8 @@ #define LLAMA_ATTRIBUTE_FORMAT(...) #endif -#define LLAMA_MAX_NODES 8192 +#define LLAMA_MAX_NODES 8192 +#define LLAMA_MAX_EXPERTS 8 // // logging @@ -231,6 +232,8 @@ enum llm_kv { LLM_KV_FEED_FORWARD_LENGTH, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, + LLM_KV_EXPERT_COUNT, + LLM_KV_EXPERT_USED_COUNT, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -281,6 +284,8 @@ static std::map LLM_KV_NAMES = { { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -338,10 +343,14 @@ enum llm_tensor { LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_NORM_2, LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_DOWN_EXP, + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_UP_EXP, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, }; @@ -360,10 +369,14 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, }, }, { @@ -585,6 +598,10 @@ struct LLM_TN { std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix; } + + std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const { + return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid, xid) + "." + suffix; + } }; // @@ -1164,6 +1181,8 @@ struct llama_hparams { uint32_t n_layer; uint32_t n_rot; uint32_t n_ff; + uint32_t n_expert = 0; + uint32_t n_expert_used = 0; float f_norm_eps; float f_norm_rms_eps; @@ -1178,15 +1197,18 @@ struct llama_hparams { float f_max_alibi_bias; bool operator!=(const llama_hparams & other) const { - if (this->vocab_only != other.vocab_only) return true; - if (this->n_vocab != other.n_vocab) return true; - if (this->n_ctx_train != other.n_ctx_train) return true; - if (this->n_embd != other.n_embd) return true; - if (this->n_head != other.n_head) return true; - if (this->n_head_kv != other.n_head_kv) return true; - if (this->n_layer != other.n_layer) return true; - if (this->n_rot != other.n_rot) return true; - if (this->n_ff != other.n_ff) return true; + if (this->vocab_only != other.vocab_only) return true; + if (this->n_vocab != other.n_vocab) return true; + if (this->n_ctx_train != other.n_ctx_train) return true; + if (this->n_embd != other.n_embd) return true; + if (this->n_head != other.n_head) return true; + if (this->n_head_kv != other.n_head_kv) return true; + if (this->n_layer != other.n_layer) return true; + if (this->n_rot != other.n_rot) return true; + if (this->n_ff != other.n_ff) return true; + if (this->n_expert != other.n_expert) return true; + if (this->n_expert_used != other.n_expert_used) return true; + if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; @@ -1268,6 +1290,12 @@ struct llama_layer { struct ggml_tensor * ffn_down; // w2 struct ggml_tensor * ffn_up; // w3 + // ff MoE + struct ggml_tensor * ffn_gate_inp; + struct ggml_tensor * ffn_gate_exp[LLAMA_MAX_EXPERTS]; + struct ggml_tensor * ffn_down_exp[LLAMA_MAX_EXPERTS]; + struct ggml_tensor * ffn_up_exp [LLAMA_MAX_EXPERTS]; + // ff bias struct ggml_tensor * ffn_down_b; // b2 struct ggml_tensor * ffn_up_b; // b3 @@ -2440,6 +2468,16 @@ static void llm_load_hparams( ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false); + ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + + GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); + GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); + if (hparams.n_expert > 0) { + GGML_ASSERT(hparams.n_expert_used > 0); + } else { + GGML_ASSERT(hparams.n_expert_used == 0); + } // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; @@ -2871,6 +2909,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); @@ -3025,9 +3065,26 @@ static void llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); - layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_gate_inp = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, backend, false); + + if (layer.ffn_gate_inp == nullptr) { + GGML_ASSERT(hparams.n_expert == 0); + GGML_ASSERT(hparams.n_expert_used == 0); + + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + } else { + GGML_ASSERT(hparams.n_expert > 0); + GGML_ASSERT(hparams.n_expert_used > 0); + + // MoE branch + for (uint32_t x = 0; x < hparams.n_expert; ++x) { + layer.ffn_gate_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}, backend_split); + layer.ffn_down_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd}, backend_split); + layer.ffn_up_exp[x] = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}, backend_split); + } + } if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -3037,8 +3094,18 @@ static void llm_load_tensors( (layer.bk ? ggml_nbytes(layer.bk) : 0) + (layer.bv ? ggml_nbytes(layer.bv) : 0) + (layer.bo ? ggml_nbytes(layer.bo) : 0) + - ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_gate) + - ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); + ggml_nbytes(layer.ffn_norm); + + if (layer.ffn_gate_inp == nullptr) { + vram_weights += + ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); + } else { + vram_weights += ggml_nbytes(layer.ffn_gate_inp); + for (uint32_t x = 0; x < hparams.n_expert; ++x) { + vram_weights += + ggml_nbytes(layer.ffn_gate_exp[x]) + ggml_nbytes(layer.ffn_down_exp[x]) + ggml_nbytes(layer.ffn_up_exp[x]); + } + } } } } break; @@ -4019,6 +4086,8 @@ struct llm_build_context { const int64_t n_head_kv; const int64_t n_embd_head; const int64_t n_embd_gqa; + const int64_t n_expert; + const int64_t n_expert_used; const float freq_base; const float freq_scale; @@ -4060,6 +4129,8 @@ struct llm_build_context { n_head_kv (hparams.n_head_kv), n_embd_head (hparams.n_embd_head()), n_embd_gqa (hparams.n_embd_gqa()), + n_expert (hparams.n_expert), + n_expert_used (hparams.n_expert_used), freq_base (cparams.rope_freq_base), freq_scale (cparams.rope_freq_scale), ext_factor (cparams.yarn_ext_factor), @@ -4184,7 +4255,7 @@ struct llm_build_context { cb(ffn_inp, "ffn_inp", il); // feed-forward network - { + if (model.layers[il].ffn_gate_inp == nullptr) { cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); @@ -4196,6 +4267,69 @@ struct llm_build_context { model.layers[il].ffn_down, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] + cb(logits, "ffn_moe_logits", il); + + ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] + cb(probs, "ffn_moe_probs", il); + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + + ggml_tensor * weights = ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); + cb(weights, "ffn_moe_weights", il); + + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] + + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] + cb(weights, "ffn_moe_weights_norm", il); + + // compute expert outputs + ggml_tensor * moe_out = nullptr; + + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert; + + ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur); + cb(cur_up, "ffn_moe_up", il); + + ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur); + cb(cur_gate, "ffn_moe_gate", il); + + cur_gate = ggml_silu(ctx0, cur_gate); + cb(cur_gate, "ffn_moe_silu", il); + + cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd] + cb(cur_expert, "ffn_moe_gate_par", il); + + cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd] + cb(cur_expert, "ffn_moe_down", il); + + cur_expert = ggml_mul(ctx0, cur_expert, + ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); + cb(cur_expert, "ffn_moe_weighted", il); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx0, moe_out, cur_expert); + cb(moe_out, "ffn_moe_out", il); + } + } + + cur = moe_out; } cur = ggml_add(ctx0, cur, ffn_inp); @@ -5450,6 +5584,20 @@ static const std::unordered_map k_offload_map { "ffn_relu", OFFLOAD_FUNC }, { "ffn_sqr(relu)", OFFLOAD_FUNC }, + { "ffn_moe_logits", OFFLOAD_FUNC }, + { "ffn_moe_probs", OFFLOAD_FUNC }, + { "ffn_moe_argsort", OFFLOAD_FUNC }, + { "ffn_moe_weights", OFFLOAD_FUNC }, + { "ffn_moe_weights_sum", OFFLOAD_FUNC }, + { "ffn_moe_weights_norm", OFFLOAD_FUNC }, + { "ffn_moe_weighted", OFFLOAD_FUNC }, + { "ffn_moe_up", OFFLOAD_FUNC }, + { "ffn_moe_gate", OFFLOAD_FUNC }, + { "ffn_moe_silu", OFFLOAD_FUNC }, + { "ffn_moe_gate_par", OFFLOAD_FUNC }, + { "ffn_moe_down", OFFLOAD_FUNC }, + { "ffn_moe_out", OFFLOAD_FUNC }, + { "l_out", OFFLOAD_FUNC }, { "result_norm", OFFLOAD_FUNC_EMB }, @@ -8067,11 +8215,9 @@ static void llama_convert_tensor_internal( workers.clear(); } -static ggml_type get_k_quant_type( - quantize_state_internal & qs, - ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype -) { +static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { const std::string name = ggml_get_name(tensor); + // TODO: avoid hardcoded tensor names - use the TN_* constants const llm_arch arch = qs.model.arch; const auto tn = LLM_TN(arch); @@ -8105,7 +8251,18 @@ static ggml_type get_k_quant_type( // nearly negligible increase in model size by quantizing this tensor with more bits: if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K; } + if (qs.model.hparams.n_expert == 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = GGML_TYPE_Q8_0; + } ++qs.i_attention_wv; + } else if (name.find("attn_k.weight") != std::string::npos) { + if (qs.model.hparams.n_expert == 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = GGML_TYPE_Q8_0; + } } else if (name.find("ffn_down.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { @@ -8318,6 +8475,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s quantize &= params->quantize_output_tensor || name != "output.weight"; quantize &= !params->only_copy; + // do not quantize expert gating tensors + quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; + enum ggml_type new_type; void * new_data; size_t new_size; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e0155ac1c..44830b4d4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -20,8 +20,6 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m size_t size = ggml_nelements(tensor); std::vector data(size); - std::random_device rd; - #if 0 std::default_random_engine generator(rd()); std::uniform_real_distribution distribution(min, max); @@ -31,6 +29,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } #endif auto init_thread = [&](size_t start, size_t end) { + std::random_device rd; std::default_random_engine generator(rd()); std::uniform_real_distribution distribution(min, max); @@ -51,7 +50,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m t.join(); } - if (tensor->type == GGML_TYPE_F32) { + if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) { ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) { GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0); @@ -71,23 +70,28 @@ static std::vector tensor_to_float(const ggml_tensor * t) { std::vector buf(ggml_nbytes(t)); ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t)); + ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type); + size_t bs = ggml_blck_size(t->type); + // access elements by index to avoid gaps in views for (int64_t i3 = 0; i3 < t->ne[3]; i3++) { for (int64_t i2 = 0; i2 < t->ne[2]; i2++) { for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { - for (int64_t i0 = 0; i0 < t->ne[0]; i0++) { - size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0*t->nb[0]; - float v; + for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) { + size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; if (t->type == GGML_TYPE_F16) { - v = (float) ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]); + tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i])); } else if (t->type == GGML_TYPE_F32) { - v = *(float *) &buf[i]; + tv.push_back(*(float *) &buf[i]); } else if (t->type == GGML_TYPE_I32) { - v = *(int32_t *) &buf[i]; + tv.push_back((float)*(int32_t *) &buf[i]); + } else if (ggml_is_quantized(t->type)) { + std::vector vq(ggml_blck_size(t->type)); + tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type)); + tv.insert(tv.end(), vq.begin(), vq.end()); } else { GGML_ASSERT(false); } - tv.push_back(v); } } } @@ -233,6 +237,10 @@ static bool ggml_is_view_op(enum ggml_op op) { struct test_case { virtual ~test_case() {} + virtual std::string op_desc(ggml_tensor * t) { + return ggml_op_desc(t); + } + virtual std::string vars() { return ""; } @@ -240,7 +248,7 @@ struct test_case { virtual ggml_tensor * build_graph(ggml_context * ctx) = 0; virtual double max_nmse_err() { - return 1e-6; + return 1e-7; } virtual void initialize_tensors(ggml_context * ctx) { @@ -270,13 +278,13 @@ struct test_case { ggml_tensor * out = build_graph(ctx); - if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) { - //printf(" %s: skipping\n", ggml_op_desc(out)); + if (op_name != nullptr && op_desc(out) != op_name) { + //printf(" %s: skipping\n", op_desc(out).c_str()); ggml_free(ctx); return true; } - printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); + printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); fflush(stdout); // check if backends support op @@ -317,7 +325,7 @@ struct test_case { for (size_t i = 0; i < f1.size(); i++) { // check for nans if (std::isnan(f1[i]) || std::isnan(f2[i])) { - printf("NaN at index %zu ", i); + printf("[%s] NaN at index %zu (%f %f) ", ggml_op_desc(t1), i, f1[i], f2[i]); ud->ok = false; return true; } @@ -325,12 +333,12 @@ struct test_case { if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) { if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) { if (std::signbit(f1[i]) != std::signbit(f2[i])) { - printf("inf sign mismatch: %f %f ", f1[i], f2[i]); + printf("[%s] inf sign mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]); ud->ok = false; return true; } } else { - printf("inf mismatch: %f %f ", f1[i], f2[i]); + printf("[%s] inf mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]); ud->ok = false; return true; } @@ -339,10 +347,16 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { - printf("NMSE = %f ", err); + printf("[%s] NMSE = %f ", ggml_op_desc(t1), err); + //for (int i = 0; i < f1.size(); i++) { + // printf("(%f, %f) ", f1[i], f2[i]); + //} + //printf("\n"); ud->ok = false; } return true; + + GGML_UNUSED(index); }; ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud); @@ -372,13 +386,13 @@ struct test_case { ggml_tensor * out = build_graph(ctx); - if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) { - //printf(" %s: skipping\n", ggml_op_desc(out)); + if (op_name != nullptr && op_desc(out) != op_name) { + //printf(" %s: skipping\n", op_desc(out).c_str()); ggml_free(ctx); return true; } - int len = printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); + int len = printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); fflush(stdout); // check if backends support op @@ -430,8 +444,9 @@ struct test_case { return size; }; for (int i = 0; i < gf->n_nodes; i++) { - if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) + if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) { continue; + } mem += tensor_op_size(gf->nodes[i]); } @@ -486,17 +501,22 @@ struct test_get_rows : public test_case { const int n; // cols const int m; // rows const int r; // rows to get + const int b; // batch size + const bool v; // view (non-contiguous src1) std::string vars() override { - return VARS_TO_STR4(type, n, m, r); + return VARS_TO_STR6(type, n, m, r, b, v); } - test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3) - : type(type), n(n), m(m), r(r) {} + test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false) + : type(type), n(n), m(m), r(r), b(b), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m); - ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r); + ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b); + ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b); + if (v) { + rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0); + } ggml_tensor * out = ggml_get_rows(ctx, in, rows); return out; } @@ -504,12 +524,13 @@ struct test_get_rows : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } // rows - std::vector data(r); - for (int i = 0; i < r; i++) { + std::vector data(r*b); + for (int i = 0; i < r*b; i++) { data[i] = rand() % m; } - ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int)); + ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int)); } else { init_tensor_uniform(t); } @@ -770,11 +791,10 @@ struct test_mul_mat_id : public test_case { const int64_t m; const int64_t n; const int64_t k; - const std::array bs; // dims 3 and 4 - const std::array nr; // repeat in dims 3 and 4 + const bool v; // view (non-contiguous ids) std::string vars() override { - return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr); + return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v); } double max_nmse_err() override { @@ -782,7 +802,7 @@ struct test_mul_mat_id : public test_case { } size_t op_size(ggml_tensor * t) override { - size_t a = ggml_nbytes(t->src[2]) * n * nr[0] * nr[1]; + size_t a = ggml_nbytes(t->src[2]) * n; size_t b = ggml_nbytes(t->src[1]) * m; size_t c = ggml_nbytes(t); return a + b + c; @@ -792,35 +812,41 @@ struct test_mul_mat_id : public test_case { test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, int n_mats = 2, int id = 0, - int64_t m = 32, int64_t n = 32, int64_t k = 32, - std::array bs = {10, 10}, - std::array nr = {2, 2}) + int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false) : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id), - m(m), n(n), k(k), bs(bs), nr(nr) {} + m(m), n(n), k(k), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { // C^T = A * B^T: (k, m) * (k, n) => (m, n) std::vector mats; for (int i = 0; i < n_mats; i++) { - ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); + ggml_tensor * a = ggml_new_tensor_2d(ctx, type_a, k, m); mats.push_back(a); } - ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats); - ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); - ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n); + if (v) { + ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0); + } + ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n); + ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), n_mats, ids, v ? id/2 : id, b); return out; } void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } // ids - std::vector data(n_mats); - for (int i = 0; i < n_mats; i++) { - data[i] = i; + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i % n_mats; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); } - std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()())); - ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int)); } else { init_tensor_uniform(t); } @@ -1109,6 +1135,90 @@ struct test_sum_rows : public test_case { } }; +// Mixtral MOE +struct test_moe : public test_case { + const int n_experts; + const int n_experts_per_tok; + const int n_tokens; + const int n_embd; + const int n_ff; + + std::string op_desc(ggml_tensor * t) override { + return "MOE"; + + GGML_UNUSED(t); + } + + std::string vars() override { + return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff); + } + + test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336) + : n_experts(n_experts), n_experts_per_tok(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) { + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts); + + std::vector ffn_up_exp(n_experts); + std::vector ffn_gate_exp(n_experts); + std::vector ffn_down_exp(n_experts); + + for (int i = 0; i < n_experts; ++i) { + ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); + ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); + ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); + } + + ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens); + + ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); + ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd)); + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok); + + ggml_tensor * weights = ggml_get_rows(ctx, + ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts); + + weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens); + + ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); + + weights = ggml_div(ctx, weights, weights_sum); + + // compute expert outputs + ggml_tensor * moe_out = nullptr; + + for (int i = 0; i < n_experts_per_tok; ++i) { + ggml_tensor * cur_expert; + + ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur); + + ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur); + + cur_gate = ggml_silu(ctx, cur_gate); + + cur_expert = ggml_mul(ctx, cur_up, cur_gate); + + cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert); + + cur_expert = ggml_mul(ctx, cur_expert, + ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx, moe_out, cur_expert); + } + } + + cur = moe_out; + + return cur; + } +}; + enum test_mode { MODE_TEST, MODE_PERF, @@ -1117,14 +1227,28 @@ enum test_mode { static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { std::vector> test_cases; + const ggml_type all_types[] = { + GGML_TYPE_F32, GGML_TYPE_F16, + GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, + GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_Q8_0, + GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, + GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K + }; + // unary ops for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { test_cases.emplace_back(new test_unary((ggml_unary_op) op)); } - for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - test_cases.emplace_back(new test_get_rows(type, 10, 5, 3)); - test_cases.emplace_back(new test_get_rows(type, 16, 5, 3)); + test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false)); + for (ggml_type type : all_types) { + for (int b : {1, 7}) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, v)); + } + } } test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); @@ -1134,7 +1258,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2})); test_cases.emplace_back(new test_dup()); - test_cases.emplace_back(new test_cpy()); + + for (ggml_type type : all_types) { + test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, type, {256, 10, 10, 1})); + } + test_cases.emplace_back(new test_cont()); auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr) { @@ -1144,6 +1272,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op }; add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1}); @@ -1170,8 +1299,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1}); - add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1}); + //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1}); + //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1}); test_cases.emplace_back(new test_scale()); @@ -1180,16 +1309,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps)); } - const ggml_type all_types[] = { - GGML_TYPE_F32, GGML_TYPE_F16, - GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, - GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, - GGML_TYPE_Q8_0, - GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, - GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, - GGML_TYPE_Q6_K - }; - for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { // FIXME: CPU crashes on f16xf16 @@ -1213,9 +1332,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { - for (int n_mats : {1, 2, 4}) { + for (int n_mats : {2, 4, 8}) { for (int id = 0; id < n_mats; id++) { - test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1})); + for (bool v : {false, true}) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v)); + } } } } @@ -1247,10 +1368,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_concat()); for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) { + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); } - test_cases.emplace_back(new test_sum_rows()); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {10, 10, 10, 10})); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {2, 1, 1, 1})); + +#if !defined(__SANITIZE_THREAD__) + // FIXME: these tests use too much memory with thread sanitizer + test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 14336)); + //test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336)); +#endif // run tests if (mode == MODE_TEST) { @@ -1267,14 +1396,17 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op ggml_backend_free(backend_cpu); return n_ok == test_cases.size(); - } else if (mode == MODE_PERF) { + } + + if (mode == MODE_PERF) { for (auto & test : test_cases) { test->eval_perf(backend, op_name); } return true; - } else { - GGML_ASSERT(false); } + + GGML_ASSERT(false); + return false; } static void usage(char ** argv) { @@ -1347,11 +1479,12 @@ int main(int argc, char ** argv) { } printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count()); + if (n_ok != ggml_backend_reg_get_count()) { printf("\033[1;31mFAIL\033[0m\n"); return 1; - } else { - printf("\033[1;32mOK\033[0m\n"); - return 0; } + + printf("\033[1;32mOK\033[0m\n"); + return 0; }