diff --git a/ggml-metal.m b/ggml-metal.m index 912ddc8..4b3eb49 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -668,7 +668,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: - return ctx->support_simdgroup_reduction; + return ctx->support_simdgroup_reduction && + (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32); case GGML_OP_CPY: case GGML_OP_DUP: case GGML_OP_CONT: