Use jinja template for chat formatting (#730) (#744)

* Use jinja template for chat formatting

* Add support for transformers js chat template

* update to latest transformers version

* Make sure to `add_generation_prompt`

* unindent
pull/977/head
Nathan Sarrazin 2024-04-04 12:38:54 +02:00 committed by GitHub
parent e02792165b
commit 0819256ea4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 102 additions and 58 deletions

View File

@ -64,7 +64,7 @@ MODELS=`[
"description": "The latest and biggest model from Meta, fine-tuned for chat.",
"logoUrl": "https://huggingface.co/datasets/huggingchat/models-logo/resolve/main/meta-logo.png",
"websiteUrl": "https://ai.meta.com/llama/",
"preprompt": " ",
"preprompt": "",
"chatPromptTemplate" : "<s>[INST] <<SYS>>\n{{preprompt}}\n<</SYS>>\n\n{{#each messages}}{{#ifUser}}{{content}} [/INST] {{/ifUser}}{{#ifAssistant}}{{content}} </s><s>[INST] {{/ifAssistant}}{{/each}}",
"promptExamples": [
{

17
package-lock.json generated
View File

@ -12,7 +12,7 @@
"@huggingface/inference": "^2.6.3",
"@iconify-json/bi": "^1.1.21",
"@resvg/resvg-js": "^2.6.0",
"@xenova/transformers": "^2.6.0",
"@xenova/transformers": "^2.16.1",
"autoprefixer": "^10.4.14",
"browser-image-resizer": "^2.4.1",
"date-fns": "^2.29.3",
@ -660,6 +660,14 @@
"node": ">=18"
}
},
"node_modules/@huggingface/jinja": {
"version": "0.2.2",
"resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.2.2.tgz",
"integrity": "sha512-/KPde26khDUIPkTGU82jdtTW9UAuvUTumCAbFs/7giR0SxsvZC4hru51PBvpijH6BVkHcROcvZM/lpy5h1jRRA==",
"engines": {
"node": ">=18"
}
},
"node_modules/@humanwhocodes/config-array": {
"version": "0.11.8",
"resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.8.tgz",
@ -2407,10 +2415,11 @@
}
},
"node_modules/@xenova/transformers": {
"version": "2.6.0",
"resolved": "https://registry.npmjs.org/@xenova/transformers/-/transformers-2.6.0.tgz",
"integrity": "sha512-k9bs+reiwhn+kx0d4FYnlBTWtl8D5Q4fIzoKYxKbTTSVyS33KXbQESRpdIxiU9gtlMKML2Sw0Oep4FYK9dQCsQ==",
"version": "2.16.1",
"resolved": "https://registry.npmjs.org/@xenova/transformers/-/transformers-2.16.1.tgz",
"integrity": "sha512-p2ii7v7oC3Se0PC012dn4vt196GCroaN5ngOYJYkfg0/ce8A5frsrnnnktOBJuejG3bW5Hreb7JZ/KxtUaKd8w==",
"dependencies": {
"@huggingface/jinja": "^0.2.2",
"onnxruntime-web": "1.14.0",
"sharp": "^0.32.0"
},

View File

@ -54,7 +54,7 @@
"@huggingface/inference": "^2.6.3",
"@iconify-json/bi": "^1.1.21",
"@resvg/resvg-js": "^2.6.0",
"@xenova/transformers": "^2.6.0",
"@xenova/transformers": "^2.16.1",
"autoprefixer": "^10.4.14",
"browser-image-resizer": "^2.4.1",
"date-fns": "^2.29.3",
@ -83,8 +83,8 @@
},
"optionalDependencies": {
"@anthropic-ai/sdk": "^0.17.1",
"@google-cloud/vertexai": "^0.5.0",
"aws4fetch": "^1.0.17",
"openai": "^4.14.2",
"@google-cloud/vertexai": "^0.5.0"
"openai": "^4.14.2"
}
}

View File

@ -1,6 +1,7 @@
<script lang="ts">
import type { Model } from "$lib/types/Model";
import { AutoTokenizer, PreTrainedTokenizer } from "@xenova/transformers";
import { getTokenizer } from "$lib/utils/getTokenizer";
import type { PreTrainedTokenizer } from "@xenova/transformers";
export let classNames = "";
export let prompt = "";
@ -9,23 +10,6 @@
let tokenizer: PreTrainedTokenizer | undefined = undefined;
async function getTokenizer(_modelTokenizer: Exclude<Model["tokenizer"], undefined>) {
if (typeof _modelTokenizer === "string") {
// return auto tokenizer
return await AutoTokenizer.from_pretrained(_modelTokenizer);
}
{
// construct & return pretrained tokenizer
const { tokenizerUrl, tokenizerConfigUrl } = _modelTokenizer satisfies {
tokenizerUrl: string;
tokenizerConfigUrl: string;
};
const tokenizerJSON = await (await fetch(tokenizerUrl)).json();
const tokenizerConfig = await (await fetch(tokenizerConfigUrl)).json();
return new PreTrainedTokenizer(tokenizerJSON, tokenizerConfig);
}
}
async function tokenizeText(_prompt: string) {
if (!tokenizer) {
return;

View File

@ -1,6 +1,6 @@
import { z } from "zod";
import type { EmbeddingEndpoint } from "../embeddingEndpoints";
import type { Tensor, Pipeline } from "@xenova/transformers";
import type { Tensor, FeatureExtractionPipeline } from "@xenova/transformers";
import { pipeline } from "@xenova/transformers";
export const embeddingEndpointTransformersJSParametersSchema = z.object({
@ -11,9 +11,9 @@ export const embeddingEndpointTransformersJSParametersSchema = z.object({
// Use the Singleton pattern to enable lazy construction of the pipeline.
class TransformersJSModelsSingleton {
static instances: Array<[string, Promise<Pipeline>]> = [];
static instances: Array<[string, Promise<FeatureExtractionPipeline>]> = [];
static async getInstance(modelName: string): Promise<Pipeline> {
static async getInstance(modelName: string): Promise<FeatureExtractionPipeline> {
const modelPipelineInstance = this.instances.find(([name]) => name === modelName);
if (modelPipelineInstance) {

View File

@ -14,7 +14,10 @@ import endpointTgi from "./endpoints/tgi/endpointTgi";
import { sum } from "$lib/utils/sum";
import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";
import type { PreTrainedTokenizer } from "@xenova/transformers";
import JSON5 from "json5";
import { getTokenizer } from "$lib/utils/getTokenizer";
type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;
@ -39,23 +42,9 @@ const modelConfig = z.object({
.optional(),
datasetName: z.string().min(1).optional(),
datasetUrl: z.string().url().optional(),
userMessageToken: z.string().default(""),
userMessageEndToken: z.string().default(""),
assistantMessageToken: z.string().default(""),
assistantMessageEndToken: z.string().default(""),
messageEndToken: z.string().default(""),
preprompt: z.string().default(""),
prepromptUrl: z.string().url().optional(),
chatPromptTemplate: z
.string()
.default(
"{{preprompt}}" +
"{{#each messages}}" +
"{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
"{{#ifAssistant}}{{@root.assistantMessageToken}}{{content}}{{@root.assistantMessageEndToken}}{{/ifAssistant}}" +
"{{/each}}" +
"{{assistantMessageToken}}"
),
chatPromptTemplate: z.string().optional(),
promptExamples: z
.array(
z.object({
@ -84,11 +73,64 @@ const modelConfig = z.object({
const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS));
async function getChatPromptRender(
m: z.infer<typeof modelConfig>
): Promise<ReturnType<typeof compileTemplate<ChatTemplateInput>>> {
if (m.chatPromptTemplate) {
return compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m);
}
let tokenizer: PreTrainedTokenizer;
if (!m.tokenizer) {
throw new Error(
"No tokenizer specified and no chat prompt template specified for model " + m.name
);
}
try {
tokenizer = await getTokenizer(m.tokenizer);
} catch (e) {
throw Error(
"Failed to load tokenizer for model " +
m.name +
" consider setting chatPromptTemplate manually or making sure the model is available on the hub."
);
}
const renderTemplate = ({ messages, preprompt }: ChatTemplateInput) => {
let formattedMessages: { role: string; content: string }[] = messages.map((message) => ({
content: message.content,
role: message.from,
}));
if (preprompt) {
formattedMessages = [
{
role: "system",
content: preprompt,
},
...formattedMessages,
];
}
const output = tokenizer.apply_chat_template(formattedMessages, {
tokenize: false,
add_generation_prompt: true,
});
if (typeof output !== "string") {
throw new Error("Failed to apply chat template, the output is not a string");
}
return output;
};
return renderTemplate;
}
const processModel = async (m: z.infer<typeof modelConfig>) => ({
...m,
userMessageEndToken: m?.userMessageEndToken || m?.messageEndToken,
assistantMessageEndToken: m?.assistantMessageEndToken || m?.messageEndToken,
chatPromptRender: compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m),
chatPromptRender: await getChatPromptRender(m),
id: m.id || m.name,
displayName: m.displayName || m.name,
preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,

View File

@ -1,13 +1,5 @@
import type { Message } from "./Message";
export type LegacyParamatersTemplateInput = {
preprompt?: string;
userMessageToken: string;
userMessageEndToken: string;
assistantMessageToken: string;
assistantMessageEndToken: string;
};
export type ChatTemplateInput = {
messages: Pick<Message, "from" | "content">[];
preprompt?: string;

View File

@ -0,0 +1,18 @@
import type { Model } from "$lib/types/Model";
import { AutoTokenizer, PreTrainedTokenizer } from "@xenova/transformers";
export async function getTokenizer(_modelTokenizer: Exclude<Model["tokenizer"], undefined>) {
if (typeof _modelTokenizer === "string") {
// return auto tokenizer
return await AutoTokenizer.from_pretrained(_modelTokenizer);
} else {
// construct & return pretrained tokenizer
const { tokenizerUrl, tokenizerConfigUrl } = _modelTokenizer satisfies {
tokenizerUrl: string;
tokenizerConfigUrl: string;
};
const tokenizerJSON = await (await fetch(tokenizerUrl)).json();
const tokenizerConfig = await (await fetch(tokenizerConfigUrl)).json();
return new PreTrainedTokenizer(tokenizerJSON, tokenizerConfig);
}
}

View File

@ -1,5 +1,4 @@
import type { Message } from "$lib/types/Message";
import type { LegacyParamatersTemplateInput } from "$lib/types/Template";
import Handlebars from "handlebars";
Handlebars.registerHelper("ifUser", function (this: Pick<Message, "from" | "content">, options) {
@ -13,8 +12,8 @@ Handlebars.registerHelper(
}
);
export function compileTemplate<T>(input: string, model: LegacyParamatersTemplateInput) {
const template = Handlebars.compile<T & LegacyParamatersTemplateInput>(input, {
export function compileTemplate<T>(input: string, model: { preprompt: string }) {
const template = Handlebars.compile<T>(input, {
knownHelpers: { ifUser: true, ifAssistant: true },
knownHelpersOnly: true,
noEscape: true,