whisper : add full CUDA and Metal offloading (#1472)

* whisper : migrate to ggml-backend

* whisper : fix logit reading

* whisper : fix tensor allocation during load

* whisper : fix beam-search with CUDA

* whisper : free backends + fix compile warning

* whisper : print when CUDA is enabled

* whisper : fix CoreML

* make : clean-up

* talk : fix compile warning

* whisper : support ggml_conv with CUDA and Metal (#1473)

* ggml : add CUDA support for ggml_conv

* whisper : remove ggml_repeat for conv bias + single backend

* cuda : fix im2col kernel

* metal : add im2col support + mul mat-vec f16 x f16

* bench-all : add q4 models

* whisper : clean-up

* quantize-all : fix

* ggml : im2col opts

* whisper : avoid whisper_model_data wrapper

* whisper : add note that ggml_mul_mat_pad does not work with CUDA

* whisper : factor out graph compute in common function

* whisper : fixes

* whisper : fix UB with measure buffers

* whisper : try to fix the parallel whisper_state functionality (#1479)

* whisper : try to fix the parallel whisper_state functionality

* whisper : fix multi-state Metal

* whisper : free backend instances in whisper_state
pull/1375/merge
Georgi Gerganov 2023-11-12 15:31:08 +02:00 committed by GitHub
parent ec7a6f04f9
commit b0502836b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 916 additions and 1594 deletions

1
.gitignore vendored
View File

@ -8,6 +8,7 @@
.DS_Store
build/
build-coreml/
build-em/
build-debug/
build-release/

View File

@ -307,7 +307,7 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
$(CC) $(CFLAGS) -c $< -o $@
WHISPER_OBJ += ggml-alloc.o ggml-backend.o ggml-quants.o
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
$(CXX) $(CXXFLAGS) -c $< -o $@
@ -331,11 +331,11 @@ ggml-metal.o: ggml-metal.m ggml-metal.h
WHISPER_OBJ += ggml-metal.o
endif
libwhisper.a: ggml.o $(WHISPER_OBJ)
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
libwhisper.a: $(WHISPER_OBJ)
$(AR) rcs libwhisper.a $(WHISPER_OBJ)
libwhisper.so: ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
libwhisper.so: $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
clean:
rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
@ -349,30 +349,30 @@ CC_SDL=`sdl2-config --cflags --libs`
SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
SRC_COMMON_SDL = examples/common-sdl.cpp
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
main: examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o main $(LDFLAGS)
./main -h
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
bench: examples/bench/bench.cpp $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp $(WHISPER_OBJ) -o bench $(LDFLAGS)
quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS)
quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
#
# Audio samples

View File

@ -181,7 +181,7 @@ private:
// It is assumed that PCM data is normalized to a range from -1 to 1
bool write_audio(const float * data, size_t length) {
for (size_t i = 0; i < length; ++i) {
const auto intSample = static_cast<const int16_t>(data[i] * 32767);
const int16_t intSample = data[i] * 32767;
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
dataSize += sizeof(int16_t);
}

View File

@ -121,13 +121,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
return false;
}
std::string word;
char word[129];
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
word.resize(len);
fin.read((char *) word.data(), len);
word[len] = '\0';
fin.read((char *) word, len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;

View File

@ -18,11 +18,11 @@ else
fi
models=( \
"tiny" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
"base" "base-q5_0" "base-q5_1" "base-q8_0" \
"small" "small-q5_0" "small-q5_1" "small-q8_0" \
"medium" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
"large" "large-q5_0" "large-q5_1" "large-q8_0" \
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
"small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
"medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
"large" "large-q4_0" "large-q4_1" "large-q5_0" "large-q5_1" "large-q8_0" \
)
if [ "$encoder_only" -eq 0 ]; then
@ -83,6 +83,10 @@ for model in "${models[@]}"; do
config="$config COREML"
fi
if [[ $system_info == *"CUDA = 1"* ]]; then
config="$config CUDA"
fi
if [[ $system_info == *"METAL = 1"* ]]; then
config="$config METAL"
fi

View File

@ -15,33 +15,13 @@ declare -a filedex
cd `dirname $0`
cd ../
# Let's loop across all the objects in the 'models' dir:
for i in ./models/*; do
# Check to see if it's a file or directory
if [ -d "$i" ]; then
# It's a directory! We should make sure it's not empty first:
if [ "$(ls -A $i)" ]; then
# Passed! Let's go searching for bin files (shouldn't need to go more than a layer deep here)
for f in "$i"/*.bin; do
# [Neuron Activation]
newfile=`echo "${f##*/}" | cut -d _ -f 1`;
if [ "$newfile" != "q5" ]; then
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" ${qtype1};
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" ${qtype0};
filedex+=( "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" )
fi
done
fi
else
# It's a file! Let's make sure it's the right type:
if [ "${i##*.}" == "bin" ]; then
# And we probably want to skip the testing files
if [ "${i:9:8}" != "for-test" ]; then
# [Neuron Activation]
./quantize "${i}" "${i:-4}-${qtype1}.bin" ${qtype1};
./quantize "${i}" "${i:-4}-${qtype0}.bin" ${qtype0};
filedex+=( "${i:-4}-${qtype1}.bin" "${i:-4}-${qtype0}.bin" )
fi
for i in `ls ./models | grep ^ggml-.*.bin | grep -v "\-q"`; do
m="models/$i"
if [ -f "$m" ]; then
if [ "${m##*.}" == "bin" ]; then
./quantize "${m}" "${m::${#m}-4}-${qtype1}.bin" ${qtype1};
./quantize "${m}" "${m::${#m}-4}-${qtype0}.bin" ${qtype0};
filedex+=( "${m::${#m}-4}-${qtype1}.bin" "${m::${#m}-4}-${qtype0}.bin" )
fi
fi
done

View File

@ -4476,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
*dsti = __float2half(*xi);
}
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
const half * xi = (const half *) cxi;
half * dsti = (half *) cdsti;
*dsti = *xi;
}
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@ -4729,6 +4736,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
static __global__ void im2col_f32_f16(
const float * x, half * dst,
int ofs0, int ofs1, int IW, int IH, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
const int offset_dst =
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = __float2half(0.0f);
} else {
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
}
}
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@ -5618,6 +5644,16 @@ static void ggml_cpy_f32_f16_cuda(
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@ -5701,6 +5737,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
}
static void im2col_f32_f16_cuda(const float * x, half * dst,
int OH, int IW, int IH, int OW, int IC,
int KH, int KW, int N, int ofs0, int ofs1,
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
dim3 block_nums(IC, OH, OW);
dim3 block_dims(N, KH, KW);
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
// buffer pool for cuda
#define MAX_CUDA_BUFFERS 256
@ -6483,7 +6528,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
size_t dst_f16_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
@ -6659,6 +6704,45 @@ inline void ggml_cuda_op_alibi(
(void) src1_dd;
}
inline void ggml_cuda_op_im2col(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int64_t N = src1->ne[is_2D ? 3 : 2];
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t KH = is_2D ? src0->ne[1] : 1;
const int64_t KW = src0->ne[0];
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
OH, IW, IH, OW, IC, KH, KW, N,
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
(void) src0;
(void) src0_dd;
}
inline void ggml_cuda_op_diag_mask_inf(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@ -7549,6 +7633,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@ -7580,6 +7667,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
}
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
}
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
(void) src0;
(void) src1;
@ -7943,6 +8034,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_ALIBI:
func = ggml_cuda_alibi;
break;
case GGML_OP_IM2COL:
func = ggml_cuda_im2col;
break;
default:
return false;
}

View File

@ -26,7 +26,7 @@
#include <stdbool.h>
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 16
#define GGML_METAL_MAX_BUFFERS 64
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
struct ggml_tensor;

View File

@ -86,6 +86,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@ -114,6 +115,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(rope_f32);
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(im2col_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@ -287,6 +289,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@ -317,6 +320,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(im2col_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
GGML_METAL_DEL_KERNEL(alibi_f32);
GGML_METAL_DEL_KERNEL(im2col_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@ -473,6 +479,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
const int64_t tsize = ggml_nbytes(t);
if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
ctx = t->buffer->backend->context;
}
// find the view that contains the tensor fully
for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
@ -1139,6 +1149,7 @@ void ggml_metal_graph_compute(
switch (src0t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
nrows = 4;
} break;
@ -1146,13 +1157,18 @@ void ggml_metal_graph_compute(
{
nth0 = 32;
nth1 = 1;
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
nrows = ne11;
if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
nrows = ne11;
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
nrows = 4;
}
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
nrows = 4;
}
} break;
@ -1464,6 +1480,58 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_IM2COL:
{
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
const int32_t N = src1->ne[is_2D ? 3 : 2];
const int32_t IC = src1->ne[is_2D ? 2 : 1];
const int32_t IH = is_2D ? src1->ne[1] : 1;
const int32_t IW = src1->ne[0];
const int32_t KH = is_2D ? src0->ne[1] : 1;
const int32_t KW = src0->ne[0];
const int32_t OH = is_2D ? dst->ne[2] : 1;
const int32_t OW = dst->ne[1];
const int32_t CHW = IC * KH * KW;
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
default: GGML_ASSERT(false);
};
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:

View File

@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F32_F32;
@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
}
}
#define N_F16_F16 4
kernel void kernel_mul_mv_f16_f16(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F16_F16;
const int64_t im = tgpig.z;
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
if (ne00 < 128) {
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (half) x[i] * (half) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const half4 * x4 = (device const half4 *)x;
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
device const half4 * y4 = (device const half4 *) y;
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}
kernel void kernel_mul_mv_f16_f32_1row(
device const char * src0,
device const char * src1,
@ -1229,6 +1302,39 @@ kernel void kernel_rope(
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
kernel void kernel_im2col_f16(
device const float * x,
device half * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
const int32_t offset_dst =
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,

1056
ggml.c

File diff suppressed because it is too large Load Diff

19
ggml.h
View File

@ -403,13 +403,8 @@ extern "C" {
GGML_OP_ROPE_BACK,
GGML_OP_ALIBI,
GGML_OP_CLAMP,
GGML_OP_CONV_1D,
GGML_OP_CONV_1D_STAGE_0, // internal
GGML_OP_CONV_1D_STAGE_1, // internal
GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_CONV_2D,
GGML_OP_CONV_2D_STAGE_0, // internal
GGML_OP_CONV_2D_STAGE_1, // internal
GGML_OP_IM2COL,
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
@ -1398,6 +1393,18 @@ extern "C" {
float min,
float max);
GGML_API struct ggml_tensor * ggml_im2col(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1,
bool is_2D);
GGML_API struct ggml_tensor * ggml_conv_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,8 @@
#ifndef WHISPER_H
#define WHISPER_H
#include "ggml.h"
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
@ -110,15 +112,15 @@ extern "C" {
// Various functions for loading a ggml whisper model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
WHISPER_API struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params);
// These are the same as the above, but the internal state of the context is not allocated automatically
// It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params);
WHISPER_DEPRECATED(
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model),
@ -570,8 +572,7 @@ extern "C" {
// Control logging output; default behavior is to print to stderr
typedef void (*whisper_log_callback)(const char * line);
WHISPER_API void whisper_set_log_callback(whisper_log_callback callback);
WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
#ifdef __cplusplus
}