From 115e208bd340f175b23964524670418fe6f72c31 Mon Sep 17 00:00:00 2001 From: Andi Albrecht Date: Thu, 12 Oct 2023 21:11:50 +0200 Subject: [PATCH] Add option to remove trailing semicolon when splitting (fixes #742). --- CHANGELOG | 5 +++++ sqlparse/__init__.py | 6 ++++-- sqlparse/engine/filter_stack.py | 5 ++++- sqlparse/filters/__init__.py | 2 ++ sqlparse/filters/others.py | 9 +++++++++ tests/test_split.py | 28 ++++++++++++++++++++++++++++ 6 files changed, 52 insertions(+), 3 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index 525918a2..0ede2800 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -6,6 +6,11 @@ Notable Changes * Drop support for Python 3.5. * Python 3.12 is now supported (pr725, by hugovk). +Enhancements: + +* Splitting statements now allows to remove the semicolon at the end. + Some database backends love statements without semicolon (issue742). + Bug Fixes * Ignore dunder attributes when creating Tokens (issue672). diff --git a/sqlparse/__init__.py b/sqlparse/__init__.py index cfd4e2fd..b80b2d60 100644 --- a/sqlparse/__init__.py +++ b/sqlparse/__init__.py @@ -59,12 +59,14 @@ def format(sql, encoding=None, **options): return ''.join(stack.run(sql, encoding)) -def split(sql, encoding=None): +def split(sql, encoding=None, strip_semicolon=False): """Split *sql* into single statements. :param sql: A string containing one or more SQL statements. :param encoding: The encoding of the statement (optional). + :param strip_semicolon: If True, remove trainling semicolons + (default: False). :returns: A list of strings. """ - stack = engine.FilterStack() + stack = engine.FilterStack(strip_semicolon=strip_semicolon) return [str(stmt).strip() for stmt in stack.run(sql, encoding)] diff --git a/sqlparse/engine/filter_stack.py b/sqlparse/engine/filter_stack.py index 9665a224..3feba377 100644 --- a/sqlparse/engine/filter_stack.py +++ b/sqlparse/engine/filter_stack.py @@ -10,14 +10,17 @@ from sqlparse import lexer from sqlparse.engine import grouping from sqlparse.engine.statement_splitter import StatementSplitter +from sqlparse.filters import StripTrailingSemicolonFilter class FilterStack: - def __init__(self): + def __init__(self, strip_semicolon=False): self.preprocess = [] self.stmtprocess = [] self.postprocess = [] self._grouping = False + if strip_semicolon: + self.stmtprocess.append(StripTrailingSemicolonFilter()) def enable_grouping(self): self._grouping = True diff --git a/sqlparse/filters/__init__.py b/sqlparse/filters/__init__.py index 5bd6b325..06169460 100644 --- a/sqlparse/filters/__init__.py +++ b/sqlparse/filters/__init__.py @@ -8,6 +8,7 @@ from sqlparse.filters.others import SerializerUnicode from sqlparse.filters.others import StripCommentsFilter from sqlparse.filters.others import StripWhitespaceFilter +from sqlparse.filters.others import StripTrailingSemicolonFilter from sqlparse.filters.others import SpacesAroundOperatorsFilter from sqlparse.filters.output import OutputPHPFilter @@ -25,6 +26,7 @@ 'SerializerUnicode', 'StripCommentsFilter', 'StripWhitespaceFilter', + 'StripTrailingSemicolonFilter', 'SpacesAroundOperatorsFilter', 'OutputPHPFilter', diff --git a/sqlparse/filters/others.py b/sqlparse/filters/others.py index 9e617c37..da7c0e79 100644 --- a/sqlparse/filters/others.py +++ b/sqlparse/filters/others.py @@ -126,6 +126,15 @@ def process(self, stmt): return stmt +class StripTrailingSemicolonFilter: + + def process(self, stmt): + while stmt.tokens and (stmt.tokens[-1].is_whitespace + or stmt.tokens[-1].value == ';'): + stmt.tokens.pop() + return stmt + + # --------------------------- # postprocess diff --git a/tests/test_split.py b/tests/test_split.py index e79750e8..30a50c59 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -166,3 +166,31 @@ def test_split_mysql_handler_for(load_file): # see issue581 stmts = sqlparse.split(load_file('mysql_handler.sql')) assert len(stmts) == 2 + + +@pytest.mark.parametrize('sql, expected', [ + ('select * from foo;', ['select * from foo']), + ('select * from foo', ['select * from foo']), + ('select * from foo; select * from bar;', [ + 'select * from foo', + 'select * from bar', + ]), + (' select * from foo;\n\nselect * from bar;\n\n\n\n', [ + 'select * from foo', + 'select * from bar', + ]), + ('select * from foo\n\n; bar', ['select * from foo', 'bar']), +]) +def test_split_strip_semicolon(sql, expected): + stmts = sqlparse.split(sql, strip_semicolon=True) + assert len(stmts) == len(expected) + for idx, expectation in enumerate(expected): + assert stmts[idx] == expectation + + +def test_split_strip_semicolon_procedure(load_file): + stmts = sqlparse.split(load_file('mysql_handler.sql'), + strip_semicolon=True) + assert len(stmts) == 2 + assert stmts[0].endswith('end') + assert stmts[1].endswith('end')