From a03d92af2e05f65ca933d7d07a7ca9848175dcaa Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 16 Dec 2023 15:09:18 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20jinja=20-=20add=20filter=20opera?= =?UTF-8?q?tor=20(#419)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds support for the filter/pipe operator: - `{{ arr | length }}` gets array length - `{{ 1 + arr | length }}` computes 1 + (array length), order important. - `{{ 2 + arr | sort | length }}` computes 2 + ((array sort) length), order important - `{{ (arr | sort)[0] }}` computes (array sort)[0] Needed for [this](https://discuss.huggingface.co/t/issue-with-llama-2-chat-template-and-out-of-date-documentation/61645/3) complicated chat template. (cc @Rocketknight1) At the moment, we don't support user-defined or non-identifier filter functions, but could be added in future. Especially after we add user-defined functions. --- packages/jinja/src/ast.ts | 15 ++++++++ packages/jinja/src/lexer.ts | 2 + packages/jinja/src/parser.ts | 20 +++++++++- packages/jinja/src/runtime.ts | 54 +++++++++++++++++++++++++++ packages/jinja/test/templates.test.js | 46 +++++++++++++++++++++++ 5 files changed, 135 insertions(+), 2 deletions(-) diff --git a/packages/jinja/src/ast.ts b/packages/jinja/src/ast.ts index 407eeaf58..20ad4478a 100644 --- a/packages/jinja/src/ast.ts +++ b/packages/jinja/src/ast.ts @@ -150,6 +150,21 @@ export class BinaryExpression extends Expression { } } +/** + * An operation with two sides, separated by the | operator. + * Operator precedence: https://github.com/pallets/jinja/issues/379#issuecomment-168076202 + */ +export class FilterExpression extends Expression { + override type = "FilterExpression"; + + constructor( + public operand: Expression, + public filter: Identifier // TODO: Add support for non-identifier filters + ) { + super(); + } +} + /** * An operation with one side (operator on the left). */ diff --git a/packages/jinja/src/lexer.ts b/packages/jinja/src/lexer.ts index 3dd6eaafb..61c3fd7f3 100644 --- a/packages/jinja/src/lexer.ts +++ b/packages/jinja/src/lexer.ts @@ -20,6 +20,7 @@ export const TOKEN_TYPES = Object.freeze({ Comma: "Comma", // , Dot: "Dot", // . Colon: "Colon", // : + Pipe: "Pipe", // | CallOperator: "CallOperator", // () AdditiveBinaryOperator: "AdditiveBinaryOperator", // + - @@ -106,6 +107,7 @@ const ORDERED_MAPPING_TABLE: [string, TokenType][] = [ [",", TOKEN_TYPES.Comma], [".", TOKEN_TYPES.Dot], [":", TOKEN_TYPES.Colon], + ["|", TOKEN_TYPES.Pipe], // Comparison operators ["<=", TOKEN_TYPES.ComparisonBinaryOperator], [">=", TOKEN_TYPES.ComparisonBinaryOperator], diff --git a/packages/jinja/src/parser.ts b/packages/jinja/src/parser.ts index feee4e905..f033e4b06 100644 --- a/packages/jinja/src/parser.ts +++ b/packages/jinja/src/parser.ts @@ -13,6 +13,7 @@ import { StringLiteral, BooleanLiteral, BinaryExpression, + FilterExpression, UnaryExpression, SliceExpression, } from "./ast"; @@ -353,17 +354,32 @@ export function parse(tokens: Token[]): Program { } function parseMultiplicativeExpression(): Statement { - let left = parseCallMemberExpression(); + let left = parseFilterExpression(); while (is(TOKEN_TYPES.MultiplicativeBinaryOperator)) { const operator = tokens[current]; ++current; - const right = parseCallMemberExpression(); + const right = parseFilterExpression(); left = new BinaryExpression(operator, left, right); } return left; } + function parseFilterExpression(): Statement { + let operand = parseCallMemberExpression(); + + while (is(TOKEN_TYPES.Pipe)) { + // Support chaining filters + ++current; // consume pipe + const filter = parsePrimaryExpression(); // should be an identifier + if (!(filter instanceof Identifier)) { + throw new SyntaxError(`Expected identifier for the filter`); + } + operand = new FilterExpression(operand, filter); + } + return operand; + } + function parsePrimaryExpression(): Statement { // Primary expression: number, string, identifier, function call, parenthesized expression const token = tokens[current]; diff --git a/packages/jinja/src/runtime.ts b/packages/jinja/src/runtime.ts index 885a0b559..b2a57995e 100644 --- a/packages/jinja/src/runtime.ts +++ b/packages/jinja/src/runtime.ts @@ -11,6 +11,7 @@ import type { CallExpression, Identifier, BinaryExpression, + FilterExpression, UnaryExpression, SliceExpression, } from "./ast"; @@ -292,6 +293,57 @@ export class Interpreter { throw new SyntaxError(`Unknown operator "${node.operator.value}" between ${left.type} and ${right.type}`); } + /** + * Evaulates expressions following the filter operation type. + */ + private evaluateFilterExpression(node: FilterExpression, environment: Environment): AnyRuntimeValue { + const operand = this.evaluate(node.operand, environment); + + // For now, we only support the built-in filters + // TODO: Add support for non-identifier filters + // e.g., functions which return filters: {{ numbers | select("odd") }} + // TODO: Add support for user-defined filters + // const filter = environment.lookupVariable(node.filter.value); + // if (!(filter instanceof FunctionValue)) { + // throw new Error(`Filter must be a function: got ${filter.type}`); + // } + // return filter.value([operand], environment); + + if (operand instanceof ArrayValue) { + switch (node.filter.value) { + case "first": + return operand.value[0]; + case "last": + return operand.value[operand.value.length - 1]; + case "length": + return new NumericValue(operand.value.length); + case "reverse": + return new ArrayValue(operand.value.reverse()); + case "sort": + return new ArrayValue( + operand.value.sort((a, b) => { + if (a.type !== b.type) { + throw new Error(`Cannot compare different types: ${a.type} and ${b.type}`); + } + switch (a.type) { + case "NumericValue": + return (a as NumericValue).value - (b as NumericValue).value; + case "StringValue": + return (a as StringValue).value.localeCompare((b as StringValue).value); + default: + throw new Error(`Cannot compare type: ${a.type}`); + } + }) + ); + default: + throw new Error(`Unknown filter: ${node.filter.value}`); + } + } + + // TODO add support for StringValue operand + throw new Error(`Cannot apply filter to type: ${operand.type}`); + } + /** * Evaulates expressions following the unary operation type. */ @@ -510,6 +562,8 @@ export class Interpreter { return this.evaluateUnaryExpression(statement as UnaryExpression, environment); case "BinaryExpression": return this.evaluateBinaryExpression(statement as BinaryExpression, environment); + case "FilterExpression": + return this.evaluateFilterExpression(statement as FilterExpression, environment); default: throw new SyntaxError(`Unknown node type: ${statement.type}`); diff --git a/packages/jinja/test/templates.test.js b/packages/jinja/test/templates.test.js index b8682a8da..d543582f7 100644 --- a/packages/jinja/test/templates.test.js +++ b/packages/jinja/test/templates.test.js @@ -66,6 +66,9 @@ const TEST_STRINGS = { // Substring inclusion SUBSTRING_INCLUSION: `|{{ '' in 'abc' }}|{{ 'a' in 'abc' }}|{{ 'd' in 'abc' }}|{{ 'ab' in 'abc' }}|{{ 'ac' in 'abc' }}|{{ 'abc' in 'abc' }}|{{ 'abcd' in 'abc' }}|`, + + // Filter operator + FILTER_OPERATOR: `{{ arr | length }}{{ 1 + arr | length }}{{ 2 + arr | sort | length }}{{ (arr | sort)[0] }}`, }; const TEST_PARSED = { @@ -1094,6 +1097,41 @@ const TEST_PARSED = { { value: "}}", type: "CloseExpression" }, { value: "|", type: "Text" }, ], + + // Filter operator + FILTER_OPERATOR: [ + { value: "{{", type: "OpenExpression" }, + { value: "arr", type: "Identifier" }, + { value: "|", type: "Pipe" }, + { value: "length", type: "Identifier" }, + { value: "}}", type: "CloseExpression" }, + { value: "{{", type: "OpenExpression" }, + { value: "1", type: "NumericLiteral" }, + { value: "+", type: "AdditiveBinaryOperator" }, + { value: "arr", type: "Identifier" }, + { value: "|", type: "Pipe" }, + { value: "length", type: "Identifier" }, + { value: "}}", type: "CloseExpression" }, + { value: "{{", type: "OpenExpression" }, + { value: "2", type: "NumericLiteral" }, + { value: "+", type: "AdditiveBinaryOperator" }, + { value: "arr", type: "Identifier" }, + { value: "|", type: "Pipe" }, + { value: "sort", type: "Identifier" }, + { value: "|", type: "Pipe" }, + { value: "length", type: "Identifier" }, + { value: "}}", type: "CloseExpression" }, + { value: "{{", type: "OpenExpression" }, + { value: "(", type: "OpenParen" }, + { value: "arr", type: "Identifier" }, + { value: "|", type: "Pipe" }, + { value: "sort", type: "Identifier" }, + { value: ")", type: "CloseParen" }, + { value: "[", type: "OpenSquareBracket" }, + { value: "0", type: "NumericLiteral" }, + { value: "]", type: "CloseSquareBracket" }, + { value: "}}", type: "CloseExpression" }, + ], }; const TEST_CONTEXT = { @@ -1196,6 +1234,11 @@ const TEST_CONTEXT = { // Substring inclusion SUBSTRING_INCLUSION: {}, + + // Filter operator + FILTER_OPERATOR: { + arr: [3, 2, 1], + }, }; const EXPECTED_OUTPUTS = { @@ -1262,6 +1305,9 @@ const EXPECTED_OUTPUTS = { // Substring inclusion SUBSTRING_INCLUSION: `|true|true|false|true|false|true|false|`, + + // Filter operator + FILTER_OPERATOR: `3451`, }; describe("Templates", () => {