Add support for cohere endpoints (#976)

* Add support for cohere endpoints

* readme update
pull/977/head
Nathan Sarrazin 2024-04-04 16:02:28 +02:00 committed by GitHub
parent 0819256ea4
commit a99cca3f95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 352 additions and 0 deletions

1
.env
View File

@ -13,6 +13,7 @@ OPENAI_API_KEY=#your openai api key here
ANTHROPIC_API_KEY=#your anthropic api key here
CLOUDFLARE_ACCOUNT_ID=#your cloudflare account id here
CLOUDFLARE_API_TOKEN=#your cloudflare api token here
COHERE_API_TOKEN=#your cohere api token here
HF_ACCESS_TOKEN=#LEGACY! Use HF_TOKEN instead

View File

@ -560,6 +560,28 @@ You can find the list of models available on Cloudflare [here](https://developer
> [!NOTE]
> Cloudlare Workers AI currently do not support custom sampling parameters like temperature, top_p, etc.
#### Cohere
You can also use Cohere to run their models directly from chat-ui. You will need to have a Cohere account, then get your [API token](https://dashboard.cohere.com/api-keys). You can either specify it directly in your `.env.local` using the `COHERE_API_TOKEN` variable, or you can set it in the endpoint config.
Here is an example of a Cohere model config. You can set which model you want to use by setting the `id` field to the model name.
```env
{
"name" : "CohereForAI/c4ai-command-r-v01",
"id": "command-r",
"description": "C4AI Command-R is a research release of a 35 billion parameter highly performant generative model",
"endpoints": [
{
"type": "cohere",
<!-- optionally specify these, or use COHERE_API_TOKEN
"apiKey": "your-api-token"
-->
}
]
}
```
##### Google Vertex models
Chat UI can connect to the google Vertex API endpoints ([List of supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models)).

209
package-lock.json generated
View File

@ -75,6 +75,7 @@
"@anthropic-ai/sdk": "^0.17.1",
"@google-cloud/vertexai": "^0.5.0",
"aws4fetch": "^1.0.17",
"cohere-ai": "^7.9.0",
"openai": "^4.14.2"
}
},
@ -2881,6 +2882,25 @@
"node": ">=8"
}
},
"node_modules/call-bind": {
"version": "1.0.7",
"resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz",
"integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==",
"optional": true,
"dependencies": {
"es-define-property": "^1.0.0",
"es-errors": "^1.3.0",
"function-bind": "^1.1.2",
"get-intrinsic": "^1.2.4",
"set-function-length": "^1.2.1"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/callsites": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz",
@ -3041,6 +3061,19 @@
"@types/estree": "^1.0.0"
}
},
"node_modules/cohere-ai": {
"version": "7.9.0",
"resolved": "https://registry.npmjs.org/cohere-ai/-/cohere-ai-7.9.0.tgz",
"integrity": "sha512-iHPG4dule+nMlw88Xe0USGZbLlXuRC4yvOvfCqoEdW9tHOc0vkiPfiyjBalDcPKj9KEvWZfii84kyN5HyTMySw==",
"optional": true,
"dependencies": {
"form-data": "4.0.0",
"js-base64": "3.7.2",
"node-fetch": "2.7.0",
"qs": "6.11.2",
"url-join": "4.0.1"
}
},
"node_modules/color": {
"version": "4.2.3",
"resolved": "https://registry.npmjs.org/color/-/color-4.2.3.tgz",
@ -3358,6 +3391,23 @@
"node": ">=0.10.0"
}
},
"node_modules/define-data-property": {
"version": "1.1.4",
"resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz",
"integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==",
"optional": true,
"dependencies": {
"es-define-property": "^1.0.0",
"es-errors": "^1.3.0",
"gopd": "^1.0.1"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/delayed-stream": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz",
@ -3508,6 +3558,27 @@
"url": "https://github.com/fb55/entities?sponsor=1"
}
},
"node_modules/es-define-property": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz",
"integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==",
"optional": true,
"dependencies": {
"get-intrinsic": "^1.2.4"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/es-errors": {
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz",
"integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==",
"optional": true,
"engines": {
"node": ">= 0.4"
}
},
"node_modules/es6-promise": {
"version": "3.3.1",
"resolved": "https://registry.npmjs.org/es6-promise/-/es6-promise-3.3.1.tgz",
@ -4107,6 +4178,25 @@
"node": "*"
}
},
"node_modules/get-intrinsic": {
"version": "1.2.4",
"resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz",
"integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==",
"optional": true,
"dependencies": {
"es-errors": "^1.3.0",
"function-bind": "^1.1.2",
"has-proto": "^1.0.1",
"has-symbols": "^1.0.3",
"hasown": "^2.0.0"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/get-stream": {
"version": "6.0.1",
"resolved": "https://registry.npmjs.org/get-stream/-/get-stream-6.0.1.tgz",
@ -4219,6 +4309,18 @@
"node": ">=14"
}
},
"node_modules/gopd": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz",
"integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==",
"optional": true,
"dependencies": {
"get-intrinsic": "^1.1.3"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/graceful-fs": {
"version": "4.2.11",
"resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz",
@ -4284,6 +4386,42 @@
"node": ">=8"
}
},
"node_modules/has-property-descriptors": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz",
"integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==",
"optional": true,
"dependencies": {
"es-define-property": "^1.0.0"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/has-proto": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz",
"integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==",
"optional": true,
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/has-symbols": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz",
"integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==",
"optional": true,
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/hash-wasm": {
"version": "4.9.0",
"resolved": "https://registry.npmjs.org/hash-wasm/-/hash-wasm-4.9.0.tgz",
@ -4625,6 +4763,12 @@
"url": "https://github.com/sponsors/panva"
}
},
"node_modules/js-base64": {
"version": "3.7.2",
"resolved": "https://registry.npmjs.org/js-base64/-/js-base64-3.7.2.tgz",
"integrity": "sha512-NnRs6dsyqUXejqk/yv2aiXlAvOs56sLkX6nUdeaNezI5LFFLlsZjOThmwnrcwh5ZZRwZlCMnVAY3CvhIhoVEKQ==",
"optional": true
},
"node_modules/js-sdsl": {
"version": "4.3.0",
"resolved": "https://registry.npmjs.org/js-sdsl/-/js-sdsl-4.3.0.tgz",
@ -5426,6 +5570,15 @@
"node": ">= 6"
}
},
"node_modules/object-inspect": {
"version": "1.13.1",
"resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.1.tgz",
"integrity": "sha512-5qoj1RUiKOMsCCNLV1CBiPYE10sziTsnmNxkAI/rZhiD63CF7IqdFGC/XzjWjpSgLf0LxXX3bDFIh0E18f6UhQ==",
"optional": true,
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/object-stream": {
"version": "0.0.1",
"resolved": "https://registry.npmjs.org/object-stream/-/object-stream-0.0.1.tgz",
@ -6225,6 +6378,21 @@
"teleport": ">=0.2.0"
}
},
"node_modules/qs": {
"version": "6.11.2",
"resolved": "https://registry.npmjs.org/qs/-/qs-6.11.2.tgz",
"integrity": "sha512-tDNIz22aBzCDxLtVH++VnTfzxlfeK5CbqohpSqpJgj1Wg/cQbStNAz3NuqCs5vV+pjBsK4x4pN9HlVh7rcYRiA==",
"optional": true,
"dependencies": {
"side-channel": "^1.0.4"
},
"engines": {
"node": ">=0.6"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/querystringify": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/querystringify/-/querystringify-2.2.0.tgz",
@ -6557,6 +6725,23 @@
"integrity": "sha512-RVnVQxTXuerk653XfuliOxBP81Sf0+qfQE73LIYKcyMYHG94AuH0kgrQpRDuTZnSmjpysHmzxJXKNfa6PjFhyQ==",
"dev": true
},
"node_modules/set-function-length": {
"version": "1.2.2",
"resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz",
"integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==",
"optional": true,
"dependencies": {
"define-data-property": "^1.1.4",
"es-errors": "^1.3.0",
"function-bind": "^1.1.2",
"get-intrinsic": "^1.2.4",
"gopd": "^1.0.1",
"has-property-descriptors": "^1.0.2"
},
"engines": {
"node": ">= 0.4"
}
},
"node_modules/sharp": {
"version": "0.33.2",
"resolved": "https://registry.npmjs.org/sharp/-/sharp-0.33.2.tgz",
@ -6617,6 +6802,24 @@
"node": ">=8"
}
},
"node_modules/side-channel": {
"version": "1.0.6",
"resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz",
"integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==",
"optional": true,
"dependencies": {
"call-bind": "^1.0.7",
"es-errors": "^1.3.0",
"get-intrinsic": "^1.2.4",
"object-inspect": "^1.13.1"
},
"engines": {
"node": ">= 0.4"
},
"funding": {
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/siginfo": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/siginfo/-/siginfo-2.0.0.tgz",
@ -7647,6 +7850,12 @@
"punycode": "^2.1.0"
}
},
"node_modules/url-join": {
"version": "4.0.1",
"resolved": "https://registry.npmjs.org/url-join/-/url-join-4.0.1.tgz",
"integrity": "sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==",
"optional": true
},
"node_modules/url-parse": {
"version": "1.5.10",
"resolved": "https://registry.npmjs.org/url-parse/-/url-parse-1.5.10.tgz",

View File

@ -85,6 +85,7 @@
"@anthropic-ai/sdk": "^0.17.1",
"@google-cloud/vertexai": "^0.5.0",
"aws4fetch": "^1.0.17",
"cohere-ai": "^7.9.0",
"openai": "^4.14.2"
}
}

View File

@ -0,0 +1,114 @@
import { z } from "zod";
import { COHERE_API_TOKEN } from "$env/static/private";
import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type { Cohere, CohereClient } from "cohere-ai";
import { buildPrompt } from "$lib/buildPrompt";
export const endpointCohereParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("cohere"),
apiKey: z.string().default(COHERE_API_TOKEN),
raw: z.boolean().default(false),
});
export async function endpointCohere(
input: z.input<typeof endpointCohereParametersSchema>
): Promise<Endpoint> {
const { apiKey, model, raw } = endpointCohereParametersSchema.parse(input);
let cohere: CohereClient;
try {
cohere = new (await import("cohere-ai")).CohereClient({
token: apiKey,
});
} catch (e) {
throw new Error("Failed to import @anthropic-ai/sdk", { cause: e });
}
return async ({ messages, preprompt, generateSettings, continueMessage }) => {
let system = preprompt;
if (messages?.[0]?.from === "system") {
system = messages[0].content;
}
const parameters = { ...model.parameters, ...generateSettings };
return (async function* () {
let stream;
let tokenId = 0;
if (raw) {
const prompt = await buildPrompt({
messages: messages.filter((message) => message.from !== "system"),
model,
preprompt: system,
continueMessage,
});
stream = await cohere.chatStream({
message: prompt,
rawPrompting: true,
model: model.id ?? model.name,
p: parameters?.top_p,
k: parameters?.top_k,
maxTokens: parameters?.max_new_tokens,
temperature: parameters?.temperature,
stopSequences: parameters?.stop,
frequencyPenalty: parameters?.frequency_penalty,
});
} else {
const formattedMessages = messages
.filter((message) => message.from !== "system")
.map((message) => ({
role: message.from === "user" ? "USER" : "CHATBOT",
message: message.content,
})) satisfies Cohere.ChatMessage[];
stream = await cohere.chatStream({
model: model.id ?? model.name,
chatHistory: formattedMessages.slice(0, -1),
message: formattedMessages[formattedMessages.length - 1].message,
preamble: system,
p: parameters?.top_p,
k: parameters?.top_k,
maxTokens: parameters?.max_new_tokens,
temperature: parameters?.temperature,
stopSequences: parameters?.stop,
frequencyPenalty: parameters?.frequency_penalty,
});
}
for await (const output of stream) {
if (output.eventType === "text-generation") {
yield {
token: {
id: tokenId++,
text: output.text,
logprob: 0,
special: false,
},
generated_text: null,
details: null,
} satisfies TextGenerationStreamOutput;
} else if (output.eventType === "stream-end") {
if (["ERROR", "ERROR_TOXIC", "ERROR_LIMIT"].includes(output.finishReason)) {
throw new Error(output.finishReason);
}
yield {
token: {
id: tokenId++,
text: "",
logprob: 0,
special: true,
},
generated_text: output.response.text,
details: null,
};
}
}
})();
};
}

View File

@ -16,6 +16,7 @@ import type { Model } from "$lib/types/Model";
import endpointCloudflare, {
endpointCloudflareParametersSchema,
} from "./cloudflare/endpointCloudflare";
import { endpointCohere, endpointCohereParametersSchema } from "./cohere/endpointCohere";
// parameters passed when generating text
export interface EndpointParameters {
@ -46,6 +47,7 @@ export const endpoints = {
ollama: endpointOllama,
vertex: endpointVertex,
cloudflare: endpointCloudflare,
cohere: endpointCohere,
};
export const endpointSchema = z.discriminatedUnion("type", [
@ -57,5 +59,6 @@ export const endpointSchema = z.discriminatedUnion("type", [
endpointOllamaParametersSchema,
endpointVertexParametersSchema,
endpointCloudflareParametersSchema,
endpointCohereParametersSchema,
]);
export default endpoints;

View File

@ -174,6 +174,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
return await endpoints.vertex(args);
case "cloudflare":
return await endpoints.cloudflare(args);
case "cohere":
return await endpoints.cohere(args);
default:
// for legacy reason
return endpoints.tgi(args);