diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8c2712308..2e759d43e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -10039,14 +10039,22 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten } return false; } break; + case GGML_OP_DUP: + case GGML_OP_REPEAT: + case GGML_OP_CONCAT: + { + ggml_type src0_type = op->src[0]->type; + if (src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16) { + return true; + } + return false; + } break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_NORM: - case GGML_OP_REPEAT: - case GGML_OP_DUP: case GGML_OP_ADD: case GGML_OP_MUL: case GGML_OP_DIV: @@ -10063,7 +10071,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: - case GGML_OP_CONCAT: case GGML_OP_GROUP_NORM: case GGML_OP_UPSCALE: case GGML_OP_PAD: