Support dup & cont ops on CUDA (#2242)

This commit is contained in:
Jiahao Li 2023-07-18 01:39:29 +08:00 committed by GitHub
parent b7647436cc
commit 7568d1a2b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -3537,6 +3537,11 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
(void) dst; (void) dst;
} }
void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_cpy(src0, dst, nullptr);
(void) src1;
}
void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true); ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true);
@ -3670,7 +3675,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
// recursively assign CUDA buffers until a compute tensor is found // recursively assign CUDA buffers until a compute tensor is found
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) { if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
const ggml_op src0_op = tensor->src[0]->op; const ggml_op src0_op = tensor->src[0]->op;
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) { if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) {
ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace); ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
} }
} }
@ -3776,6 +3781,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU); || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
switch (tensor->op) { switch (tensor->op) {
case GGML_OP_DUP:
if (!any_on_device) {
return false;
}
func = ggml_cuda_dup;
break;
case GGML_OP_ADD: case GGML_OP_ADD:
if (!any_on_device) { if (!any_on_device) {
return false; return false;
@ -3830,6 +3841,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
} }
func = ggml_cuda_cpy; func = ggml_cuda_cpy;
break; break;
case GGML_OP_CONT:
if (!any_on_device) {
return false;
}
func = ggml_cuda_dup;
break;
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE: