Skip to content

Commit

Permalink
Add option to remove trailing semicolon when splitting (fixes #742).
Browse files Browse the repository at this point in the history
  • Loading branch information
andialbrecht committed Oct 12, 2023
1 parent 6eca7ae commit 115e208
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 3 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ Notable Changes
* Drop support for Python 3.5.
* Python 3.12 is now supported (pr725, by hugovk).

Enhancements:

* Splitting statements now allows to remove the semicolon at the end.
Some database backends love statements without semicolon (issue742).

Bug Fixes

* Ignore dunder attributes when creating Tokens (issue672).
Expand Down
6 changes: 4 additions & 2 deletions sqlparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ def format(sql, encoding=None, **options):
return ''.join(stack.run(sql, encoding))


def split(sql, encoding=None):
def split(sql, encoding=None, strip_semicolon=False):
"""Split *sql* into single statements.
:param sql: A string containing one or more SQL statements.
:param encoding: The encoding of the statement (optional).
:param strip_semicolon: If True, remove trainling semicolons
(default: False).
:returns: A list of strings.
"""
stack = engine.FilterStack()
stack = engine.FilterStack(strip_semicolon=strip_semicolon)
return [str(stmt).strip() for stmt in stack.run(sql, encoding)]
5 changes: 4 additions & 1 deletion sqlparse/engine/filter_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
from sqlparse import lexer
from sqlparse.engine import grouping
from sqlparse.engine.statement_splitter import StatementSplitter
from sqlparse.filters import StripTrailingSemicolonFilter


class FilterStack:
def __init__(self):
def __init__(self, strip_semicolon=False):
self.preprocess = []
self.stmtprocess = []
self.postprocess = []
self._grouping = False
if strip_semicolon:
self.stmtprocess.append(StripTrailingSemicolonFilter())

def enable_grouping(self):
self._grouping = True
Expand Down
2 changes: 2 additions & 0 deletions sqlparse/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlparse.filters.others import SerializerUnicode
from sqlparse.filters.others import StripCommentsFilter
from sqlparse.filters.others import StripWhitespaceFilter
from sqlparse.filters.others import StripTrailingSemicolonFilter
from sqlparse.filters.others import SpacesAroundOperatorsFilter

from sqlparse.filters.output import OutputPHPFilter
Expand All @@ -25,6 +26,7 @@
'SerializerUnicode',
'StripCommentsFilter',
'StripWhitespaceFilter',
'StripTrailingSemicolonFilter',
'SpacesAroundOperatorsFilter',

'OutputPHPFilter',
Expand Down
9 changes: 9 additions & 0 deletions sqlparse/filters/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ def process(self, stmt):
return stmt


class StripTrailingSemicolonFilter:

def process(self, stmt):
while stmt.tokens and (stmt.tokens[-1].is_whitespace
or stmt.tokens[-1].value == ';'):
stmt.tokens.pop()
return stmt


# ---------------------------
# postprocess

Expand Down
28 changes: 28 additions & 0 deletions tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,31 @@ def test_split_mysql_handler_for(load_file):
# see issue581
stmts = sqlparse.split(load_file('mysql_handler.sql'))
assert len(stmts) == 2


@pytest.mark.parametrize('sql, expected', [
('select * from foo;', ['select * from foo']),
('select * from foo', ['select * from foo']),
('select * from foo; select * from bar;', [
'select * from foo',
'select * from bar',
]),
(' select * from foo;\n\nselect * from bar;\n\n\n\n', [
'select * from foo',
'select * from bar',
]),
('select * from foo\n\n; bar', ['select * from foo', 'bar']),
])
def test_split_strip_semicolon(sql, expected):
stmts = sqlparse.split(sql, strip_semicolon=True)
assert len(stmts) == len(expected)
for idx, expectation in enumerate(expected):
assert stmts[idx] == expectation


def test_split_strip_semicolon_procedure(load_file):
stmts = sqlparse.split(load_file('mysql_handler.sql'),
strip_semicolon=True)
assert len(stmts) == 2
assert stmts[0].endswith('end')
assert stmts[1].endswith('end')

0 comments on commit 115e208

Please sign in to comment.