Skip to content

Commit

Permalink
🚀 @huggingface/jinja 100% chat template coverage. (#536)
Browse files Browse the repository at this point in the history
With this PR, I'm pleased to announce that `@huggingface/jinja` is now
compatible with every single _valid_[^1] chat template on the HF Hub (as
of 2024/03/05), including the very complex function calling ones (like
[fireworks-ai/firefunction-v1](https://huggingface.co/fireworks-ai/firefunction-v1))!
🥳, and we match the python output exactly in each case! Interestingly,
of the ~11k public conversational models, there were only ~250 unique
templates. This will also ensure all models tagged with `conversational`
will have working conversational widgets!

[^1]: Only 1 invalid chat template was discovered (containing invalid
quote characters), and a PR has been made for it
[here](https://huggingface.co/YokaiKoibito/llama2_70b_chat_uncensored-fp16/discussions/1#65e8686f034b83aeb8d40528).


Of course, future models may introduce more complex chat templates, and
we'll continue to add support for them!


cc @Wauplin @Rocketknight1 @osanseviero
  • Loading branch information
xenova authored Mar 6, 2024
1 parent d4757ea commit f570f3c
Show file tree
Hide file tree
Showing 7 changed files with 901 additions and 158 deletions.
20 changes: 15 additions & 5 deletions packages/jinja/src/ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,6 @@ export class NumericLiteral extends Literal<number> {
*/
export class StringLiteral extends Literal<string> {
override type = "StringLiteral";

constructor(value: string) {
super(value);
}
}

/**
Expand All @@ -133,6 +129,20 @@ export class BooleanLiteral extends Literal<boolean> {
override type = "BooleanLiteral";
}

/**
* Represents an array literal in the template.
*/
export class ArrayLiteral extends Literal<Expression[]> {
override type = "ArrayLiteral";
}

/**
* Represents an object literal in the template.
*/
export class ObjectLiteral extends Literal<Map<Expression, Expression>> {
override type = "ObjectLiteral";
}

/**
* An operation with two sides, separated by an operator.
* Note: Either side can be a Complex Expression, with order
Expand All @@ -159,7 +169,7 @@ export class FilterExpression extends Expression {

constructor(
public operand: Expression,
public filter: Identifier // TODO: Add support for non-identifier filters
public filter: Identifier | CallExpression
) {
super();
}
Expand Down
26 changes: 18 additions & 8 deletions packages/jinja/src/lexer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ export const TOKEN_TYPES = Object.freeze({
CloseExpression: "CloseExpression", // }}
OpenSquareBracket: "OpenSquareBracket", // [
CloseSquareBracket: "CloseSquareBracket", // ]
OpenCurlyBracket: "OpenCurlyBracket", // {
CloseCurlyBracket: "CloseCurlyBracket", // }
Comma: "Comma", // ,
Dot: "Dot", // .
Colon: "Colon", // :
Expand Down Expand Up @@ -104,6 +106,8 @@ const ORDERED_MAPPING_TABLE: [string, TokenType][] = [
// Single character tokens
["(", TOKEN_TYPES.OpenParen],
[")", TOKEN_TYPES.CloseParen],
["{", TOKEN_TYPES.OpenCurlyBracket],
["}", TOKEN_TYPES.CloseCurlyBracket],
["[", TOKEN_TYPES.OpenSquareBracket],
["]", TOKEN_TYPES.CloseSquareBracket],
[",", TOKEN_TYPES.Comma],
Expand Down Expand Up @@ -154,19 +158,25 @@ function preprocess(template: string, options: PreprocessOptions = {}): string {
template = template.slice(0, -1);
}

if (options.trim_blocks) {
// If an application configures Jinja to trim_blocks, the first newline after
// a template tag is removed automatically (like in PHP).
template = template.replace(/%}\n/g, "%}");
}
// Replace all comments with a placeholder
// This ensures that comments don't interfere with the following options
template = template.replace(/{#.*?#}/gs, "{##}");

if (options.lstrip_blocks) {
// The lstrip_blocks option can also be set to strip tabs and spaces from the
// beginning of a line to the start of a block. (Nothing will be stripped if
// there are other characters before the start of the block.)
template = template.replace(/^[ \t]*{%/gm, "{%");
template = template.replace(/^[ \t]*({[#%])/gm, "$1");
}

if (options.trim_blocks) {
// If an application configures Jinja to trim_blocks, the first newline after
// a template tag is removed automatically (like in PHP).
template = template.replace(/([#%]})\n/g, "$1");
}

return template
.replace(/{##}/g, "") // Remove comments
.replace(/-%}\s*/g, "%}")
.replace(/\s*{%-/g, "{%")
.replace(/-}}\s*/g, "}}")
Expand Down Expand Up @@ -283,9 +293,9 @@ export function tokenize(source: string, options: PreprocessOptions = {}): Token
}
}

if (char === "'") {
if (char === "'" || char === '"') {
++cursorPosition; // Skip the opening quote
const str = consumeWhile((char) => char !== "'");
const str = consumeWhile((c) => c !== char);
tokens.push(new Token(str, TOKEN_TYPES.StringLiteral));
++cursorPosition; // Skip the closing quote
continue;
Expand Down
53 changes: 50 additions & 3 deletions packages/jinja/src/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import {
NumericLiteral,
StringLiteral,
BooleanLiteral,
ArrayLiteral,
ObjectLiteral,
BinaryExpression,
FilterExpression,
TestExpression,
Expand Down Expand Up @@ -197,7 +199,16 @@ export function parse(tokens: Token[]): Program {

function parseExpression(): Statement {
// Choose parse function with lowest precedence
return parseLogicalOrExpression();
const a = parseLogicalOrExpression();
if (is(TOKEN_TYPES.If)) {
// Ternary expression
++current; // consume if
const predicate = parseLogicalOrExpression();
expect(TOKEN_TYPES.Else, "Expected else token");
const b = parseLogicalOrExpression();
return new If(predicate, [a], [b]);
}
return a;
}

function parseLogicalOrExpression(): Statement {
Expand Down Expand Up @@ -423,11 +434,14 @@ export function parse(tokens: Token[]): Program {
while (is(TOKEN_TYPES.Pipe)) {
// Support chaining filters
++current; // consume pipe
const filter = parsePrimaryExpression(); // should be an identifier
let filter = parsePrimaryExpression(); // should be an identifier
if (!(filter instanceof Identifier)) {
throw new SyntaxError(`Expected identifier for the filter`);
}
operand = new FilterExpression(operand, filter);
if (is(TOKEN_TYPES.OpenParen)) {
filter = parseCallExpression(filter);
}
operand = new FilterExpression(operand, filter as Identifier | CallExpression);
}
return operand;
}
Expand Down Expand Up @@ -457,6 +471,39 @@ export function parse(tokens: Token[]): Program {
++current; // consume closing parenthesis
return expression;
}
case TOKEN_TYPES.OpenSquareBracket: {
++current; // consume opening square bracket

const values = [];
while (!is(TOKEN_TYPES.CloseSquareBracket)) {
values.push(parseExpression());

if (is(TOKEN_TYPES.Comma)) {
++current; // consume comma
}
}
++current; // consume closing square bracket

return new ArrayLiteral(values);
}
case TOKEN_TYPES.OpenCurlyBracket: {
++current; // consume opening curly bracket

const values = new Map();
while (!is(TOKEN_TYPES.CloseCurlyBracket)) {
const key = parseExpression();
expect(TOKEN_TYPES.Colon, "Expected colon between key and value in object literal");
const value = parseExpression();
values.set(key, value);

if (is(TOKEN_TYPES.Comma)) {
++current; // consume comma
}
}
++current; // consume closing curly bracket

return new ObjectLiteral(values);
}
default:
throw new SyntaxError(`Unexpected token: ${token.type}`);
}
Expand Down
Loading

0 comments on commit f570f3c

Please sign in to comment.