diff --git a/migration_checker/executor.py b/migration_checker/executor.py index 0a7c4f8..b88ae73 100644 --- a/migration_checker/executor.py +++ b/migration_checker/executor.py @@ -2,13 +2,15 @@ Helper to execute migrations and record results """ -from typing import Any, Callable, Union, cast +from typing import Any, Callable, Sequence, Union, cast import django +import sqlparse # type: ignore[import] from django.contrib.postgres.operations import NotInTransactionMixin from django.db import connections, transaction -from django.db.migrations import Migration +from django.db.migrations import Migration, RunSQL, SeparateDatabaseAndState from django.db.migrations.executor import MigrationExecutor +from django.db.migrations.operations.base import Operation from django.db.migrations.recorder import MigrationRecorder from django.db.migrations.state import ProjectState @@ -43,7 +45,7 @@ def __init__( *, database: str, apply_migrations: bool, - outputs: list[Union[ConsoleOutput, GithubCommentOutput]] + outputs: list[Union[ConsoleOutput, GithubCommentOutput]], ) -> None: self.database = database self.apply_migrations = apply_migrations @@ -118,11 +120,7 @@ def _apply_migration( # Some operations, like AddIndexConcurrently, cannot be run in a # transaction, so for those special cases we skip recording locks # because we ahave no way of doing that. - must_be_non_atomic = any( - isinstance(operation, NotInTransactionMixin) - for operation in migration.operations - ) - if must_be_non_atomic: + if self._must_be_non_atomic(migration.operations): return self._apply_non_atomic_migration(migration, state), None # Apply the migration in the database and record queries and locks @@ -137,6 +135,58 @@ def _apply_migration( return query_logger.queries, locks + def _must_be_non_atomic_query(self, query: str) -> bool: + """ + Try to detect if a raw query must be non-atomic. + """ + + patterns = [ + [ + (sqlparse.tokens.DDL, "CREATE"), + (sqlparse.tokens.Keyword, "INDEX"), + (sqlparse.tokens.Keyword, "CONCURRENTLY"), + ], + [ + (sqlparse.tokens.DDL, "DROP"), + (sqlparse.tokens.Keyword, "INDEX"), + (sqlparse.tokens.Keyword, "CONCURRENTLY"), + ], + ] + + for statement in sqlparse.parse(query): + for pattern in patterns: + if all( + any(token.match(ttype, value) for token in statement.tokens) + for ttype, value in pattern + ): + return True + return False + + def _must_be_non_atomic(self, operations: Sequence[Operation]) -> bool: + """ + Check if any of the operations must be run outside of a transaction. + This is the case for some operations, like AddIndexConcurrently. This + will recursivey check SeparateDatabaseAndState migrations. + """ + + for operation in operations: + if isinstance(operation, NotInTransactionMixin): + return True + if isinstance(operation, SeparateDatabaseAndState): + return self._must_be_non_atomic(operation.database_operations) + if isinstance(operation, RunSQL): + if isinstance(operation.sql, str): + return self._must_be_non_atomic_query(operation.sql) + else: + return any( + self._must_be_non_atomic_query(statement) + if isinstance(statement, str) + else self._must_be_non_atomic_query(statement[0]) + for statement in operation.sql + ) + + return False + def _apply_non_atomic_migration( self, migration: Migration, state: ProjectState ) -> list[str]: diff --git a/poetry.lock b/poetry.lock index 3b2de11..f421062 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,10 +1,9 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "asgiref" version = "3.6.0" description = "ASGI specs, helper code, and adapters" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -19,7 +18,6 @@ tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] name = "attrs" version = "22.2.0" description = "Classes Without Boilerplate" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -38,7 +36,6 @@ tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy name = "black" version = "23.1.0" description = "The uncompromising code formatter." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -88,7 +85,6 @@ uvloop = ["uvloop (>=0.15.2)"] name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -103,7 +99,6 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -115,7 +110,6 @@ files = [ name = "django" version = "4.1.6" description = "A high-level Python web framework that encourages rapid development and clean, pragmatic design." -category = "main" optional = false python-versions = ">=3.8" files = [ @@ -136,7 +130,6 @@ bcrypt = ["bcrypt"] name = "django-stubs" version = "1.14.0" description = "Mypy stubs for Django" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -160,7 +153,6 @@ compatible-mypy = ["mypy (>=0.991,<1.0)"] name = "django-stubs-ext" version = "0.7.0" description = "Monkey-patching and extensions for django-stubs" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -176,7 +168,6 @@ typing-extensions = "*" name = "exceptiongroup" version = "1.1.0" description = "Backport of PEP 654 (exception groups)" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -191,7 +182,6 @@ test = ["pytest (>=6)"] name = "flake8" version = "6.0.0" description = "the modular source code checker: pep8 pyflakes and co" -category = "dev" optional = false python-versions = ">=3.8.1" files = [ @@ -208,7 +198,6 @@ pyflakes = ">=3.0.0,<3.1.0" name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -220,7 +209,6 @@ files = [ name = "isort" version = "5.12.0" description = "A Python utility / library to sort Python imports." -category = "dev" optional = false python-versions = ">=3.8.0" files = [ @@ -238,7 +226,6 @@ requirements-deprecated-finder = ["pip-api", "pipreqs"] name = "mccabe" version = "0.7.0" description = "McCabe checker, plugin for flake8" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -250,7 +237,6 @@ files = [ name = "mypy" version = "1.0.0" description = "Optional static typing for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -297,7 +283,6 @@ reports = ["lxml"] name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." -category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -309,7 +294,6 @@ files = [ name = "packaging" version = "23.0" description = "Core utilities for Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -321,7 +305,6 @@ files = [ name = "pathspec" version = "0.11.0" description = "Utility library for gitignore style pattern matching of file paths." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -333,7 +316,6 @@ files = [ name = "platformdirs" version = "3.0.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -349,7 +331,6 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytes name = "pluggy" version = "1.0.0" description = "plugin and hook calling mechanisms for python" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -365,7 +346,6 @@ testing = ["pytest", "pytest-benchmark"] name = "psycopg2" version = "2.9.5" description = "psycopg2 - Python-PostgreSQL Database Adapter" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -388,7 +368,6 @@ files = [ name = "pycodestyle" version = "2.10.0" description = "Python style guide checker" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -400,7 +379,6 @@ files = [ name = "pyflakes" version = "3.0.1" description = "passive checker of Python programs" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -412,7 +390,6 @@ files = [ name = "pytest" version = "7.2.1" description = "pytest: simple powerful testing with Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -434,21 +411,24 @@ testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2. [[package]] name = "sqlparse" -version = "0.4.3" +version = "0.4.4" description = "A non-validating SQL parser." -category = "main" optional = false python-versions = ">=3.5" files = [ - {file = "sqlparse-0.4.3-py3-none-any.whl", hash = "sha256:0323c0ec29cd52bceabc1b4d9d579e311f3e4961b98d174201d5622a23b85e34"}, - {file = "sqlparse-0.4.3.tar.gz", hash = "sha256:69ca804846bb114d2ec380e4360a8a340db83f0ccf3afceeb1404df028f57268"}, + {file = "sqlparse-0.4.4-py3-none-any.whl", hash = "sha256:5430a4fe2ac7d0f93e66f1efc6e1338a41884b7ddf2a350cedd20ccc4d9d28f3"}, + {file = "sqlparse-0.4.4.tar.gz", hash = "sha256:d446183e84b8349fa3061f0fe7f06ca94ba65b426946ffebe6e3e8295332420c"}, ] +[package.extras] +dev = ["build", "flake8"] +doc = ["sphinx"] +test = ["pytest", "pytest-cov"] + [[package]] name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -460,7 +440,6 @@ files = [ name = "types-pytz" version = "2022.7.1.0" description = "Typing stubs for pytz" -category = "dev" optional = false python-versions = "*" files = [ @@ -472,7 +451,6 @@ files = [ name = "types-pyyaml" version = "6.0.12.5" description = "Typing stubs for PyYAML" -category = "dev" optional = false python-versions = "*" files = [ @@ -484,7 +462,6 @@ files = [ name = "typing-extensions" version = "4.4.0" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -496,7 +473,6 @@ files = [ name = "tzdata" version = "2022.7" description = "Provider of IANA time zone data" -category = "main" optional = false python-versions = ">=2" files = [ @@ -510,4 +486,4 @@ django = ["Django"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "d9f9054f4c284749ee4abbf2ea36d71e196cee1cdff2dfb87b1207f096652659" +content-hash = "1b4a84ab1af32c4ab489440c5f976f8b3e9991570c369c494644331d89967311" diff --git a/pyproject.toml b/pyproject.toml index d05b6cb..08fadc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.9" Django = {version = ">=3.2", optional = true } +sqlparse = ">=0.3.1" [tool.poetry.extras] django = ["Django"] diff --git a/tests/test_executor.py b/tests/test_executor.py index 9055b40..6b9770b 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1,3 +1,13 @@ +from unittest.mock import Mock + +import pytest +from django.contrib.postgres.operations import ( + AddIndexConcurrently, + RemoveIndexConcurrently, +) +from django.db.migrations import AddIndex, RunSQL, SeparateDatabaseAndState +from django.db.migrations.operations.base import Operation + from migration_checker.executor import Executor from migration_checker.output import ConsoleOutput @@ -7,3 +17,50 @@ def test_executor(setup_db: None) -> None: database="default", apply_migrations=True, outputs=[ConsoleOutput()] ) executor.run() + + +@pytest.mark.parametrize( + "operation,must_be_non_atomic", + [ + (AddIndex("test", index=Mock()), False), + (AddIndexConcurrently("test", index=Mock()), True), + (RemoveIndexConcurrently("foo", "test"), True), + (RunSQL("CREATE INDEX foobar"), False), + (RunSQL("CREATE INDEX foobar CONCURRENTLY", RunSQL.noop), True), + (RunSQL("DROP INDEX foobar CONCURRENTLY", RunSQL.noop), True), + (RunSQL([("CREATE INDEX foobar", None)]), False), + (RunSQL([("CREATE INDEX foobar CONCURRENTLY", None)]), True), + ( + SeparateDatabaseAndState( + database_operations=[RunSQL("CREATE INDEX foobar")] + ), + False, + ), + ( + SeparateDatabaseAndState( + database_operations=[RunSQL("CREATE INDEX foobar CONCURRENTLY")] + ), + True, + ), + ( + SeparateDatabaseAndState( + database_operations=[AddIndex("test", index=Mock())] + ), + False, + ), + ( + SeparateDatabaseAndState( + database_operations=[AddIndexConcurrently("test", index=Mock())] + ), + True, + ), + ], +) +def test_run_sql_must_be_non_atomic( + operation: Operation, must_be_non_atomic: bool +) -> None: + executor = Executor( + database="default", apply_migrations=True, outputs=[ConsoleOutput()] + ) + + assert executor._must_be_non_atomic([operation]) is must_be_non_atomic