Skip to content

Commit

Permalink
🚀 jinja - add filter operator (#419)
Browse files Browse the repository at this point in the history
Adds support for the filter/pipe operator:
- `{{ arr | length }}` gets array length
- `{{ 1 + arr | length }}` computes 1 + (array length), order important.
- `{{ 2 + arr | sort | length }}` computes 2 + ((array sort) length),
order important
- `{{ (arr | sort)[0] }}` computes (array sort)[0]
 
Needed for
[this](https://discuss.huggingface.co/t/issue-with-llama-2-chat-template-and-out-of-date-documentation/61645/3)
complicated chat template. (cc @Rocketknight1)


At the moment, we don't support user-defined or non-identifier filter
functions, but could be added in future. Especially after we add
user-defined functions.
  • Loading branch information
xenova committed Dec 16, 2023
1 parent 329090c commit a03d92a
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 2 deletions.
15 changes: 15 additions & 0 deletions packages/jinja/src/ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,21 @@ export class BinaryExpression extends Expression {
}
}

/**
* An operation with two sides, separated by the | operator.
* Operator precedence: https://github.com/pallets/jinja/issues/379#issuecomment-168076202
*/
export class FilterExpression extends Expression {
override type = "FilterExpression";

constructor(
public operand: Expression,
public filter: Identifier // TODO: Add support for non-identifier filters
) {
super();
}
}

/**
* An operation with one side (operator on the left).
*/
Expand Down
2 changes: 2 additions & 0 deletions packages/jinja/src/lexer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export const TOKEN_TYPES = Object.freeze({
Comma: "Comma", // ,
Dot: "Dot", // .
Colon: "Colon", // :
Pipe: "Pipe", // |

CallOperator: "CallOperator", // ()
AdditiveBinaryOperator: "AdditiveBinaryOperator", // + -
Expand Down Expand Up @@ -106,6 +107,7 @@ const ORDERED_MAPPING_TABLE: [string, TokenType][] = [
[",", TOKEN_TYPES.Comma],
[".", TOKEN_TYPES.Dot],
[":", TOKEN_TYPES.Colon],
["|", TOKEN_TYPES.Pipe],
// Comparison operators
["<=", TOKEN_TYPES.ComparisonBinaryOperator],
[">=", TOKEN_TYPES.ComparisonBinaryOperator],
Expand Down
20 changes: 18 additions & 2 deletions packages/jinja/src/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
StringLiteral,
BooleanLiteral,
BinaryExpression,
FilterExpression,
UnaryExpression,
SliceExpression,
} from "./ast";
Expand Down Expand Up @@ -353,17 +354,32 @@ export function parse(tokens: Token[]): Program {
}

function parseMultiplicativeExpression(): Statement {
let left = parseCallMemberExpression();
let left = parseFilterExpression();

while (is(TOKEN_TYPES.MultiplicativeBinaryOperator)) {
const operator = tokens[current];
++current;
const right = parseCallMemberExpression();
const right = parseFilterExpression();
left = new BinaryExpression(operator, left, right);
}
return left;
}

function parseFilterExpression(): Statement {
let operand = parseCallMemberExpression();

while (is(TOKEN_TYPES.Pipe)) {
// Support chaining filters
++current; // consume pipe
const filter = parsePrimaryExpression(); // should be an identifier
if (!(filter instanceof Identifier)) {
throw new SyntaxError(`Expected identifier for the filter`);
}
operand = new FilterExpression(operand, filter);
}
return operand;
}

function parsePrimaryExpression(): Statement {
// Primary expression: number, string, identifier, function call, parenthesized expression
const token = tokens[current];
Expand Down
54 changes: 54 additions & 0 deletions packages/jinja/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import type {
CallExpression,
Identifier,
BinaryExpression,
FilterExpression,
UnaryExpression,
SliceExpression,
} from "./ast";
Expand Down Expand Up @@ -292,6 +293,57 @@ export class Interpreter {
throw new SyntaxError(`Unknown operator "${node.operator.value}" between ${left.type} and ${right.type}`);
}

/**
* Evaulates expressions following the filter operation type.
*/
private evaluateFilterExpression(node: FilterExpression, environment: Environment): AnyRuntimeValue {
const operand = this.evaluate(node.operand, environment);

// For now, we only support the built-in filters
// TODO: Add support for non-identifier filters
// e.g., functions which return filters: {{ numbers | select("odd") }}
// TODO: Add support for user-defined filters
// const filter = environment.lookupVariable(node.filter.value);
// if (!(filter instanceof FunctionValue)) {
// throw new Error(`Filter must be a function: got ${filter.type}`);
// }
// return filter.value([operand], environment);

if (operand instanceof ArrayValue) {
switch (node.filter.value) {
case "first":
return operand.value[0];
case "last":
return operand.value[operand.value.length - 1];
case "length":
return new NumericValue(operand.value.length);
case "reverse":
return new ArrayValue(operand.value.reverse());
case "sort":
return new ArrayValue(
operand.value.sort((a, b) => {
if (a.type !== b.type) {
throw new Error(`Cannot compare different types: ${a.type} and ${b.type}`);
}
switch (a.type) {
case "NumericValue":
return (a as NumericValue).value - (b as NumericValue).value;
case "StringValue":
return (a as StringValue).value.localeCompare((b as StringValue).value);
default:
throw new Error(`Cannot compare type: ${a.type}`);
}
})
);
default:
throw new Error(`Unknown filter: ${node.filter.value}`);
}
}

// TODO add support for StringValue operand
throw new Error(`Cannot apply filter to type: ${operand.type}`);
}

/**
* Evaulates expressions following the unary operation type.
*/
Expand Down Expand Up @@ -510,6 +562,8 @@ export class Interpreter {
return this.evaluateUnaryExpression(statement as UnaryExpression, environment);
case "BinaryExpression":
return this.evaluateBinaryExpression(statement as BinaryExpression, environment);
case "FilterExpression":
return this.evaluateFilterExpression(statement as FilterExpression, environment);

default:
throw new SyntaxError(`Unknown node type: ${statement.type}`);
Expand Down
46 changes: 46 additions & 0 deletions packages/jinja/test/templates.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ const TEST_STRINGS = {

// Substring inclusion
SUBSTRING_INCLUSION: `|{{ '' in 'abc' }}|{{ 'a' in 'abc' }}|{{ 'd' in 'abc' }}|{{ 'ab' in 'abc' }}|{{ 'ac' in 'abc' }}|{{ 'abc' in 'abc' }}|{{ 'abcd' in 'abc' }}|`,

// Filter operator
FILTER_OPERATOR: `{{ arr | length }}{{ 1 + arr | length }}{{ 2 + arr | sort | length }}{{ (arr | sort)[0] }}`,
};

const TEST_PARSED = {
Expand Down Expand Up @@ -1094,6 +1097,41 @@ const TEST_PARSED = {
{ value: "}}", type: "CloseExpression" },
{ value: "|", type: "Text" },
],

// Filter operator
FILTER_OPERATOR: [
{ value: "{{", type: "OpenExpression" },
{ value: "arr", type: "Identifier" },
{ value: "|", type: "Pipe" },
{ value: "length", type: "Identifier" },
{ value: "}}", type: "CloseExpression" },
{ value: "{{", type: "OpenExpression" },
{ value: "1", type: "NumericLiteral" },
{ value: "+", type: "AdditiveBinaryOperator" },
{ value: "arr", type: "Identifier" },
{ value: "|", type: "Pipe" },
{ value: "length", type: "Identifier" },
{ value: "}}", type: "CloseExpression" },
{ value: "{{", type: "OpenExpression" },
{ value: "2", type: "NumericLiteral" },
{ value: "+", type: "AdditiveBinaryOperator" },
{ value: "arr", type: "Identifier" },
{ value: "|", type: "Pipe" },
{ value: "sort", type: "Identifier" },
{ value: "|", type: "Pipe" },
{ value: "length", type: "Identifier" },
{ value: "}}", type: "CloseExpression" },
{ value: "{{", type: "OpenExpression" },
{ value: "(", type: "OpenParen" },
{ value: "arr", type: "Identifier" },
{ value: "|", type: "Pipe" },
{ value: "sort", type: "Identifier" },
{ value: ")", type: "CloseParen" },
{ value: "[", type: "OpenSquareBracket" },
{ value: "0", type: "NumericLiteral" },
{ value: "]", type: "CloseSquareBracket" },
{ value: "}}", type: "CloseExpression" },
],
};

const TEST_CONTEXT = {
Expand Down Expand Up @@ -1196,6 +1234,11 @@ const TEST_CONTEXT = {

// Substring inclusion
SUBSTRING_INCLUSION: {},

// Filter operator
FILTER_OPERATOR: {
arr: [3, 2, 1],
},
};

const EXPECTED_OUTPUTS = {
Expand Down Expand Up @@ -1262,6 +1305,9 @@ const EXPECTED_OUTPUTS = {

// Substring inclusion
SUBSTRING_INCLUSION: `|true|true|false|true|false|true|false|`,

// Filter operator
FILTER_OPERATOR: `3451`,
};

describe("Templates", () => {
Expand Down

0 comments on commit a03d92a

Please sign in to comment.