From 7568d1a2b206331412106ea66da3f871025e0c3c Mon Sep 17 00:00:00 2001 From: Jiahao Li Date: Tue, 18 Jul 2023 01:39:29 +0800 Subject: [PATCH] Support dup & cont ops on CUDA (#2242) --- ggml-cuda.cu | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0646fa7b2..d3054a7fa 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3537,6 +3537,11 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens (void) dst; } +void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_cpy(src0, dst, nullptr); + (void) src1; +} + void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true); @@ -3670,7 +3675,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo // recursively assign CUDA buffers until a compute tensor is found if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) { const ggml_op src0_op = tensor->src[0]->op; - if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) { + if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) { ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace); } } @@ -3776,6 +3781,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU); switch (tensor->op) { + case GGML_OP_DUP: + if (!any_on_device) { + return false; + } + func = ggml_cuda_dup; + break; case GGML_OP_ADD: if (!any_on_device) { return false; @@ -3830,6 +3841,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ } func = ggml_cuda_cpy; break; + case GGML_OP_CONT: + if (!any_on_device) { + return false; + } + func = ggml_cuda_dup; + break; case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: