From 9337ee4fed4fbc47b09bcecba5ec0b1c4bdb90dd Mon Sep 17 00:00:00 2001 From: saville Date: Wed, 20 Dec 2023 14:14:36 -0700 Subject: [PATCH] Convert build to ruff and run formatting --- .github/workflows/build.yaml | 6 +- .pre-commit-config.yaml | 10 + .pylint-license-header | 7 - .pylintrc | 526 --------------------- .python-version | 2 +- dysql/__init__.py | 24 + dysql/annotations.py | 4 +- dysql/connections.py | 42 +- dysql/databases.py | 78 +-- dysql/mappers.py | 18 +- dysql/pydantic_mappers.py | 13 +- dysql/query_utils.py | 108 +++-- dysql/test/__init__.py | 14 +- dysql/test/test_annotations.py | 52 +- dysql/test/test_database_initialization.py | 116 +++-- dysql/test/test_mappers.py | 224 ++++++--- dysql/test/test_mariadbmap_tojson.py | 57 ++- dysql/test/test_pydantic_mappers.py | 347 +++++++------- dysql/test/test_sql_decorator.py | 155 +++--- dysql/test/test_sql_exists_decorator.py | 56 +-- dysql/test/test_sql_in_list_templates.py | 193 +++++--- dysql/test/test_sql_insert_templates.py | 71 +-- dysql/test/test_template_generators.py | 221 ++++++--- dysql/test_managers.py | 94 ++-- ruff.toml | 20 + setup.py | 82 ++-- test_requirements.in | 10 + test_requirements.txt | 109 ++++- 28 files changed, 1293 insertions(+), 1366 deletions(-) create mode 100644 .pre-commit-config.yaml delete mode 100644 .pylint-license-header delete mode 100644 .pylintrc create mode 100644 ruff.toml create mode 100644 test_requirements.in diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index dc222e8..c46ef00 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -27,10 +27,8 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt pip install -r test_requirements.txt - - name: Lint with pycodestyle - run: pycodestyle dysql - - name: Lint with pylint - run: pylint dysql + - name: Run pre-commit checks + run: pre-commit run --all-files - name: Test with pytest run: pytest --junitxml=test-reports/test-results.xml - name: Publish test results diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..0b2f4c7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.1.7 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/.pylint-license-header b/.pylint-license-header deleted file mode 100644 index b6311f6..0000000 --- a/.pylint-license-header +++ /dev/null @@ -1,7 +0,0 @@ -""" -Copyright 20\d\d Adobe -All Rights Reserved. - -NOTICE: Adobe permits you to use, modify, and distribute this file in accordance -with the terms of the Adobe license agreement accompanying it. -""" \ No newline at end of file diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 35582a9..0000000 --- a/.pylintrc +++ /dev/null @@ -1,526 +0,0 @@ -[MASTER] -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code. -extension-pkg-whitelist=pydantic - -# Specify a score threshold to be exceeded before program exits with error. -fail-under=10.0 - -# Add files or directories to the ignore list. They should be base names, not -# paths. -ignore=CVS - -# Add files or directories matching the regex patterns to the ignore list. The -# regex matches against base names, not paths. -ignore-patterns= - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -#init-hook= - -# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use. -jobs=1 - -# Control the amount of potential inferred values when inferring a single -# object. This can help the performance when dealing with large functions or -# complex, nested conditions. -limit-inference-results=100 - -# List of plugins (as comma separated values of python module names) to load, -# usually to register additional checkers. -load-plugins=pylintfileheader - -# The file header to enforce -file-header-path=.pylint-license-header -# Ignore empty files such as __init__.py -file-header-ignore-empty-files=yes - -# Pickle collected data for later comparisons. -persistent=yes - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. -confidence= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once). You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use "--disable=all --enable=classes -# --disable=W". -disable=raw-checker-failed, - bad-inline-option, - locally-disabled, - file-ignored, - suppressed-message, - useless-suppression, - deprecated-pragma, - use-symbolic-message-instead, - logging-fstring-interpolation, - missing-module-docstring, - missing-class-docstring, - missing-function-docstring, - duplicate-code - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[REPORTS] - -# Python expression which should return a score less than or equal to 10. You -# have access to the variables 'error', 'warning', 'refactor', and 'convention' -# which contain the number of messages in each category, as well as 'statement' -# which is the total number of statements analyzed. This score is used by the -# global evaluation report (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details. -#msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages. -reports=no - -# Activate the evaluation score. -score=yes - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit - - -[LOGGING] - -# The type of string formatting that logging methods do. `old` means using % -# formatting, `new` is for `{}` formatting. -logging-format-style=new - -# Logging modules to check that the string format arguments are in logging -# function parameter format. -logging-modules=logging - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it work, -# install the python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains the private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to the private dictionary (see the -# --spelling-private-dict-file option) instead of raising a message. -spelling-store-unknown-words=no - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - -# Regular expression of note tags to take in consideration. -#notes-rgx= - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - -# List of decorators that change the signature of a decorated function. -signature-mutators= - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expected to -# not be used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore. -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=120 - -# Maximum number of lines in a module. -max-module-lines=1000 - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[SIMILARITIES] - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - -# Minimum lines number of a similarity. -min-similarity-lines=4 - - -[BASIC] - -# Naming style matching correct argument names. -argument-naming-style=snake_case - -# Regular expression matching correct argument names. Overrides argument- -# naming-style. -#argument-rgx= - -# Naming style matching correct attribute names. -attr-naming-style=snake_case - -# Regular expression matching correct attribute names. Overrides attr-naming- -# style. -#attr-rgx= - -# Bad variable names which should always be refused, separated by a comma. -bad-names=foo, - bar, - baz, - toto, - tutu, - tata - -# Bad variable names regexes, separated by a comma. If names match any regex, -# they will always be refused -bad-names-rgxs= - -# Naming style matching correct class attribute names. -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names. Overrides class- -# attribute-naming-style. -#class-attribute-rgx= - -# Naming style matching correct class names. -class-naming-style=PascalCase - -# Regular expression matching correct class names. Overrides class-naming- -# style. -#class-rgx= - -# Naming style matching correct constant names. -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names. Overrides const-naming- -# style. -#const-rgx= - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# Naming style matching correct function names. -function-naming-style=snake_case - -# Regular expression matching correct function names. Overrides function- -# naming-style. -#function-rgx= - -# Good variable names which should always be accepted, separated by a comma. -good-names=i, - j, - k, - ex, - Run, - _ - -# Good variable names regexes, separated by a comma. If names match any regex, -# they will always be accepted -good-names-rgxs= - -# Include a hint for the correct naming format with invalid-name. -include-naming-hint=no - -# Naming style matching correct inline iteration names. -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names. Overrides -# inlinevar-naming-style. -#inlinevar-rgx= - -# Naming style matching correct method names. -method-naming-style=snake_case - -# Regular expression matching correct method names. Overrides method-naming- -# style. -#method-rgx= - -# Naming style matching correct module names. -module-naming-style=snake_case - -# Regular expression matching correct module names. Overrides module-naming- -# style. -#module-rgx= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=^_ - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -# These decorators are taken in consideration only for invalid-name. -property-classes=abc.abstractproperty - -# Naming style matching correct variable names. -variable-naming-style=snake_case - -# Regular expression matching correct variable names. Overrides variable- -# naming-style. -#variable-rgx= - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=no - -# This flag controls whether the implicit-str-concat should generate a warning -# on implicit string concatenation in sequences defined over several lines. -check-str-concat-over-line-jumps=no - - -[IMPORTS] - -# List of modules that can be imported at any level, not just the top level -# one. -allow-any-import-level= - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Deprecated modules which should not be used, separated by a comma. -deprecated-modules=optparse,tkinter.tix - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled). -ext-import-graph= - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled). -import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled). -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - -# Couples of modules and preferred modules, separated by a comma. -preferred-modules= - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp, - __post_init__ - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=cls - - -[DESIGN] - -# Maximum number of arguments for function / method. -max-args=5 - -# Maximum number of attributes for a class (see R0902). -max-attributes=7 - -# Maximum number of boolean expressions in an if statement (see R0916). -max-bool-expr=5 - -# Maximum number of branch for function / method body. -max-branches=12 - -# Maximum number of locals for function / method body. -max-locals=15 - -# Maximum number of parents for a class (see R0901). -max-parents=7 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -# Maximum number of return / yield for function / method body. -max-returns=6 - -# Maximum number of statements in function / method body. -max-statements=50 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "BaseException, Exception". -overgeneral-exceptions=buildins.BaseException, - buildins.Exception diff --git a/.python-version b/.python-version index afad818..0c7d5f5 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.11.0 +3.11.4 diff --git a/dysql/__init__.py b/dysql/__init__.py index d15d41a..2dd5258 100644 --- a/dysql/__init__.py +++ b/dysql/__init__.py @@ -31,3 +31,27 @@ set_default_connection_parameters, ) from .exceptions import DBNotPreparedError + + +__all__ = [ + "BaseMapper", + "DbMapResult", + "RecordCombiningMapper", + "SingleRowMapper", + "SingleColumnMapper", + "SingleRowAndColumnMapper", + "CountMapper", + "KeyValueMapper", + "QueryData", + "QueryDataError", + "TemplateGenerators", + "sqlexists", + "sqlquery", + "sqlupdate", + "is_set_current_database_supported", + "reset_current_database", + "set_current_database", + "set_database_init_hook", + "set_default_connection_parameters", + "DBNotPreparedError", +] diff --git a/dysql/annotations.py b/dysql/annotations.py index 5b5e66f..4a12d05 100644 --- a/dysql/annotations.py +++ b/dysql/annotations.py @@ -11,7 +11,7 @@ from pydantic import BeforeValidator # pylint: disable=invalid-name -T = TypeVar('T') +T = TypeVar("T") def _transform_csv(value: str) -> T: @@ -19,7 +19,7 @@ def _transform_csv(value: str) -> T: return None if isinstance(value, str): - return list(map(str.strip, value.split(','))) + return list(map(str.strip, value.split(","))) if isinstance(value, list): return value diff --git a/dysql/connections.py b/dysql/connections.py index 9ce6b62..29ee2ac 100644 --- a/dysql/connections.py +++ b/dysql/connections.py @@ -21,7 +21,7 @@ from .query_utils import get_query_data -logger = logging.getLogger('database') +logger = logging.getLogger("database") # Always initialize a database container, it is never set again @@ -35,8 +35,10 @@ class _ConnectionManager: def __init__(self, func, isolation_level, transaction, *args, **kwargs): self._transaction = None - self._connection = _DATABASE_CONTAINER.current_database.engine.connect().execution_options( - isolation_level=isolation_level + self._connection = ( + _DATABASE_CONTAINER.current_database.engine.connect().execution_options( + isolation_level=isolation_level + ) ) if transaction: self._transaction = self._connection.begin() @@ -75,7 +77,10 @@ def execute_query(self, query, params=None) -> sqlalchemy.engine.CursorResult: return self._connection.execute(sqlalchemy.text(query), params) -def sqlquery(mapper: Union[BaseMapper, Type[BaseMapper]] = None, isolation_level: str = 'READ_COMMITTED'): +def sqlquery( + mapper: Union[BaseMapper, Type[BaseMapper]] = None, + isolation_level: str = "READ_COMMITTED", +): """ query allows for defining a parameterize select query that is then executed :param mapper: a class extending from or an instance of BaseMapper, defaults to @@ -120,7 +125,9 @@ def handle_query(*args, **kwargs): if inspect.isclass(actual_mapper): actual_mapper = actual_mapper() - with _ConnectionManager(func, isolation_level, False, *args, **kwargs) as conn_manager: + with _ConnectionManager( + func, isolation_level, False, *args, **kwargs + ) as conn_manager: data = func(*args, **kwargs) query, params = get_query_data(data) records = conn_manager.execute_query(query, params) @@ -131,7 +138,7 @@ def handle_query(*args, **kwargs): return decorator -def sqlexists(isolation_level='READ_COMMITTED'): +def sqlexists(isolation_level="READ_COMMITTED"): """ exists query allows for defining a parameterize select query that is wrapped in an exists clause and then executed @@ -155,15 +162,17 @@ def get_items_exists(): def decorator(func): def handle_query(*args, **kwargs): functools.wraps(func, handle_query) - with _ConnectionManager(func, isolation_level, False, *args, **kwargs) as conn_manager: + with _ConnectionManager( + func, isolation_level, False, *args, **kwargs + ) as conn_manager: data = func(*args, **kwargs) query, params = get_query_data(data) query = query.lstrip() - if query.startswith('SELECT EXISTS'): - query = query.replace('SELECT EXISTS', 'SELECT 1 WHERE EXISTS') - if not query.startswith('SELECT 1 WHERE EXISTS'): - query = f'SELECT 1 WHERE EXISTS ( {query} )' + if query.startswith("SELECT EXISTS"): + query = query.replace("SELECT EXISTS", "SELECT 1 WHERE EXISTS") + if not query.startswith("SELECT 1 WHERE EXISTS"): + query = f"SELECT 1 WHERE EXISTS ( {query} )" result = conn_manager.execute_query(query, params).scalar() return result == 1 @@ -172,7 +181,9 @@ def handle_query(*args, **kwargs): return decorator -def sqlupdate(isolation_level='READ_COMMITTED', disable_foreign_key_checks=False, on_success=None): +def sqlupdate( + isolation_level="READ_COMMITTED", disable_foreign_key_checks=False, on_success=None +): """ :param isolation_level should specify whether we can read data from transactions that are not yet committed defaults to READ_COMMITTED @@ -209,8 +220,9 @@ def update_wrapper(func): def handle_query(*args, **kwargs): functools.wraps(func) - with _ConnectionManager(func, isolation_level, True, *args, **kwargs) as conn_manager: - + with _ConnectionManager( + func, isolation_level, True, *args, **kwargs + ) as conn_manager: if disable_foreign_key_checks: conn_manager.execute_query("SET FOREIGN_KEY_CHECKS=0") @@ -228,7 +240,7 @@ def handle_query(*args, **kwargs): data = func(*args, **kwargs) query, params = get_query_data(data) if isinstance(params, list): - raise Exception('Params must not be a list') + raise Exception("Params must not be a list") conn_manager.execute_query(query, params) if disable_foreign_key_checks: diff --git a/dysql/databases.py b/dysql/databases.py index 7276855..72a359b 100644 --- a/dysql/databases.py +++ b/dysql/databases.py @@ -15,18 +15,21 @@ from .exceptions import DBNotPreparedError -logger = logging.getLogger('database') +logger = logging.getLogger("database") _DEFAULT_CONNECTION_PARAMS = {} try: import contextvars - CURRENT_DATABASE_VAR = contextvars.ContextVar("dysql_current_database", default='') + + CURRENT_DATABASE_VAR = contextvars.ContextVar("dysql_current_database", default="") except ImportError: CURRENT_DATABASE_VAR = None -def set_database_init_hook(hook_method: Callable[[Optional[str], sqlalchemy.engine.Engine], None]) -> None: +def set_database_init_hook( + hook_method: Callable[[Optional[str], sqlalchemy.engine.Engine], None], +) -> None: """ Sets an initialization hook whenever a new database is initialized. This method will receive the database name (may be none) and the sqlalchemy engine as parameters. @@ -55,7 +58,7 @@ def set_current_database(database: str) -> None: f'Cannot set the current database on Python "{sys.version}", please upgrade your Python version' ) CURRENT_DATABASE_VAR.set(database) - logger.debug(f'Set current database to {database}') + logger.debug(f"Set current database to {database}") def reset_current_database() -> None: @@ -63,7 +66,7 @@ def reset_current_database() -> None: Helper method to reset the current database to the default. Internally, this calls `set_current_database` with an empty string. """ - set_current_database('') + set_current_database("") def _get_current_database() -> str: @@ -75,25 +78,27 @@ def _get_current_database() -> str: if CURRENT_DATABASE_VAR: database = CURRENT_DATABASE_VAR.get() if not database: - database = _DEFAULT_CONNECTION_PARAMS.get('database') + database = _DEFAULT_CONNECTION_PARAMS.get("database") return database def _validate_param(name: str, value: str) -> None: if not value: - raise DBNotPreparedError(f'Database parameter "{name}" is not set or empty and is required') + raise DBNotPreparedError( + f'Database parameter "{name}" is not set or empty and is required' + ) def set_default_connection_parameters( - host: str, - user: str, - password: str, - database: str, - port: int = 3306, - pool_size: int = 10, - pool_recycle: int = 3600, - echo_queries: bool = False, - charset: str = 'utf8' + host: str, + user: str, + password: str, + database: str, + port: int = 3306, + pool_size: int = 10, + pool_recycle: int = 3600, + echo_queries: bool = False, + charset: str = "utf8", ): # pylint: disable=too-many-arguments,unused-argument """ Initializes the parameters to use when connecting to the database. This is a subset of the parameters @@ -111,10 +116,10 @@ def set_default_connection_parameters( :param charset: the charset for the sql engine to initialize with. (default utf8) :exception DBNotPrepareError: happens when required parameters are missing """ - _validate_param('host', host) - _validate_param('user', user) - _validate_param('password', password) - _validate_param('database', database) + _validate_param("host", host) + _validate_param("user", user) + _validate_param("password", password) + _validate_param("database", database) _DEFAULT_CONNECTION_PARAMS.update(locals()) @@ -129,30 +134,31 @@ def __init__(self, database: Optional[str]) -> None: @classmethod def set_init_hook( - cls, - hook_method: Callable[[Optional[str], sqlalchemy.engine.Engine], None], + cls, + hook_method: Callable[[Optional[str], sqlalchemy.engine.Engine], None], ) -> None: cls.hook_method = hook_method @property def engine(self) -> sqlalchemy.engine.Engine: if not self._engine: - user = _DEFAULT_CONNECTION_PARAMS.get('user') - password = _DEFAULT_CONNECTION_PARAMS.get('password') - host = _DEFAULT_CONNECTION_PARAMS.get('host') - port = _DEFAULT_CONNECTION_PARAMS.get('port') - charset = _DEFAULT_CONNECTION_PARAMS.get('charset') + user = _DEFAULT_CONNECTION_PARAMS.get("user") + password = _DEFAULT_CONNECTION_PARAMS.get("password") + host = _DEFAULT_CONNECTION_PARAMS.get("host") + port = _DEFAULT_CONNECTION_PARAMS.get("port") + charset = _DEFAULT_CONNECTION_PARAMS.get("charset") - url = f'mysql+mysqlconnector://{user}:{password}@{host}:{port}/{self.database}?charset={charset}' + url = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{self.database}?charset={charset}" self._engine = sqlalchemy.create_engine( url, - pool_recycle=_DEFAULT_CONNECTION_PARAMS.get('pool_recycle'), - pool_size=_DEFAULT_CONNECTION_PARAMS.get('pool_size'), - echo=_DEFAULT_CONNECTION_PARAMS.get('echo_queries'), + pool_recycle=_DEFAULT_CONNECTION_PARAMS.get("pool_recycle"), + pool_size=_DEFAULT_CONNECTION_PARAMS.get("pool_size"), + echo=_DEFAULT_CONNECTION_PARAMS.get("echo_queries"), pool_pre_ping=True, ) - hook_method: Optional[Callable[[Optional[str], sqlalchemy.engine.Engine], None]] = \ - getattr(self.__class__, 'hook_method', None) + hook_method: Optional[ + Callable[[Optional[str], sqlalchemy.engine.Engine], None] + ] = getattr(self.__class__, "hook_method", None) if hook_method: hook_method(self.database, self._engine) @@ -163,6 +169,7 @@ class DatabaseContainer(dict): """ Implementation of a dictionary that always provides a Database class instance, even if the key is missing. """ + def __getitem__(self, database: Optional[str]) -> Database: """ Override getitem to always return an instance of a database, which includes a lazy-initialized engine. @@ -173,7 +180,7 @@ def __getitem__(self, database: Optional[str]) -> Database: """ if not _DEFAULT_CONNECTION_PARAMS: raise DBNotPreparedError( - 'Unable to connect to a database, set_default_connection_parameters must first be called' + "Unable to connect to a database, set_default_connection_parameters must first be called" ) if not super().__contains__(database): @@ -194,7 +201,8 @@ class DatabaseContainerSingleton(DatabaseContainer): All instantiations of this class will result in the same instance every time due to the override of the __new__ method. """ - def __new__(cls, *args, **kwargs) -> 'DatabaseContainer': + + def __new__(cls, *args, **kwargs) -> "DatabaseContainer": instance = cls.__dict__.get("__instance__") if instance is not None: return instance diff --git a/dysql/mappers.py b/dysql/mappers.py index 6fad1a3..f560a6a 100644 --- a/dysql/mappers.py +++ b/dysql/mappers.py @@ -23,13 +23,12 @@ class MapperError(Exception): class DbMapResultBase(abc.ABC): - @classmethod def get_key_columns(cls): - return ['id'] + return ["id"] @classmethod - def create_instance(cls, *args, **kwargs) -> 'DbMapResultBase': + def create_instance(cls, *args, **kwargs) -> "DbMapResultBase": """ Called instead of constructor, used to support different ways of creating objects instead of constructors (if desired). @@ -73,7 +72,7 @@ class DbMapResult(DbMapResultBase): def __init__(self, **kwargs): self.__dict__ = kwargs # pylint: disable=invalid-name - if not self.__dict__.get('id'): + if not self.__dict__.get("id"): self.id = None def __getitem__(self, field: str) -> Any: @@ -102,8 +101,8 @@ def get_raw(value: Any) -> dict: else: raw[key] = get_raw(value) # Remove the id field if it was never set - if raw.get('id') is None: - del raw['id'] + if raw.get("id") is None: + del raw["id"] return raw def has(self, field: str) -> bool: @@ -117,6 +116,7 @@ class BaseMapper(metaclass=abc.ABCMeta): """ Extend this class and implement the map_records method to map the results from a database query. """ + @abc.abstractmethod def map_records(self, records: sqlalchemy.engine.CursorResult) -> Any: pass @@ -186,6 +186,7 @@ class SingleRowMapper(BaseMapper): Returns a single mapped result from one or more records. The first record is returned even if there are multiple records from the database. """ + def __init__(self, record_mapper: Optional[Type[DbMapResultBase]] = DbMapResult): self.record_mapper = record_mapper @@ -207,6 +208,7 @@ class SingleColumnMapper(BaseMapper): Returns the first column value for each record from the database, even if multiple columns are defined. This will return a list of scalar values. """ + def map_records(self, records: sqlalchemy.engine.CursorResult) -> Any: results = [] for record in records: @@ -221,6 +223,7 @@ class SingleRowAndColumnMapper(BaseMapper): Returns the first column in the first record from the database, even if multiple records or columns exist. This will return a single scalar or None if there are no records. """ + def map_records(self, records: sqlalchemy.engine.CursorResult) -> Any: for record in records: for value in record.values(): @@ -242,9 +245,10 @@ class KeyValueMapper(BaseMapper): the each key may have more than 1 result. this will returns a dictionary of lists when set """ + def __init__(self, key_column=0, value_column=1, has_multiple_values_per_key=False): if key_column == value_column: - raise MapperError('key and value columns cannot be the same') + raise MapperError("key and value columns cannot be the same") self.key_column = key_column self.value_column = value_column diff --git a/dysql/pydantic_mappers.py b/dysql/pydantic_mappers.py index c6e89e1..5f0d2d2 100644 --- a/dysql/pydantic_mappers.py +++ b/dysql/pydantic_mappers.py @@ -25,6 +25,7 @@ class DbMapResultModel(BaseModel, DbMapResultBase): Additionally, lists, sets, and dicts will ignore null values from the database. Therefore you must provide default values for these fields when used or else validation will fail. """ + # List fields that are aggregated into a string of comma seperated values with basic string splitting on commas _csv_list_fields: Set[str] = set() # List field that are json objects @@ -39,7 +40,7 @@ class DbMapResultModel(BaseModel, DbMapResultBase): _dict_value_mappings: Dict[str, str] = {} @classmethod - def create_instance(cls, *args, **kwargs) -> 'DbMapResultModel': + def create_instance(cls, *args, **kwargs) -> "DbMapResultModel": # Uses the construct method to prevent validation when mapping results return cls.model_construct(*args, **kwargs) @@ -49,7 +50,9 @@ def _map_json(self, current_dict: dict, record: sqlalchemy.engine.Row, field: st if not value: return if not self._has_been_mapped(): - current_dict[field] = TypeAdapter(model_field.annotation).validate_json(value) + current_dict[field] = TypeAdapter(model_field.annotation).validate_json( + value + ) def _map_list(self, current_dict: dict, record: sqlalchemy.engine.Row, field: str): if record[field] is None: @@ -81,7 +84,9 @@ def _map_dict(self, current_dict: dict, record: sqlalchemy.engine.Row, field: st else: current_dict[model_field_name] = {record[field]: record[value_field]} - def _map_list_from_string(self, current_dict: dict, record: sqlalchemy.engine.Row, field: str): + def _map_list_from_string( + self, current_dict: dict, record: sqlalchemy.engine.Row, field: str + ): list_string = record[field] if not list_string: # See note above for lists @@ -89,7 +94,7 @@ def _map_list_from_string(self, current_dict: dict, record: sqlalchemy.engine.Ro # force it to be a string list_string = str(list_string) - values_from_string = list(map(str.strip, list_string.split(','))) + values_from_string = list(map(str.strip, list_string.split(","))) model_field = self.model_fields[field] # pre-validates the list we are expecting because we want to ensure all records are validated diff --git a/dysql/query_utils.py b/dysql/query_utils.py index f7990a4..4cd5b0d 100644 --- a/dysql/query_utils.py +++ b/dysql/query_utils.py @@ -17,7 +17,9 @@ # group 4 - column_name # group 5 - empty space after template. this helps us ensure we add space to a template after but # only 1 -LIST_TEMPLATE_REGEX = re.compile(r'(( +)?{(in|not_in|values)__([A-Za-z_]+\.)?([A-Za-z_]+)}( +)?)') +LIST_TEMPLATE_REGEX = re.compile( + r"(( +)?{(in|not_in|values)__([A-Za-z_]+\.)?([A-Za-z_]+)}( +)?)" +) class TemplateGenerators: @@ -28,20 +30,20 @@ class TemplateGenerators: @classmethod def get_template(cls, name: str): - if name == 'in': + if name == "in": return cls.in_column - if name == 'not_in': + if name == "not_in": return cls.not_in_column - if name == 'values': + if name == "values": return cls.values return None @staticmethod def in_column( - name: str, - values: Union[str, Iterable[str]], - legacy_key: str = None, - is_multi_column: bool = False + name: str, + values: Union[str, Iterable[str]], + legacy_key: str = None, + is_multi_column: bool = False, ) -> Tuple[str, Optional[dict]]: """ Returns query and params for using "IN" SQL queries. @@ -52,18 +54,16 @@ def in_column( :return: a tuple of the query string and the params dictionary """ if not values: - return '1 <> 1', None + return "1 <> 1", None key_name = TemplateGenerators._get_key(name, legacy_key, is_multi_column) keys, values = TemplateGenerators._parameterize_list(key_name, values) if is_multi_column: - keys = f'({keys})' - return f'{name} IN {keys}', values + keys = f"({keys})" + return f"{name} IN {keys}", values @staticmethod def in_multi_column( - name: str, - values: Union[str, Iterable[str]], - legacy_key: str = None + name: str, values: Union[str, Iterable[str]], legacy_key: str = None ): """ A wrapper for in_column with is_multi_column set to true @@ -72,10 +72,10 @@ def in_multi_column( @staticmethod def not_in_column( - name: str, - values: Union[str, Iterable[str]], - legacy_key: str = None, - is_multi_column: bool = False + name: str, + values: Union[str, Iterable[str]], + legacy_key: str = None, + is_multi_column: bool = False, ) -> Tuple[str, Optional[dict]]: """ Returns query and params for using "NOT IN" SQL queries. @@ -86,20 +86,17 @@ def not_in_column( :return: a tuple of the query string and the params dictionary """ if not values: - return '1 = 1', None + return "1 = 1", None key_name = TemplateGenerators._get_key(name, legacy_key, is_multi_column) keys, values = TemplateGenerators._parameterize_list(key_name, values) if is_multi_column: - keys = f'({keys})' - return f'{name} NOT IN {keys}', values + keys = f"({keys})" + return f"{name} NOT IN {keys}", values @staticmethod def not_in_multi_column( - name: str, - values: Union[str, Iterable[str]], - legacy_key: str = None + name: str, values: Union[str, Iterable[str]], legacy_key: str = None ): - """ A wrapper for not_in_column with is_multi_column set to true """ @@ -107,9 +104,9 @@ def not_in_multi_column( @staticmethod def values( - name: str, - values: Union[str, Iterable[str]], - legacy_key: str = None, + name: str, + values: Union[str, Iterable[str]], + legacy_key: str = None, ) -> Tuple[str, Optional[dict]]: """ Returns query and params for using "VALUES" SQL queries. @@ -119,10 +116,10 @@ def values( :return: a tuple of the query string and the params dictionary """ if not values: - raise ListTemplateException(f'Must have values for {name} template') + raise ListTemplateException(f"Must have values for {name} template") key_name = TemplateGenerators._get_key(name, legacy_key, False) keys, values = TemplateGenerators._parameterize_list(key_name, values) - return f'VALUES {keys}', values + return f"VALUES {keys}", values @staticmethod def _get_key(key: str, legacy_key: str, is_multi_column: bool) -> str: @@ -130,11 +127,13 @@ def _get_key(key: str, legacy_key: str, is_multi_column: bool) -> str: if legacy_key: key_name = legacy_key if is_multi_column: - key_name = re.sub('[, ()]', '', key_name) + key_name = re.sub("[, ()]", "", key_name) return key_name @staticmethod - def _parameterize_inner_list(key: str, values: Union[str, Iterable[str]]) -> Tuple[str, Optional[dict]]: + def _parameterize_inner_list( + key: str, values: Union[str, Iterable[str]] + ) -> Tuple[str, Optional[dict]]: param_values = {} parameterized_keys = [] if not isinstance(values, (list, tuple)): @@ -149,7 +148,9 @@ def _parameterize_inner_list(key: str, values: Union[str, Iterable[str]]) -> Tup return f"( :{', :'.join(parameterized_keys)} )", param_values @staticmethod - def _parameterize_list(key: str, values: Union[str, Iterable[str]]) -> Tuple[str, Optional[dict]]: + def _parameterize_list( + key: str, values: Union[str, Iterable[str]] + ) -> Tuple[str, Optional[dict]]: """ Build a string with parameterized values and a dictionary with key value pairs matching the string parameters. @@ -162,16 +163,19 @@ def _parameterize_list(key: str, values: Union[str, Iterable[str]]) -> Tuple[str values = tuple((values,)) for index, value in enumerate(values): - if isinstance(value, tuple) or key.startswith('values'): - param_string, inner_param_values = TemplateGenerators._parameterize_inner_list( - f'{key}_{str(index)}', value + if isinstance(value, tuple) or key.startswith("values"): + ( + param_string, + inner_param_values, + ) = TemplateGenerators._parameterize_inner_list( + f"{key}_{str(index)}", value ) param_values.update(inner_param_values) param_inner_keys.append(param_string) else: return TemplateGenerators._parameterize_inner_list(key, values) - return ', '.join(param_inner_keys), param_values + return ", ".join(param_inner_keys), param_values class ListTemplateException(Exception): @@ -195,10 +199,10 @@ class QueryData: """ def __init__( - self, - query: str, - query_params: dict = None, - template_params: dict = None, + self, + query: str, + query_params: dict = None, + template_params: dict = None, ): """ Constructor. @@ -257,11 +261,11 @@ def __validate_keys_clean_query(query, template_params): # validate if template_params is None or template_params.get(key) is None: missing_keys.append(key) - elif key == 'values' and len(template_params.get(key)) == 0: + elif key == "values" and len(template_params.get(key)) == 0: missing_keys.append(key) if len(missing_keys) > 0: - raise ListTemplateException(f'Missing template keys {missing_keys}') + raise ListTemplateException(f"Missing template keys {missing_keys}") # Clean whitespace as templates will add their own padding later on query = query.replace(groups[0], groups[0].strip()) @@ -270,7 +274,9 @@ def __validate_keys_clean_query(query, template_params): def __validate_query_and_params(data: QueryData) -> None: if not isinstance(data, QueryData): - raise QueryDataError('SQL annotated methods must return an instance of QueryData for query information') + raise QueryDataError( + "SQL annotated methods must return an instance of QueryData for query information" + ) def get_query_data(data: QueryData) -> Tuple[str, dict]: @@ -282,18 +288,22 @@ def get_query_data(data: QueryData) -> Tuple[str, dict]: __validate_query_and_params(data) params = {} - query, validated_keys = __validate_keys_clean_query(data.query, data.template_params) + query, validated_keys = __validate_keys_clean_query( + data.query, data.template_params + ) if data.query_params: params.update(data.query_params) for key in validated_keys: - list_template_key, column_name = tuple(key.split('__')) + list_template_key, column_name = tuple(key.split("__")) template_to_use = TemplateGenerators.get_template(list_template_key) - template_query, param_dict = template_to_use(column_name, data.template_params[key], legacy_key=key) + template_query, param_dict = template_to_use( + column_name, data.template_params[key], legacy_key=key + ) if param_dict: params.update(param_dict) - query_key = '{' + key + '}' - query = query.replace(query_key, f' {template_query} ') + query_key = "{" + key + "}" + query = query.replace(query_key, f" {template_query} ") return query, params diff --git a/dysql/test/__init__.py b/dysql/test/__init__.py index 5af67b5..ad292a5 100644 --- a/dysql/test/__init__.py +++ b/dysql/test/__init__.py @@ -11,9 +11,9 @@ from dysql import set_default_connection_parameters, databases -@pytest.fixture(name='mock_create_engine') +@pytest.fixture(name="mock_create_engine") def mock_create_engine_fixture(): - create_mock = patch('dysql.databases.sqlalchemy.create_engine') + create_mock = patch("dysql.databases.sqlalchemy.create_engine") try: yield create_mock.start() finally: @@ -28,7 +28,7 @@ def setup_mock_engine(mock_create_engine): mock_engine = Mock() mock_engine.connect().execution_options().__enter__ = Mock() mock_engine.connect().execution_options().__exit__ = Mock() - set_default_connection_parameters('fake', 'user', 'password', 'test') + set_default_connection_parameters("fake", "user", "password", "test") # Clear out the databases before attempting to mock anything databases.DatabaseContainerSingleton().clear() @@ -42,7 +42,9 @@ def _verify_query_params(mock_engine, expected_query, expected_args): def _verify_query(mock_engine, expected_query): - execute_call = mock_engine.connect.return_value.execution_options.return_value.execute + execute_call = ( + mock_engine.connect.return_value.execution_options.return_value.execute + ) execute_call.assert_called() query = execute_call.call_args[0][0].text @@ -50,7 +52,9 @@ def _verify_query(mock_engine, expected_query): def _verify_query_args(mock_engine, expected_args): - execute_call = mock_engine.connect.return_value.execution_options.return_value.execute + execute_call = ( + mock_engine.connect.return_value.execution_options.return_value.execute + ) query_args = execute_call.call_args[0][1] assert query_args diff --git a/dysql/test/test_annotations.py b/dysql/test/test_annotations.py index 03a75a1..d7bd9f0 100644 --- a/dysql/test/test_annotations.py +++ b/dysql/test/test_annotations.py @@ -29,33 +29,39 @@ class NullableIntCSVModel(BaseModel): values: FromCSVToList[Union[List[int], None]] -@pytest.mark.parametrize('cls, values, expected', [ - (StrCSVModel, '1,2,3', ['1', '2', '3']), - (StrCSVModel, 'a,b', ['a', 'b']), - (StrCSVModel, 'a', ['a']), - (NullableStrCSVModel, '', None), - (NullableStrCSVModel, None, None), - (StrCSVModel, ['a', 'b'], ['a', 'b']), - (StrCSVModel, ['a', 'b', 'c'], ['a', 'b', 'c']), - (IntCSVModel, '1,2,3', [1, 2, 3]), - (IntCSVModel, '1', [1]), - (NullableIntCSVModel, '', None), - (NullableIntCSVModel, None, None), - (IntCSVModel, ['1', '2', '3'], [1, 2, 3]), - (IntCSVModel, ['1', '2', '3', 4, 5], [1, 2, 3, 4, 5]) -]) +@pytest.mark.parametrize( + "cls, values, expected", + [ + (StrCSVModel, "1,2,3", ["1", "2", "3"]), + (StrCSVModel, "a,b", ["a", "b"]), + (StrCSVModel, "a", ["a"]), + (NullableStrCSVModel, "", None), + (NullableStrCSVModel, None, None), + (StrCSVModel, ["a", "b"], ["a", "b"]), + (StrCSVModel, ["a", "b", "c"], ["a", "b", "c"]), + (IntCSVModel, "1,2,3", [1, 2, 3]), + (IntCSVModel, "1", [1]), + (NullableIntCSVModel, "", None), + (NullableIntCSVModel, None, None), + (IntCSVModel, ["1", "2", "3"], [1, 2, 3]), + (IntCSVModel, ["1", "2", "3", 4, 5], [1, 2, 3, 4, 5]), + ], +) def test_from_csv_to_list(cls, values, expected): assert expected == cls(values=values).values -@pytest.mark.parametrize('cls, values', [ - (StrCSVModel, ''), - (StrCSVModel, None), - (IntCSVModel, 'a,b,c'), - (IntCSVModel, ''), - (IntCSVModel, None), - (IntCSVModel, ['a', 'b', 'c']), -]) +@pytest.mark.parametrize( + "cls, values", + [ + (StrCSVModel, ""), + (StrCSVModel, None), + (IntCSVModel, "a,b,c"), + (IntCSVModel, ""), + (IntCSVModel, None), + (IntCSVModel, ["a", "b", "c"]), + ], +) def test_from_csv_to_list_invalid(cls, values): with pytest.raises(ValidationError): cls(values=values) diff --git a/dysql/test/test_database_initialization.py b/dysql/test/test_database_initialization.py index a380f04..15bea45 100644 --- a/dysql/test/test_database_initialization.py +++ b/dysql/test/test_database_initialization.py @@ -12,7 +12,13 @@ import pytest import dysql.connections -from dysql import sqlquery, DBNotPreparedError, set_default_connection_parameters, QueryData, set_database_init_hook +from dysql import ( + sqlquery, + DBNotPreparedError, + set_default_connection_parameters, + QueryData, + set_database_init_hook, +) from dysql.test import mock_create_engine_fixture, setup_mock_engine _ = mock_create_engine_fixture @@ -34,7 +40,7 @@ def query(): return QueryData("SELECT * FROM table") -@pytest.fixture(autouse=True, name='mock_engine') +@pytest.fixture(autouse=True, name="mock_engine") def fixture_mock_engine(mock_create_engine): dysql.databases.DatabaseContainerSingleton().clear() dysql.databases._DEFAULT_CONNECTION_PARAMS.clear() @@ -53,36 +59,44 @@ def fixture_mock_engine(mock_create_engine): @pytest.fixture(autouse=True) def fixture_reset_init_hook(): yield - if hasattr(dysql.databases.Database, 'hook_method'): - delattr(dysql.databases.Database, 'hook_method') + if hasattr(dysql.databases.Database, "hook_method"): + delattr(dysql.databases.Database, "hook_method") def test_nothing_set(): dysql.databases._DEFAULT_CONNECTION_PARAMS.clear() with pytest.raises(DBNotPreparedError) as error: query() - assert str(error.value) == \ - "Unable to connect to a database, set_default_connection_parameters must first be called" - - -@pytest.mark.parametrize('host, user, password, database, failed_field', [ - (None, 'u', 'p', 'd', 'host'), - ('', 'u', 'p', 'd', 'host'), - ('h', None, 'p', 'd', 'user'), - ('h', '', 'p', 'd', 'user'), - ('h', 'u', None, 'd', 'password'), - ('h', 'u', '', 'd', 'password'), - ('h', 'u', 'p', None, 'database'), - ('h', 'u', 'p', '', 'database'), -]) + assert ( + str(error.value) + == "Unable to connect to a database, set_default_connection_parameters must first be called" + ) + + +@pytest.mark.parametrize( + "host, user, password, database, failed_field", + [ + (None, "u", "p", "d", "host"), + ("", "u", "p", "d", "host"), + ("h", None, "p", "d", "user"), + ("h", "", "p", "d", "user"), + ("h", "u", None, "d", "password"), + ("h", "u", "", "d", "password"), + ("h", "u", "p", None, "database"), + ("h", "u", "p", "", "database"), + ], +) def test_fields_required(host, user, password, database, failed_field): with pytest.raises(DBNotPreparedError) as error: set_default_connection_parameters(host, user, password, database) - assert str(error.value) == f'Database parameter "{failed_field}" is not set or empty and is required' + assert ( + str(error.value) + == f'Database parameter "{failed_field}" is not set or empty and is required' + ) def test_minimal_credentials(mock_engine): - set_default_connection_parameters('h', 'u', 'p', 'd') + set_default_connection_parameters("h", "u", "p", "d") mock_engine.connect().execution_options().execute.return_value = [] query() @@ -91,26 +105,28 @@ def test_minimal_credentials(mock_engine): def test_init_hook(mock_engine): init_hook = mock.MagicMock() set_database_init_hook(init_hook) - set_default_connection_parameters('h', 'u', 'p', 'd') + set_default_connection_parameters("h", "u", "p", "d") mock_engine.connect().execution_options().execute.return_value = [] query() - init_hook.assert_called_once_with('d', mock_engine) + init_hook.assert_called_once_with("d", mock_engine) -@pytest.mark.skipif('3.6' in sys.version, reason='set_current_database is not supported on python 3.6') +@pytest.mark.skipif( + "3.6" in sys.version, reason="set_current_database is not supported on python 3.6" +) def test_init_hook_multiple_databases(mock_engine): init_hook = mock.MagicMock() set_database_init_hook(init_hook) - set_default_connection_parameters('h', 'u', 'p', 'd1') + set_default_connection_parameters("h", "u", "p", "d1") mock_engine.connect().execution_options().execute.return_value = [] query() - dysql.databases.set_current_database('d2') + dysql.databases.set_current_database("d2") query() assert init_hook.call_args_list == [ - mock.call('d1', mock_engine), - mock.call('d2', mock_engine), + mock.call("d1", mock_engine), + mock.call("d2", mock_engine), ] @@ -122,10 +138,10 @@ def test_current_database_default(mock_engine, mock_create_engine): # Only one database is initialized assert len(db_container) == 1 - assert 'test' in db_container - assert db_container.current_database.database == 'test' + assert "test" in db_container + assert db_container.current_database.database == "test" mock_create_engine.assert_called_once_with( - 'mysql+mysqlconnector://user:password@fake:3306/test?charset=utf8', + "mysql+mysqlconnector://user:password@fake:3306/test?charset=utf8", echo=False, pool_pre_ping=True, pool_recycle=3600, @@ -135,14 +151,16 @@ def test_current_database_default(mock_engine, mock_create_engine): def test_different_charset(mock_engine, mock_create_engine): db_container = dysql.databases.DatabaseContainerSingleton() - set_default_connection_parameters('host', 'user', 'password', 'database', charset='other') + set_default_connection_parameters( + "host", "user", "password", "database", charset="other" + ) assert len(db_container) == 0 mock_engine.connect().execution_options().execute.return_value = [] query() # Only one database is initialized mock_create_engine.assert_called_once_with( - 'mysql+mysqlconnector://user:password@host:3306/database?charset=other', + "mysql+mysqlconnector://user:password@host:3306/database?charset=other", echo=False, pool_pre_ping=True, pool_recycle=3600, @@ -152,24 +170,26 @@ def test_different_charset(mock_engine, mock_create_engine): def test_is_set_current_database_supported(): # This test only returns different outputs depending on the python runtime - if '3.6' in sys.version: + if "3.6" in sys.version: assert not dysql.databases.is_set_current_database_supported() else: assert dysql.databases.is_set_current_database_supported() -@pytest.mark.skipif('3.6' in sys.version, reason='set_current_database is not supported on python 3.6') +@pytest.mark.skipif( + "3.6" in sys.version, reason="set_current_database is not supported on python 3.6" +) def test_current_database_set(mock_engine, mock_create_engine): db_container = dysql.databases.DatabaseContainerSingleton() - dysql.databases.set_current_database('db1') + dysql.databases.set_current_database("db1") mock_engine.connect().execution_options().execute.return_value = [] query() assert len(db_container) == 1 - assert 'db1' in db_container - assert db_container.current_database.database == 'db1' + assert "db1" in db_container + assert db_container.current_database.database == "db1" mock_create_engine.assert_called_once_with( - 'mysql+mysqlconnector://user:password@fake:3306/db1?charset=utf8', + "mysql+mysqlconnector://user:password@fake:3306/db1?charset=utf8", echo=False, pool_pre_ping=True, pool_recycle=3600, @@ -177,39 +197,41 @@ def test_current_database_set(mock_engine, mock_create_engine): ) -@pytest.mark.skipif('3.6' in sys.version, reason='set_current_database is not supported on python 3.6') +@pytest.mark.skipif( + "3.6" in sys.version, reason="set_current_database is not supported on python 3.6" +) def test_current_database_cached(mock_engine, mock_create_engine): db_container = dysql.databases.DatabaseContainerSingleton() mock_engine.connect().execution_options().execute.return_value = [] query() assert len(db_container) == 1 - assert 'test' in db_container - assert db_container.current_database.database == 'test' + assert "test" in db_container + assert db_container.current_database.database == "test" - dysql.databases.set_current_database('db1') + dysql.databases.set_current_database("db1") query() assert len(db_container) == 2 - assert 'test' in db_container - assert db_container.current_database.database == 'db1' + assert "test" in db_container + assert db_container.current_database.database == "db1" # Set back to default dysql.databases.reset_current_database() query() assert len(db_container) == 2 - assert db_container.current_database.database == 'test' + assert db_container.current_database.database == "test" assert mock_create_engine.call_count == 2 assert mock_create_engine.call_args_list == [ mock.call( - 'mysql+mysqlconnector://user:password@fake:3306/test?charset=utf8', + "mysql+mysqlconnector://user:password@fake:3306/test?charset=utf8", echo=False, pool_pre_ping=True, pool_recycle=3600, pool_size=10, ), mock.call( - 'mysql+mysqlconnector://user:password@fake:3306/db1?charset=utf8', + "mysql+mysqlconnector://user:password@fake:3306/db1?charset=utf8", echo=False, pool_pre_ping=True, pool_recycle=3600, diff --git a/dysql/test/test_mappers.py b/dysql/test/test_mappers.py index dc624da..02de9a0 100644 --- a/dysql/test/test_mappers.py +++ b/dysql/test/test_mappers.py @@ -12,7 +12,8 @@ SingleRowMapper, SingleColumnMapper, SingleRowAndColumnMapper, - CountMapper, KeyValueMapper + CountMapper, + KeyValueMapper, ) from dysql.mappers import MapperError @@ -29,112 +30,185 @@ def _unwrap_results(results): def test_record_combining(self): mapper = RecordCombiningMapper() assert len(mapper.map_records([])) == 0 - assert self._unwrap_results(mapper.map_records([{'a': 1, 'b': 2}])) == [{'a': 1, 'b': 2}] - assert self._unwrap_results(mapper.map_records([ - {'id': 1, 'a': 1, 'b': 2}, - {'id': 2, 'a': 1, 'b': 2}, - {'id': 1, 'c': 3}, - ])) == [ - {'id': 1, 'a': 1, 'b': 2, 'c': 3}, - {'id': 2, 'a': 1, 'b': 2}, - ] + assert self._unwrap_results(mapper.map_records([{"a": 1, "b": 2}])) == [ + {"a": 1, "b": 2} + ] + assert self._unwrap_results( + mapper.map_records( + [ + {"id": 1, "a": 1, "b": 2}, + {"id": 2, "a": 1, "b": 2}, + {"id": 1, "c": 3}, + ] + ) + ) == [ + {"id": 1, "a": 1, "b": 2, "c": 3}, + {"id": 2, "a": 1, "b": 2}, + ] @staticmethod def test_record_combining_no_record_mapper(): mapper = RecordCombiningMapper(record_mapper=None) assert len(mapper.map_records([])) == 0 - assert mapper.map_records([{'a': 1, 'b': 2}]) == [{'a': 1, 'b': 2}] - assert mapper.map_records([ - {'id': 1, 'a': 1, 'b': 2}, - {'id': 2, 'a': 1, 'b': 2}, - {'id': 1, 'c': 3}, - ]) == [ - {'id': 1, 'a': 1, 'b': 2}, - {'id': 2, 'a': 1, 'b': 2}, - {'id': 1, 'c': 3}, - ] + assert mapper.map_records([{"a": 1, "b": 2}]) == [{"a": 1, "b": 2}] + assert mapper.map_records( + [ + {"id": 1, "a": 1, "b": 2}, + {"id": 2, "a": 1, "b": 2}, + {"id": 1, "c": 3}, + ] + ) == [ + {"id": 1, "a": 1, "b": 2}, + {"id": 2, "a": 1, "b": 2}, + {"id": 1, "c": 3}, + ] @staticmethod def test_single_row(): mapper = SingleRowMapper() assert mapper.map_records([]) is None - assert mapper.map_records([{'a': 1, 'b': 2}]).raw() == {'a': 1, 'b': 2} - assert mapper.map_records([ - {'id': 1, 'a': 1, 'b': 2}, - {'id': 2, 'a': 1, 'b': 2}, - {'id': 1, 'c': 3}, - ]).raw() == {'id': 1, 'a': 1, 'b': 2} + assert mapper.map_records([{"a": 1, "b": 2}]).raw() == {"a": 1, "b": 2} + assert mapper.map_records( + [ + {"id": 1, "a": 1, "b": 2}, + {"id": 2, "a": 1, "b": 2}, + {"id": 1, "c": 3}, + ] + ).raw() == {"id": 1, "a": 1, "b": 2} @staticmethod def test_single_row_no_record_mapper(): mapper = SingleRowMapper(record_mapper=None) assert mapper.map_records([]) is None - assert mapper.map_records([{'a': 1, 'b': 2}]) == {'a': 1, 'b': 2} - assert mapper.map_records([ - {'id': 1, 'a': 1, 'b': 2}, - {'id': 2, 'a': 1, 'b': 2}, - {'id': 1, 'c': 3}, - ]) == {'id': 1, 'a': 1, 'b': 2} + assert mapper.map_records([{"a": 1, "b": 2}]) == {"a": 1, "b": 2} + assert mapper.map_records( + [ + {"id": 1, "a": 1, "b": 2}, + {"id": 2, "a": 1, "b": 2}, + {"id": 1, "c": 3}, + ] + ) == {"id": 1, "a": 1, "b": 2} @staticmethod def test_single_column(): mapper = SingleColumnMapper() assert len(mapper.map_records([])) == 0 - assert mapper.map_records([{'a': 1, 'b': 2}]) == [1] - assert mapper.map_records([ - {'id': 'myid1', 'a': 1, 'b': 2}, - {'id': 'myid2', 'a': 1, 'b': 2}, - {'id': 'myid3', 'c': 3}, - ]) == ['myid1', 'myid2', 'myid3'] + assert mapper.map_records([{"a": 1, "b": 2}]) == [1] + assert mapper.map_records( + [ + {"id": "myid1", "a": 1, "b": 2}, + {"id": "myid2", "a": 1, "b": 2}, + {"id": "myid3", "c": 3}, + ] + ) == ["myid1", "myid2", "myid3"] @staticmethod - @pytest.mark.parametrize('mapper', [ - SingleRowAndColumnMapper(), - # Alias for the other one - CountMapper(), - ]) + @pytest.mark.parametrize( + "mapper", + [ + SingleRowAndColumnMapper(), + # Alias for the other one + CountMapper(), + ], + ) def test_single_column_and_row(mapper): assert mapper.map_records([]) is None - assert mapper.map_records([{'a': 1, 'b': 2}]) == 1 - assert mapper.map_records([ - {'id': 'myid', 'a': 1, 'b': 2}, - {'id': 2, 'a': 1, 'b': 2}, - {'id': 1, 'c': 3}, - ]) == 'myid' + assert mapper.map_records([{"a": 1, "b": 2}]) == 1 + assert ( + mapper.map_records( + [ + {"id": "myid", "a": 1, "b": 2}, + {"id": 2, "a": 1, "b": 2}, + {"id": 1, "c": 3}, + ] + ) + == "myid" + ) @staticmethod - @pytest.mark.parametrize('mapper, expected', [ - (KeyValueMapper(), {'a': 4, 'b': 7}), - (KeyValueMapper(key_column='column_named_something'), {'a': 4, 'b': 7}), - (KeyValueMapper(value_column='column_with_some_value'), {'a': 4, 'b': 7}), - (KeyValueMapper(has_multiple_values_per_key=True), {'a': [1, 2, 3, 4], 'b': [3, 4, 5, 6, 7]}), - (KeyValueMapper(key_column='column_named_something', has_multiple_values_per_key=True), - {'a': [1, 2, 3, 4], 'b': [3, 4, 5, 6, 7]}), - (KeyValueMapper(key_column='column_with_some_value', value_column='column_named_something', - has_multiple_values_per_key=True), - {1: ['a'], 2: ['a'], 3: ['a', 'b'], 4: ['a', 'b'], 5: ['b'], 6: ['b'], 7: ['b']}), - (KeyValueMapper(key_column='column_with_some_value', value_column='column_named_something'), - {1: 'a', 2: 'a', 3: 'b', 4: 'b', 5: 'b', 6: 'b', 7: 'b'}), - ]) + @pytest.mark.parametrize( + "mapper, expected", + [ + (KeyValueMapper(), {"a": 4, "b": 7}), + (KeyValueMapper(key_column="column_named_something"), {"a": 4, "b": 7}), + (KeyValueMapper(value_column="column_with_some_value"), {"a": 4, "b": 7}), + ( + KeyValueMapper(has_multiple_values_per_key=True), + {"a": [1, 2, 3, 4], "b": [3, 4, 5, 6, 7]}, + ), + ( + KeyValueMapper( + key_column="column_named_something", + has_multiple_values_per_key=True, + ), + {"a": [1, 2, 3, 4], "b": [3, 4, 5, 6, 7]}, + ), + ( + KeyValueMapper( + key_column="column_with_some_value", + value_column="column_named_something", + has_multiple_values_per_key=True, + ), + { + 1: ["a"], + 2: ["a"], + 3: ["a", "b"], + 4: ["a", "b"], + 5: ["b"], + 6: ["b"], + 7: ["b"], + }, + ), + ( + KeyValueMapper( + key_column="column_with_some_value", + value_column="column_named_something", + ), + {1: "a", 2: "a", 3: "b", 4: "b", 5: "b", 6: "b", 7: "b"}, + ), + ], + ) def test_key_mapper_key_has_multiple(mapper, expected): - result = mapper.map_records([ - HelperRow(('column_named_something', 'column_with_some_value'), ['a', 1]), - HelperRow(('column_named_something', 'column_with_some_value'), ['a', 2]), - HelperRow(('column_named_something', 'column_with_some_value'), ['a', 3]), - HelperRow(('column_named_something', 'column_with_some_value'), ['a', 4]), - HelperRow(('column_named_something', 'column_with_some_value'), ['b', 3]), - HelperRow(('column_named_something', 'column_with_some_value'), ['b', 4]), - HelperRow(('column_named_something', 'column_with_some_value'), ['b', 5]), - HelperRow(('column_named_something', 'column_with_some_value'), ['b', 6]), - HelperRow(('column_named_something', 'column_with_some_value'), ['b', 7]), - ]) + result = mapper.map_records( + [ + HelperRow( + ("column_named_something", "column_with_some_value"), ["a", 1] + ), + HelperRow( + ("column_named_something", "column_with_some_value"), ["a", 2] + ), + HelperRow( + ("column_named_something", "column_with_some_value"), ["a", 3] + ), + HelperRow( + ("column_named_something", "column_with_some_value"), ["a", 4] + ), + HelperRow( + ("column_named_something", "column_with_some_value"), ["b", 3] + ), + HelperRow( + ("column_named_something", "column_with_some_value"), ["b", 4] + ), + HelperRow( + ("column_named_something", "column_with_some_value"), ["b", 5] + ), + HelperRow( + ("column_named_something", "column_with_some_value"), ["b", 6] + ), + HelperRow( + ("column_named_something", "column_with_some_value"), ["b", 7] + ), + ] + ) assert len(result) == len(expected) assert result == expected @staticmethod def test_key_mapper_key_value_same(): - with pytest.raises(MapperError, match='key and value columns cannot be the same'): - KeyValueMapper(key_column='same', value_column='same') + with pytest.raises( + MapperError, match="key and value columns cannot be the same" + ): + KeyValueMapper(key_column="same", value_column="same") class HelperRow: # pylint: disable=too-few-public-methods diff --git a/dysql/test/test_mariadbmap_tojson.py b/dysql/test/test_mariadbmap_tojson.py index 61bda69..0e5da2e 100644 --- a/dysql/test/test_mariadbmap_tojson.py +++ b/dysql/test/test_mariadbmap_tojson.py @@ -11,29 +11,36 @@ from dysql import DbMapResult -@pytest.mark.parametrize("name, expected_json, mariadb_map", [ - ( - 'basic', - '{"id": 1, "name": "test"}', - DbMapResult(id=1, name='test'), - ), - ( - 'basic_with_list', - '{"id": 1, "name": "test", "my_list": ["a", "b", "c"]}', - DbMapResult(id=1, name='test', my_list=['a', 'b', 'c']), - ), - ( - 'inner_map', - '{"id": 1, "inner_map": {"id": 2, "name": "inner_test"}}', - DbMapResult(id=1, inner_map=DbMapResult(id=2, name='inner_test')), - ), - ( - 'inner_list_of_maps', - '{"id": 1, "inner_map_list": [{"id": 2, "name": "inner_test_2"}, {"id": 3, "name": "inner_test_3"}]}', - DbMapResult(id=1, inner_map_list=[ - DbMapResult(id=2, name='inner_test_2'), DbMapResult(id=3, name='inner_test_3') - ]) - ), -]) +@pytest.mark.parametrize( + "name, expected_json, mariadb_map", + [ + ( + "basic", + '{"id": 1, "name": "test"}', + DbMapResult(id=1, name="test"), + ), + ( + "basic_with_list", + '{"id": 1, "name": "test", "my_list": ["a", "b", "c"]}', + DbMapResult(id=1, name="test", my_list=["a", "b", "c"]), + ), + ( + "inner_map", + '{"id": 1, "inner_map": {"id": 2, "name": "inner_test"}}', + DbMapResult(id=1, inner_map=DbMapResult(id=2, name="inner_test")), + ), + ( + "inner_list_of_maps", + '{"id": 1, "inner_map_list": [{"id": 2, "name": "inner_test_2"}, {"id": 3, "name": "inner_test_3"}]}', + DbMapResult( + id=1, + inner_map_list=[ + DbMapResult(id=2, name="inner_test_2"), + DbMapResult(id=3, name="inner_test_3"), + ], + ), + ), + ], +) def test_raw_json_format(name, expected_json, mariadb_map): - assert json.dumps(mariadb_map.raw()) == expected_json, 'error with ' + name + assert json.dumps(mariadb_map.raw()) == expected_json, "error with " + name diff --git a/dysql/test/test_pydantic_mappers.py b/dysql/test/test_pydantic_mappers.py index 4893361..04a1b23 100644 --- a/dysql/test/test_pydantic_mappers.py +++ b/dysql/test/test_pydantic_mappers.py @@ -26,10 +26,10 @@ class ConversionDbModel(DbMapResultModel): class CombiningDbModel(DbMapResultModel): - _list_fields: Set[str] = {'list1'} - _set_fields: Set[str] = {'set1'} - _dict_key_fields: Dict[str, str] = {'key1': 'dict1', 'key2': 'dict2'} - _dict_value_mappings: Dict[str, str] = {'dict1': 'val1', 'dict2': 'val2'} + _list_fields: Set[str] = {"list1"} + _set_fields: Set[str] = {"set1"} + _dict_key_fields: Dict[str, str] = {"key1": "dict1", "key2": "dict2"} + _dict_value_mappings: Dict[str, str] = {"dict1": "val1", "dict2": "val2"} id: int = None list1: List[str] @@ -43,7 +43,7 @@ class DefaultListCombiningDbModel(CombiningDbModel): class ListWithStringsModel(DbMapResultModel): - _csv_list_fields: Set[str] = {'list1', 'list2'} + _csv_list_fields: Set[str] = {"list1", "list2"} id: int list1: Optional[List[str]] = None @@ -51,7 +51,7 @@ class ListWithStringsModel(DbMapResultModel): class JsonModel(DbMapResultModel): - _json_fields: Set[str] = {'json1', 'json2'} + _json_fields: Set[str] = {"json1", "json2"} id: int json1: dict @@ -59,12 +59,11 @@ class JsonModel(DbMapResultModel): class MultiKeyModel(DbMapResultModel): - @classmethod def get_key_columns(cls): - return ['a', 'b'] + return ["a", "b"] - _list_fields = {'c'} + _list_fields = {"c"} a: int b: str c: List[str] @@ -76,242 +75,216 @@ def _unwrap_results(results): def test_field_conversion(): mapper = SingleRowMapper(record_mapper=ConversionDbModel) - assert mapper.map_records([ - {'id': 1, 'field_str': 'str1', 'field_int': 1, 'field_bool': 1}, - ]).raw() == {'id': 1, 'field_str': 'str1', 'field_int': 1, 'field_bool': True} + assert mapper.map_records( + [ + {"id": 1, "field_str": "str1", "field_int": 1, "field_bool": 1}, + ] + ).raw() == {"id": 1, "field_str": "str1", "field_int": 1, "field_bool": True} def test_complex_object_record_combining(): mapper = RecordCombiningMapper(record_mapper=CombiningDbModel) assert len(mapper.map_records([])) == 0 - assert _unwrap_results(mapper.map_records([ - {'id': 1, 'list1': 'val1', 'set1': 'val2', 'key1': 'k1', 'val1': 'v1', 'key2': 'k3', 'val2': 3}, - {'id': 2, 'list1': 'val1'}, - {'id': 1, 'list1': 'val3', 'set1': 'val4', 'key1': 'k2', 'val1': 'v2', 'key2': 'k4', 'val2': 4}, - ])) == [ - { - 'id': 1, - 'list1': ['val1', 'val3'], - 'set1': {'val2', 'val4'}, - 'dict1': {'k1': 'v1', 'k2': 'v2'}, - 'dict2': {'k3': 3, 'k4': 4}, - }, - { - 'id': 2, - 'list1': ['val1'], - 'set1': set(), - 'dict1': {}, - 'dict2': {}, - }, - ] + assert _unwrap_results( + mapper.map_records( + [ + { + "id": 1, + "list1": "val1", + "set1": "val2", + "key1": "k1", + "val1": "v1", + "key2": "k3", + "val2": 3, + }, + {"id": 2, "list1": "val1"}, + { + "id": 1, + "list1": "val3", + "set1": "val4", + "key1": "k2", + "val1": "v2", + "key2": "k4", + "val2": 4, + }, + ] + ) + ) == [ + { + "id": 1, + "list1": ["val1", "val3"], + "set1": {"val2", "val4"}, + "dict1": {"k1": "v1", "k2": "v2"}, + "dict2": {"k3": 3, "k4": 4}, + }, + { + "id": 2, + "list1": ["val1"], + "set1": set(), + "dict1": {}, + "dict2": {}, + }, + ] def test_record_combining_multi_column_id(): mapper = RecordCombiningMapper(MultiKeyModel) assert len(mapper.map_records([])) == 0 - assert _unwrap_results(mapper.map_records([ - {'a': 1, 'b': 'test', 'c': 'one'}, - {'a': 1, 'b': 'test', 'c': 'two'}, - {'a': 1, 'b': 'test', 'c': 'three'}, - {'a': 2, 'b': 'test', 'c': 'one'}, - {'a': 2, 'b': 'test', 'c': 'two'}, - {'a': 1, 'b': 'othertest', 'c': 'one'}, - ])) == [ - {'a': 1, 'b': 'test', 'c': ['one', 'two', 'three']}, - {'a': 2, 'b': 'test', 'c': ['one', 'two']}, - {'a': 1, 'b': 'othertest', 'c': ['one']}, + assert _unwrap_results( + mapper.map_records( + [ + {"a": 1, "b": "test", "c": "one"}, + {"a": 1, "b": "test", "c": "two"}, + {"a": 1, "b": "test", "c": "three"}, + {"a": 2, "b": "test", "c": "one"}, + {"a": 2, "b": "test", "c": "two"}, + {"a": 1, "b": "othertest", "c": "one"}, + ] + ) + ) == [ + {"a": 1, "b": "test", "c": ["one", "two", "three"]}, + {"a": 2, "b": "test", "c": ["one", "two"]}, + {"a": 1, "b": "othertest", "c": ["one"]}, ] def test_complex_object_with_null_values(): mapper = SingleRowMapper(record_mapper=DefaultListCombiningDbModel) - assert mapper.map_records([ - {'id': 1}, - ]).raw() == { - 'id': 1, - 'list1': [], - 'set1': set(), - 'dict1': {}, - 'dict2': {}, - } + assert mapper.map_records( + [ + {"id": 1}, + ] + ).raw() == { + "id": 1, + "list1": [], + "set1": set(), + "dict1": {}, + "dict2": {}, + } def test_csv_list_field(): mapper = SingleRowMapper(record_mapper=ListWithStringsModel) - assert mapper.map_records([{ - 'id': 1, - 'list1': 'a,b,c,d', - 'list2': '1,2,3,4' - }]).raw() == { - 'id': 1, - 'list1': ['a', 'b', 'c', 'd'], - 'list2': [1, 2, 3, 4] - } + assert mapper.map_records( + [{"id": 1, "list1": "a,b,c,d", "list2": "1,2,3,4"}] + ).raw() == {"id": 1, "list1": ["a", "b", "c", "d"], "list2": [1, 2, 3, 4]} def test_csv_list_field_single_value(): mapper = SingleRowMapper(record_mapper=ListWithStringsModel) - assert mapper.map_records([{ - 'id': 1, - 'list1': 'a', - 'list2': '1' - }]).raw() == { - 'id': 1, - 'list1': ['a'], - 'list2': [1] - } + assert mapper.map_records([{"id": 1, "list1": "a", "list2": "1"}]).raw() == { + "id": 1, + "list1": ["a"], + "list2": [1], + } def test_csv_list_field_empty_string(): mapper = SingleRowMapper(record_mapper=ListWithStringsModel) - assert mapper.map_records([{ - 'id': 1, - 'list1': '', - 'list2': '' - }]).raw() == { - 'id': 1, - 'list1': None, - 'list2': [] - } + assert mapper.map_records([{"id": 1, "list1": "", "list2": ""}]).raw() == { + "id": 1, + "list1": None, + "list2": [], + } def test_csv_list_field_extends(): mapper = RecordCombiningMapper(record_mapper=ListWithStringsModel) - assert mapper.map_records([{ - 'id': 1, - 'list1': 'a,b', - 'list2': '1,2' - }, { - 'id': 1, - 'list1': 'c,d', - 'list2': '3,4' - }])[0].raw() == { - 'id': 1, - 'list1': ['a', 'b', 'c', 'd'], - 'list2': [1, 2, 3, 4] - } + assert mapper.map_records( + [ + {"id": 1, "list1": "a,b", "list2": "1,2"}, + {"id": 1, "list1": "c,d", "list2": "3,4"}, + ] + )[0].raw() == {"id": 1, "list1": ["a", "b", "c", "d"], "list2": [1, 2, 3, 4]} def test_csv_list_field_multiple_records_duplicates(): mapper = RecordCombiningMapper(record_mapper=ListWithStringsModel) - assert mapper.map_records([{ - 'id': 1, - 'list1': 'a,b,c,d', - 'list2': '1,2,3,4' - }, { - 'id': 1, - 'list1': 'a,b,c,d', - 'list2': '1,2,3,4' - }])[0].raw() == { - 'id': 1, - 'list1': ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'd'], - 'list2': [1, 2, 3, 4, 1, 2, 3, 4] - } + assert mapper.map_records( + [ + {"id": 1, "list1": "a,b,c,d", "list2": "1,2,3,4"}, + {"id": 1, "list1": "a,b,c,d", "list2": "1,2,3,4"}, + ] + )[0].raw() == { + "id": 1, + "list1": ["a", "b", "c", "d", "a", "b", "c", "d"], + "list2": [1, 2, 3, 4, 1, 2, 3, 4], + } def test_csv_list_field_none_in_first_record(): mapper = RecordCombiningMapper(record_mapper=ListWithStringsModel) - assert mapper.map_records([{ - 'id': 1, - 'list1': 'a,b,c,d', - 'list2': None - }, { - 'id': 1, - 'list1': 'e,f,g', - 'list2': '1,2,3,4' - }])[0].raw() == { - 'id': 1, - 'list1': ['a', 'b', 'c', 'd', 'e', 'f', 'g'], - 'list2': [1, 2, 3, 4] - } + assert mapper.map_records( + [ + {"id": 1, "list1": "a,b,c,d", "list2": None}, + {"id": 1, "list1": "e,f,g", "list2": "1,2,3,4"}, + ] + )[0].raw() == { + "id": 1, + "list1": ["a", "b", "c", "d", "e", "f", "g"], + "list2": [1, 2, 3, 4], + } def test_csv_list_field_without_mapping_ignored(): mapper = SingleRowMapper(record_mapper=ListWithStringsModel) - assert mapper.map_records([{ - 'id': 1, - 'list1': 'a,b, c,d', - 'list2': '1,2,3,4', - 'list3': 'x,y,z' - }]).raw() == { - 'id': 1, - 'list1': ['a', 'b', 'c', 'd'], - 'list2': [1, 2, 3, 4] - } + assert mapper.map_records( + [{"id": 1, "list1": "a,b, c,d", "list2": "1,2,3,4", "list3": "x,y,z"}] + ).raw() == {"id": 1, "list1": ["a", "b", "c", "d"], "list2": [1, 2, 3, 4]} def test_csv_list_field_invalid_type(): mapper = RecordCombiningMapper(record_mapper=ListWithStringsModel) with pytest.raises(ValidationError, match="1 validation error for list"): - mapper.map_records([{ - 'id': 1, - 'list1': 'a,b', - 'list2': '1,2' - }, { - 'id': 1, - 'list1': 'c,d', - 'list2': '3,a' - }]) + mapper.map_records( + [ + {"id": 1, "list1": "a,b", "list2": "1,2"}, + {"id": 1, "list1": "c,d", "list2": "3,a"}, + ] + ) def test_json_field(): mapper = SingleRowMapper(record_mapper=JsonModel) - assert mapper.map_records([{ - 'id': 1, - 'json1': json.dumps({ - 'a': 1, - 'b': 2, - 'c': { - 'x': 10, - 'y': 9, - 'z': { - 'deep': 'value' - } + assert mapper.map_records( + [ + { + "id": 1, + "json1": json.dumps( + {"a": 1, "b": 2, "c": {"x": 10, "y": 9, "z": {"deep": "value"}}} + ), } - }) - }]).model_dump() == { - 'id': 1, - 'json1': { - 'a': 1, - 'b': 2, - 'c': { - 'x': 10, - 'y': 9, - 'z': { - 'deep': 'value' - } - } - }, - 'json2': None - } - - -@pytest.mark.parametrize('json1, json2', [ - ('{ "json": value', None), - ('{ "json": value', '{ "json": value }'), - ('{ "json": value }', '{ "json": value'), - (None, '{ "json": value'), -]) + ] + ).model_dump() == { + "id": 1, + "json1": {"a": 1, "b": 2, "c": {"x": 10, "y": 9, "z": {"deep": "value"}}}, + "json2": None, + } + + +@pytest.mark.parametrize( + "json1, json2", + [ + ('{ "json": value', None), + ('{ "json": value', '{ "json": value }'), + ('{ "json": value }', '{ "json": value'), + (None, '{ "json": value'), + ], +) def test_invalid_json(json1, json2): - with pytest.raises(ValidationError, match='Invalid JSON'): + with pytest.raises(ValidationError, match="Invalid JSON"): mapper = SingleRowMapper(record_mapper=JsonModel) - mapper.map_records([{ - 'id': 1, - 'json1': json1, - 'json2': json2 - }]) + mapper.map_records([{"id": 1, "json1": json1, "json2": json2}]) def test_json_none(): mapper = SingleRowMapper(record_mapper=JsonModel) - assert mapper.map_records([{ - 'id': 1, - 'json1': '{ "first": "value" }', - 'json2': None - }]).model_dump() == { - 'id': 1, - 'json1': { - 'first': 'value', + assert mapper.map_records( + [{"id": 1, "json1": '{ "first": "value" }', "json2": None}] + ).model_dump() == { + "id": 1, + "json1": { + "first": "value", }, - 'json2': None + "json2": None, } diff --git a/dysql/test/test_sql_decorator.py b/dysql/test/test_sql_decorator.py index e98716a..b714152 100644 --- a/dysql/test/test_sql_decorator.py +++ b/dysql/test/test_sql_decorator.py @@ -7,8 +7,15 @@ """ import pytest -from dysql import sqlupdate, sqlquery, DbMapResult, CountMapper, SingleRowMapper, \ - QueryData, QueryDataError +from dysql import ( + sqlupdate, + sqlquery, + DbMapResult, + CountMapper, + SingleRowMapper, + QueryData, + QueryDataError, +) from dysql.test import mock_create_engine_fixture, setup_mock_engine @@ -23,22 +30,26 @@ class TestSqlSelectDecorator: @staticmethod @pytest.fixture def mock_results(): - return [{ - 'id': 1, - 'name': 'jack', - 'email': 'jack@adobe.com', - 'hobbies': ['golf', 'bikes', 'coding'] - }, { - 'id': 2, - 'name': 'flora', - 'email': 'flora@adobe.com', - 'hobbies': ['golf', 'coding'] - }, { - 'id': 3, - 'name': 'Terrence', - 'email': 'terrence@adobe.com', - 'hobbies': ['coding', 'watching tv', 'hot dog eating contest'] - }] + return [ + { + "id": 1, + "name": "jack", + "email": "jack@adobe.com", + "hobbies": ["golf", "bikes", "coding"], + }, + { + "id": 2, + "name": "flora", + "email": "flora@adobe.com", + "hobbies": ["golf", "coding"], + }, + { + "id": 3, + "name": "Terrence", + "email": "terrence@adobe.com", + "hobbies": ["coding", "watching tv", "hot dog eating contest"], + }, + ] @staticmethod @pytest.fixture @@ -70,8 +81,8 @@ def test_single(self, mock_results, mock_engine): def test_count(self, mock_engine): mock_engine.connect.return_value.execution_options.return_value.execute.return_value = [ - {'count': 2}, - {'count': 3}, + {"count": 2}, + {"count": 3}, ] assert self._select_count() == 2 @@ -88,10 +99,12 @@ def test_execute_params(self, mock_engine): call_args = mock_engine.connect.return_value.execution_options.return_value.execute.call_args assert call_args[0][0].text == "SELECT * FROM my_table WHERE id=:id" - assert call_args[0][1] == {'id': 3} + assert call_args[0][1] == {"id": 3} def test_list_results_map(self, mock_results, mock_engine): - mock_engine.connect.return_value.execution_options.return_value.execute.return_value = [mock_results[2]] + mock_engine.connect.return_value.execution_options.return_value.execute.return_value = [ + mock_results[2] + ] results = self._select_filtered(3) assert len(results) == 1 @@ -100,12 +113,12 @@ def test_list_results_map(self, mock_results, mock_engine): def test_isolation_default(self, mock_engine): mock_connect = mock_engine.connect.return_value.execution_options self._select_all() - mock_connect.assert_called_with(isolation_level='READ_COMMITTED') + mock_connect.assert_called_with(isolation_level="READ_COMMITTED") def test_isolation_default_read_uncommited(self, mock_engine): mock_connect = mock_engine.connect.return_value.execution_options self._select_uncommitted() - mock_connect.assert_called_with(isolation_level='READ_UNCOMMITTED') + mock_connect.assert_called_with(isolation_level="READ_UNCOMMITTED") mock_connect.return_value.execute.assert_called() @staticmethod @@ -130,10 +143,12 @@ def _select_count(): @staticmethod @sqlquery() def _select_filtered(_id): - return QueryData("SELECT * FROM my_table WHERE id=:id", query_params={'id': _id}) + return QueryData( + "SELECT * FROM my_table WHERE id=:id", query_params={"id": _id} + ) @staticmethod - @sqlquery(isolation_level='READ_UNCOMMITTED') + @sqlquery(isolation_level="READ_UNCOMMITTED") def _select_uncommitted(): return QueryData("SELECT * FROM uncommitted") @@ -154,18 +169,18 @@ def mock_connect(mock_engine): return mock_engine.connect.return_value.execution_options def test_isolation_default(self, mock_connect): - self._update_something({'id': 1, 'value': 'test'}) - mock_connect.assert_called_with(isolation_level='READ_COMMITTED') + self._update_something({"id": 1, "value": "test"}) + mock_connect.assert_called_with(isolation_level="READ_COMMITTED") def test_isolation_default_read_uncommited(self, mock_connect): - self._update_something_uncommited_isolation({'id': 1, 'value': 'test'}) - mock_connect.assert_called_with(isolation_level='READ_UNCOMMITTED') + self._update_something_uncommited_isolation({"id": 1, "value": "test"}) + mock_connect.assert_called_with(isolation_level="READ_UNCOMMITTED") mock_connect.return_value.begin.assert_called() mock_connect.return_value.begin.return_value.commit.assert_called() mock_connect.return_value.__exit__.assert_called() def test_transaction(self, mock_connect): - self._update_something({'id': 1, 'value': 'test'}) + self._update_something({"id": 1, "value": "test"}) mock_connect().begin.assert_called() mock_connect().begin.return_value.commit.assert_called() mock_connect().__exit__.assert_called() @@ -173,16 +188,18 @@ def test_transaction(self, mock_connect): def test_transaction_fails(self, mock_connect): mock_connect().execute.side_effect = Exception("error") with pytest.raises(Exception): - self._update_something({'id': 1, 'value': 'test'}) + self._update_something({"id": 1, "value": "test"}) mock_connect().begin.return_value.commit.assert_not_called() mock_connect().begin.return_value.rollback.assert_called() mock_connect().__exit__.assert_called() def test_execute_passes_values(self, mock_engine): - values = {'id': 1, 'value': 'test'} + values = {"id": 1, "value": "test"} self._update_something(values) - execute_call = mock_engine.connect.return_value.execution_options.return_value.execute + execute_call = ( + mock_engine.connect.return_value.execution_options.return_value.execute + ) execute_call.assert_called() execute_call_args = execute_call.call_args[0] @@ -192,7 +209,9 @@ def test_execute_passes_values(self, mock_engine): def test_execute_query_values_none(self, mock_engine): self._update_something(None) - execute_call = mock_engine.connect.return_value.execution_options.return_value.execute + execute_call = ( + mock_engine.connect.return_value.execution_options.return_value.execute + ) execute_call.assert_called() self._expect_args_list(execute_call.call_args_list[0], "INSERT something") @@ -203,9 +222,9 @@ def test_execute_query_values_not_given(self, mock_engine): def test_execute_multi_yield(self, mock_connect): expected_values = [ - {'id': 1, 'value': 'test 1'}, - {'id': 2, 'value': 'test 2'}, - {'id': 3, 'value': 'test 3'}, + {"id": 1, "value": "test 1"}, + {"id": 2, "value": "test 2"}, + {"id": 3, "value": "test 3"}, ] self._update_something_multi_yield(expected_values) @@ -217,55 +236,59 @@ def test_execute_multi_yield(self, mock_connect): def test_execute_fails_list_if_multi_false(self): expected_values = [ - {'id': 1, 'value': 'test 1'}, - {'id': 2, 'value': 'test 2'}, - {'id': 3, 'value': 'test 3'}, + {"id": 1, "value": "test 1"}, + {"id": 2, "value": "test 2"}, + {"id": 3, "value": "test 3"}, ] with pytest.raises(Exception): self._update_list_when_multi_false(expected_values) def test_execute_multi_yield_and_lists(self, mock_engine): expected_values = [ - {'id': 1, 'value': 'test 1'}, - {'id': 2, 'value': 'test 2'}, - {'id': 3, 'value': 'test 3'}, + {"id": 1, "value": "test 1"}, + {"id": 2, "value": "test 2"}, + {"id": 3, "value": "test 3"}, ] other_expected_values = [ - {'id': 5, 'value': 'test 5'}, - {'id': 6, 'value': 'test 6'}, - {'id': 7, 'value': 'test 7'}, + {"id": 5, "value": "test 5"}, + {"id": 6, "value": "test 6"}, + {"id": 7, "value": "test 7"}, ] self._update_yield_with_lists(expected_values, other_expected_values) - execute_call = mock_engine.connect.return_value.execution_options.return_value.execute + execute_call = ( + mock_engine.connect.return_value.execution_options.return_value.execute + ) assert execute_call.call_count == 2 self._expect_args_list( execute_call.call_args_list[0], "INSERT some VALUES ( :values__in_0_0, :values__in_0_1 ), ( :values__in_1_0, :values__in_1_1 ), " - "( :values__in_2_0, :values__in_2_1 ) " + "( :values__in_2_0, :values__in_2_1 ) ", ) self._expect_args_list( execute_call.call_args_list[1], "INSERT some more VALUES ( :values__other_0_0, :values__other_0_1 ), " - "( :values__other_1_0, :values__other_1_1 ), ( :values__other_2_0, :values__other_2_1 ) " + "( :values__other_1_0, :values__other_1_1 ), ( :values__other_2_0, :values__other_2_1 ) ", ) def test_execute_multi_yield_and_lists_some_no_params(self, mock_engine): expected_values = [ - {'id': 1, 'value': 'test 1'}, - {'id': 2, 'value': 'test 2'}, - {'id': 3, 'value': 'test 3'}, + {"id": 1, "value": "test 1"}, + {"id": 2, "value": "test 2"}, + {"id": 3, "value": "test 3"}, ] self._update_yield_with_lists_some_no_params(expected_values) - execute_call = mock_engine.connect.return_value.execution_options.return_value.execute + execute_call = ( + mock_engine.connect.return_value.execution_options.return_value.execute + ) assert execute_call.call_count == 4 self._expect_args_list( execute_call.call_args_list[0], "INSERT some VALUES ( :values__in_0_0, :values__in_0_1 ), ( :values__in_1_0, :values__in_1_1 ), " - "( :values__in_2_0, :values__in_2_1 ) " + "( :values__in_2_0, :values__in_2_1 ) ", ) self._expect_args_list(execute_call.call_args_list[1], "UPDATE some more") self._expect_args_list(execute_call.call_args_list[2], "UPDATE some more") @@ -274,13 +297,17 @@ def test_execute_multi_yield_and_lists_some_no_params(self, mock_engine): def test_set_foreign_key_checks_default(self, mock_engine): self._update_something_no_params() - execute_call = mock_engine.connect.return_value.execution_options.return_value.execute + execute_call = ( + mock_engine.connect.return_value.execution_options.return_value.execute + ) assert execute_call.call_count == 1 def test_set_foreign_key_checks_true(self, mock_engine): self._update_without_foreign_key_checks() - execute_call = mock_engine.connect.return_value.execution_options.return_value.execute + execute_call = ( + mock_engine.connect.return_value.execution_options.return_value.execute + ) assert execute_call.call_count == 4 assert execute_call.call_args_list[0][0][0].text == "SET FOREIGN_KEY_CHECKS=0" assert execute_call.call_args_list[-1][0][0].text == "SET FOREIGN_KEY_CHECKS=1" @@ -301,7 +328,7 @@ def _update_something(values): return QueryData("INSERT something", query_params=values) @staticmethod - @sqlupdate(isolation_level='READ_UNCOMMITTED') + @sqlupdate(isolation_level="READ_UNCOMMITTED") def _update_something_uncommited_isolation(values): return QueryData(f"INSERT WITH uncommitted {values}") @@ -337,11 +364,11 @@ def _update_yield_with_lists(multiple_values, other_values): """ yield QueryData( "INSERT some {values__in}", - template_params=_get_template_params('values__in', multiple_values) + template_params=_get_template_params("values__in", multiple_values), ) yield QueryData( "INSERT some more {values__other}", - template_params=_get_template_params('values__other', other_values) + template_params=_get_template_params("values__other", other_values), ) @staticmethod @@ -353,7 +380,7 @@ def _update_yield_with_lists_some_no_params(multiple_values): """ yield QueryData( "INSERT some {values__in}", - template_params=_get_template_params('values__in', multiple_values) + template_params=_get_template_params("values__in", multiple_values), ) yield QueryData("UPDATE some more") yield QueryData("UPDATE some more") @@ -361,14 +388,14 @@ def _update_yield_with_lists_some_no_params(multiple_values): def _get_template_params(key, values): - ''' + """ template parameters is handled here as an example of what you might end up doing. usually data coming in is going to be a list :param key: the template key we want to use to build our template with :param values: the list of objects or mariadbmaps, :return: a keyed list of tuples or mariadbmaps - ''' + """ if isinstance(values[0], DbMapResult): return {key: values} - return {key: [(v['id'], v['value']) for v in values]} + return {key: [(v["id"], v["value"]) for v in values]} diff --git a/dysql/test/test_sql_exists_decorator.py b/dysql/test/test_sql_exists_decorator.py index 3980bca..e965f82 100644 --- a/dysql/test/test_sql_exists_decorator.py +++ b/dysql/test/test_sql_exists_decorator.py @@ -15,29 +15,28 @@ _ = mock_create_engine_fixture -TRUE_QUERY = 'SELECT 1 from table' -TRUE_QUERY_PARAMS = 'SELECT 1 from table where key=:key' -FALSE_QUERY = 'SELECT 1 from false_table ' -FALSE_QUERY_PARAMS = 'SELECT 1 from table where key=:key' -TRUE_PARAMS = {'key': 123} -FALSE_PARAMS = {'key': 456} -SELECT_EXISTS_QUERY = 'SELECT 1 WHERE EXISTS ( {} )' -SELECT_EXISTS_NO_WHERE_QUERY = 'SELECT EXISTS ( {} )' +TRUE_QUERY = "SELECT 1 from table" +TRUE_QUERY_PARAMS = "SELECT 1 from table where key=:key" +FALSE_QUERY = "SELECT 1 from false_table " +FALSE_QUERY_PARAMS = "SELECT 1 from table where key=:key" +TRUE_PARAMS = {"key": 123} +FALSE_PARAMS = {"key": 456} +SELECT_EXISTS_QUERY = "SELECT 1 WHERE EXISTS ( {} )" +SELECT_EXISTS_NO_WHERE_QUERY = "SELECT EXISTS ( {} )" EXISTS_QUERIES = { - 'true': SELECT_EXISTS_QUERY.format(TRUE_QUERY), - 'false': SELECT_EXISTS_QUERY.format(FALSE_QUERY), - 'true_params': SELECT_EXISTS_QUERY.format(TRUE_QUERY_PARAMS), - 'false_params': SELECT_EXISTS_QUERY.format(FALSE_QUERY_PARAMS), - 'true_no_where': SELECT_EXISTS_NO_WHERE_QUERY.format(TRUE_QUERY), - 'false_no_where': SELECT_EXISTS_NO_WHERE_QUERY.format(FALSE_QUERY) + "true": SELECT_EXISTS_QUERY.format(TRUE_QUERY), + "false": SELECT_EXISTS_QUERY.format(FALSE_QUERY), + "true_params": SELECT_EXISTS_QUERY.format(TRUE_QUERY_PARAMS), + "false_params": SELECT_EXISTS_QUERY.format(FALSE_QUERY_PARAMS), + "true_no_where": SELECT_EXISTS_NO_WHERE_QUERY.format(TRUE_QUERY), + "false_no_where": SELECT_EXISTS_NO_WHERE_QUERY.format(FALSE_QUERY), } @pytest.fixture(autouse=True) def mock_engine_fixture(mock_create_engine): mock_engine = setup_mock_engine(mock_create_engine) - mock_engine.connect.return_value.execution_options.return_value.execute.side_effect = \ - _check_query_and_return_result + mock_engine.connect.return_value.execution_options.return_value.execute.side_effect = _check_query_and_return_result mock_engine.connect().execution_options().__enter__ = Mock() mock_engine.connect().execution_options().__exit__ = Mock() @@ -64,7 +63,7 @@ def test_exists_query_contains_with_exists_true(): exists(exists()) will always give a 'true' result """ # should match against the same query, should still match - assert _exists_specified('true') + assert _exists_specified("true") def test_exists_query_contains_with_exists_false(): @@ -73,7 +72,7 @@ def test_exists_query_contains_with_exists_false(): exists(exists()) will always give a 'true' result """ # should match against the same query, should still match - assert not _exists_specified('false') + assert not _exists_specified("false") def test_exists_without_where_true(): @@ -105,12 +104,12 @@ def test_exists_query_starts_with_exists_handles_whitespace(): @sqlexists() def _select_exists_no_where_false(): - return QueryData(EXISTS_QUERIES['false_no_where']) + return QueryData(EXISTS_QUERIES["false_no_where"]) @sqlexists() def _select_exists_no_where_true(): - return QueryData(EXISTS_QUERIES['true_no_where']) + return QueryData(EXISTS_QUERIES["true_no_where"]) @sqlexists() @@ -140,8 +139,11 @@ def _exists_false_params(): @sqlexists() def _exists_whitespace(): - return QueryData(""" - """ + TRUE_QUERY) + return QueryData( + """ + """ + + TRUE_QUERY + ) def _check_query_and_return_result(query, params): @@ -154,11 +156,11 @@ def _check_query_and_return_result(query, params): scalar_mock = Mock() # default mock responses to true, then we only handle setting false responses scalar_mock.scalar.return_value = 1 - if query.text == EXISTS_QUERIES['true_params']: - assert params.get('key') == 123 - if query.text == EXISTS_QUERIES['false_params']: - assert params.get('key') == 456 + if query.text == EXISTS_QUERIES["true_params"]: + assert params.get("key") == 123 + if query.text == EXISTS_QUERIES["false_params"]: + assert params.get("key") == 456 scalar_mock.scalar.return_value = 0 - elif query.text == EXISTS_QUERIES['false']: + elif query.text == EXISTS_QUERIES["false"]: scalar_mock.scalar.return_value = 0 return scalar_mock diff --git a/dysql/test/test_sql_in_list_templates.py b/dysql/test/test_sql_in_list_templates.py index e7a1fcd..17fb07e 100644 --- a/dysql/test/test_sql_in_list_templates.py +++ b/dysql/test/test_sql_in_list_templates.py @@ -10,8 +10,13 @@ import dysql from dysql import QueryData, sqlquery -from dysql.test import \ - _verify_query, _verify_query_args, _verify_query_params, mock_create_engine_fixture, setup_mock_engine +from dysql.test import ( + _verify_query, + _verify_query_args, + _verify_query_params, + mock_create_engine_fixture, + setup_mock_engine, +) _ = mock_create_engine_fixture @@ -27,151 +32,179 @@ def mock_engine_fixture(mock_create_engine): def test_list_in_numbers(mock_engine): _query( "SELECT * FROM table WHERE {in__column_a}", - template_params={'in__column_a': [1, 2, 3, 4]} + template_params={"in__column_a": [1, 2, 3, 4]}, ) _verify_query_params( mock_engine, "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1, :in__column_a_2, :in__column_a_3 ) ", { - 'in__column_a_0': 1, - 'in__column_a_1': 2, - 'in__column_a_2': 3, - 'in__column_a_3': 4 - } + "in__column_a_0": 1, + "in__column_a_1": 2, + "in__column_a_2": 3, + "in__column_a_3": 4, + }, ) def test_list_in__strings(mock_engine): _query( "SELECT * FROM table WHERE {in__column_a}", - template_params={'in__column_a': ['a', 'b', 'c', 'd']} + template_params={"in__column_a": ["a", "b", "c", "d"]}, ) _verify_query_params( mock_engine, "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1, :in__column_a_2, :in__column_a_3 ) ", { - 'in__column_a_0': 'a', - 'in__column_a_1': 'b', - 'in__column_a_2': 'c', - 'in__column_a_3': 'd' - }) + "in__column_a_0": "a", + "in__column_a_1": "b", + "in__column_a_2": "c", + "in__column_a_3": "d", + }, + ) def test_list_not_in_numbers(mock_engine): _query( "SELECT * FROM table WHERE {not_in__column_b}", - template_params={'not_in__column_b': [1, 2, 3, 4]} + template_params={"not_in__column_b": [1, 2, 3, 4]}, ) _verify_query_params( mock_engine, "SELECT * FROM table WHERE column_b NOT IN ( :not_in__column_b_0, :not_in__column_b_1, " ":not_in__column_b_2, :not_in__column_b_3 ) ", { - 'not_in__column_b_0': 1, - 'not_in__column_b_1': 2, - 'not_in__column_b_2': 3, - 'not_in__column_b_3': 4 - }) + "not_in__column_b_0": 1, + "not_in__column_b_1": 2, + "not_in__column_b_2": 3, + "not_in__column_b_3": 4, + }, + ) def test_list_not_in_strings(mock_engine): _query( "SELECT * FROM table WHERE {not_in__column_b}", - template_params={'not_in__column_b': ['a', 'b', 'c', 'd']} + template_params={"not_in__column_b": ["a", "b", "c", "d"]}, ) _verify_query_params( mock_engine, "SELECT * FROM table WHERE column_b NOT IN ( :not_in__column_b_0, :not_in__column_b_1, " ":not_in__column_b_2, :not_in__column_b_3 ) ", { - 'not_in__column_b_0': 'a', - 'not_in__column_b_1': 'b', - 'not_in__column_b_2': 'c', - 'not_in__column_b_3': 'd' - }) + "not_in__column_b_0": "a", + "not_in__column_b_1": "b", + "not_in__column_b_2": "c", + "not_in__column_b_3": "d", + }, + ) def test_list_in_handles_empty(mock_engine): _query( - "SELECT * FROM table WHERE {in__column_a}", - template_params={'in__column_a': []} + "SELECT * FROM table WHERE {in__column_a}", template_params={"in__column_a": []} ) _verify_query(mock_engine, "SELECT * FROM table WHERE 1 <> 1 ") def test_list_in_handles_no_param(): - with pytest.raises(dysql.query_utils.ListTemplateException, match="['in__column_a']"): + with pytest.raises( + dysql.query_utils.ListTemplateException, match="['in__column_a']" + ): _query("SELECT * FROM table WHERE {in__column_a}") def test_list_in_multiple_lists(mock_engine): - _query("SELECT * FROM table WHERE {in__column_a} OR {in__column_b}", template_params={ - 'in__column_a': ['first', 'second'], - 'in__column_b': [1, 2]}) + _query( + "SELECT * FROM table WHERE {in__column_a} OR {in__column_b}", + template_params={"in__column_a": ["first", "second"], "in__column_b": [1, 2]}, + ) _verify_query( mock_engine, "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1 ) " - "OR column_b IN ( :in__column_b_0, :in__column_b_1 ) " + "OR column_b IN ( :in__column_b_0, :in__column_b_1 ) ", ) def test_list_in_multiple_lists_one_empty(mock_engine): - _query("SELECT * FROM table WHERE {in__column_a} OR {in__column_b}", template_params={ - 'in__column_a': ['first', 'second'], - 'in__column_b': []}) + _query( + "SELECT * FROM table WHERE {in__column_a} OR {in__column_b}", + template_params={"in__column_a": ["first", "second"], "in__column_b": []}, + ) _verify_query( mock_engine, - "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1 ) OR 1 <> 1 " + "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1 ) OR 1 <> 1 ", ) def test_list_in_multiple_lists_one_missing(): - with pytest.raises(dysql.query_utils.ListTemplateException, match="['in__column_a']"): - _query("SELECT * FROM table WHERE {in__column_a} OR {in__column_b} ", template_params={'in__column_b': [1, 2]}) + with pytest.raises( + dysql.query_utils.ListTemplateException, match="['in__column_a']" + ): + _query( + "SELECT * FROM table WHERE {in__column_a} OR {in__column_b} ", + template_params={"in__column_b": [1, 2]}, + ) def test_list_in_multiple_lists_all_missing(): - with pytest.raises(dysql.query_utils.ListTemplateException, match="['in__column_a','in__column_b']"): + with pytest.raises( + dysql.query_utils.ListTemplateException, match="['in__column_a','in__column_b']" + ): _query("SELECT * FROM table WHERE {in__column_a} OR {in__column_b} ") def test_list_not_in_handles_empty(mock_engine): _query( "SELECT * FROM table WHERE {not_in__column_b}", - template_params={'not_in__column_b': []} + template_params={"not_in__column_b": []}, ) _verify_query(mock_engine, "SELECT * FROM table WHERE 1 = 1 ") def test_list_not_in_handles_no_param(): - with pytest.raises(dysql.query_utils.ListTemplateException, match="['not_in__column_b']"): + with pytest.raises( + dysql.query_utils.ListTemplateException, match="['not_in__column_b']" + ): _query("SELECT * FROM table WHERE {not_in__column_b} ") def test_list_gives_template_space_before(mock_engine): - _query("SELECT * FROM table WHERE{in__space}", template_params={'in__space': [9, 8]}) - _verify_query(mock_engine, "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) ") + _query( + "SELECT * FROM table WHERE{in__space}", template_params={"in__space": [9, 8]} + ) + _verify_query( + mock_engine, + "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) ", + ) def test_list_gives_template_space_after(mock_engine): - _query("SELECT * FROM table WHERE {in__space}AND other_condition = 1", template_params={'in__space': [9, 8]}) + _query( + "SELECT * FROM table WHERE {in__space}AND other_condition = 1", + template_params={"in__space": [9, 8]}, + ) _verify_query( mock_engine, - "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) AND other_condition = 1" + "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) AND other_condition = 1", ) def test_list_gives_template_space_before_and_after(mock_engine): - _query("SELECT * FROM table WHERE{in__space}AND other_condition = 1", template_params={'in__space': [9, 8]}) + _query( + "SELECT * FROM table WHERE{in__space}AND other_condition = 1", + template_params={"in__space": [9, 8]}, + ) _verify_query( mock_engine, - "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) AND other_condition = 1" + "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) AND other_condition = 1", ) def test_in_contains_whitespace(mock_engine): - _query("{in__column_one}", template_params={'in__column_one': [1, 2]}) - _verify_query(mock_engine, " column_one IN ( :in__column_one_0, :in__column_one_1 ) ") + _query("{in__column_one}", template_params={"in__column_one": [1, 2]}) + _verify_query( + mock_engine, " column_one IN ( :in__column_one_0, :in__column_one_1 ) " + ) def test_template_handles_table_qualifier(mock_engine): @@ -184,64 +217,68 @@ def test_template_handles_table_qualifier(mock_engine): """ _query( "SELECT * FROM table WHERE {in__table.column}", - template_params={'in__table.column': [1, 2]} + template_params={"in__table.column": [1, 2]}, ) _verify_query( mock_engine, - "SELECT * FROM table WHERE table.column IN ( :in__table_column_0, :in__table_column_1 ) " - ) - _verify_query_args( - mock_engine, - { - 'in__table_column_0': 1, - 'in__table_column_1': 2 - } + "SELECT * FROM table WHERE table.column IN ( :in__table_column_0, :in__table_column_1 ) ", ) + _verify_query_args(mock_engine, {"in__table_column_0": 1, "in__table_column_1": 2}) def test_template_handles_multiple_table_qualifier(mock_engine): _query( "SELECT * FROM table WHERE {in__table.column} AND {not_in__other_column}", - template_params={'in__table.column': [1, 2], 'not_in__other_column': ['a', 'b']} + template_params={ + "in__table.column": [1, 2], + "not_in__other_column": ["a", "b"], + }, ) _verify_query( mock_engine, "SELECT * FROM table WHERE table.column IN ( :in__table_column_0, :in__table_column_1 ) " - "AND other_column NOT IN ( :not_in__other_column_0, :not_in__other_column_1 ) " + "AND other_column NOT IN ( :not_in__other_column_0, :not_in__other_column_1 ) ", ) _verify_query_args( mock_engine, { - 'in__table_column_0': 1, - 'in__table_column_1': 2, - 'not_in__other_column_0': 'a', - 'not_in__other_column_1': 'b', - } + "in__table_column_0": 1, + "in__table_column_1": 2, + "not_in__other_column_0": "a", + "not_in__other_column_1": "b", + }, ) def test_empty_in_contains_whitespace(mock_engine): - _query("{in__column_one}", template_params={'in__column_one': []}) + _query("{in__column_one}", template_params={"in__column_one": []}) _verify_query(mock_engine, " 1 <> 1 ") def test_multiple_templates_same_column_diff_table(mock_engine): - template_params = {'in__table.status': ['on', 'off', 'waiting'], - 'in__other_table.status': ['success', 'partial_success']} + template_params = { + "in__table.status": ["on", "off", "waiting"], + "in__other_table.status": ["success", "partial_success"], + } expected_params_from_template = { - 'in__table_status_0': 'on', - 'in__table_status_1': 'off', - 'in__table_status_2': 'waiting', - 'in__other_table_status_0': 'success', - 'in__other_table_status_1': 'partial_success' + "in__table_status_0": "on", + "in__table_status_1": "off", + "in__table_status_2": "waiting", + "in__other_table_status_0": "success", + "in__other_table_status_1": "partial_success", } # writing each of these queries out to help see what we expect compared to # the query we actually sent - _query("SELECT * FROM table WHERE {in__table.status} AND {in__other_table.status}", template_params=template_params) - expected_query = "SELECT * FROM table WHERE table.status IN ( :in__table_status_0, :in__table_status_1, " \ - ":in__table_status_2 ) AND other_table.status IN ( :in__other_table_status_0, " \ - ":in__other_table_status_1 ) " + _query( + "SELECT * FROM table WHERE {in__table.status} AND {in__other_table.status}", + template_params=template_params, + ) + expected_query = ( + "SELECT * FROM table WHERE table.status IN ( :in__table_status_0, :in__table_status_1, " + ":in__table_status_2 ) AND other_table.status IN ( :in__other_table_status_0, " + ":in__other_table_status_1 ) " + ) connection = mock_engine.connect.return_value.execution_options.return_value execute_call = connection.execute diff --git a/dysql/test/test_sql_insert_templates.py b/dysql/test/test_sql_insert_templates.py index d74d608..c38986c 100644 --- a/dysql/test/test_sql_insert_templates.py +++ b/dysql/test/test_sql_insert_templates.py @@ -9,7 +9,12 @@ import dysql from dysql import QueryData, sqlupdate, QueryDataError -from dysql.test import _verify_query, _verify_query_args, mock_create_engine_fixture, setup_mock_engine +from dysql.test import ( + _verify_query, + _verify_query_args, + mock_create_engine_fixture, + setup_mock_engine, +) _ = mock_create_engine_fixture @@ -32,57 +37,66 @@ def select_with_string(): def test_insert_single_column(mock_engine): - insert_into_single_value(['Tom', 'Jerry']) + insert_into_single_value(["Tom", "Jerry"]) _verify_query( mock_engine, - "INSERT INTO table(name) VALUES ( :values__name_col_0 ), ( :values__name_col_1 ) " + "INSERT INTO table(name) VALUES ( :values__name_col_0 ), ( :values__name_col_1 ) ", ) def test_insert_single_column_single_value(mock_engine): - insert_into_single_value('Tom') - _verify_query(mock_engine, "INSERT INTO table(name) VALUES ( :values__name_col_0 ) ") + insert_into_single_value("Tom") + _verify_query( + mock_engine, "INSERT INTO table(name) VALUES ( :values__name_col_0 ) " + ) def test_insert_single_value_empty(): - with pytest.raises(dysql.query_utils.ListTemplateException, match="['values_name_col']"): + with pytest.raises( + dysql.query_utils.ListTemplateException, match="['values_name_col']" + ): insert_into_single_value([]) def test_insert_single_value_no_key(): - with pytest.raises(dysql.query_utils.ListTemplateException, match="['values_name_col']"): + with pytest.raises( + dysql.query_utils.ListTemplateException, match="['values_name_col']" + ): insert_into_single_value(None) def test_insert_multiple_values(mock_engine): insert_into_multiple_values( [ - {'name': 'Tom', 'email': 'tom@adobe.com'}, - {'name': 'Jerry', 'email': 'jerry@adobe.com'} + {"name": "Tom", "email": "tom@adobe.com"}, + {"name": "Jerry", "email": "jerry@adobe.com"}, ] ) _verify_query( mock_engine, "INSERT INTO table(name, email) VALUES ( :values__users_0_0, :values__users_0_1 ), " - "( :values__users_1_0, :values__users_1_1 ) " + "( :values__users_1_0, :values__users_1_1 ) ", ) _verify_query_args( mock_engine, { - 'values__users_0_0': 'Tom', - 'values__users_0_1': 'tom@adobe.com', - 'values__users_1_0': 'Jerry', - 'values__users_1_1': 'jerry@adobe.com' - } + "values__users_0_0": "Tom", + "values__users_0_1": "tom@adobe.com", + "values__users_1_0": "Jerry", + "values__users_1_1": "jerry@adobe.com", + }, ) -@pytest.mark.parametrize('args', [ - ([('bob', 'bob@email.com')]), - ([('bob', 'bob@email.com'), ('tom', 'tom@email.com')]), - None, - (), -]) +@pytest.mark.parametrize( + "args", + [ + ([("bob", "bob@email.com")]), + ([("bob", "bob@email.com"), ("tom", "tom@email.com")]), + None, + (), + ], +) def test_insert_with_callback(args): def callback(items): assert items == args @@ -102,24 +116,27 @@ def callback(): def insert(): yield QueryData("INSERT INTO table(name, email)") - mock_engine.connect().execution_options.return_value.execute.side_effect = Exception() + mock_engine.connect().execution_options.return_value.execute.side_effect = ( + Exception() + ) with pytest.raises(Exception): insert() @sqlupdate() def insert_into_multiple_values(users): - yield QueryData("INSERT INTO table(name, email) {values__users}", - template_params={'values__users': [(d['name'], d['email']) for d in users]}) + yield QueryData( + "INSERT INTO table(name, email) {values__users}", + template_params={"values__users": [(d["name"], d["email"]) for d in users]}, + ) @sqlupdate() def insert_into_single_value(names): template_params = {} if names is not None: - template_params = {'values__name_col': names} + template_params = {"values__name_col": names} return QueryData( - "INSERT INTO table(name) {values__name_col}", - template_params=template_params + "INSERT INTO table(name) {values__name_col}", template_params=template_params ) diff --git a/dysql/test/test_template_generators.py b/dysql/test/test_template_generators.py index 4f2afe9..4925e87 100644 --- a/dysql/test/test_template_generators.py +++ b/dysql/test/test_template_generators.py @@ -14,88 +14,169 @@ class TestTemplatesGenerators: """ Test we get templates back from Templates """ + number_values = [1, 2, 3, 4] - string_values = ['1', '2', '3', '4'] - insert_values = [('ironman', 1), ('batman', 2)] + string_values = ["1", "2", "3", "4"] + insert_values = [("ironman", 1), ("batman", 2)] tuple_values = [(1, 2), (3, 4)] - query = 'column_a IN ( :column_a_0, :column_a_1, :column_a_2, :column_a_3 )' + query = "column_a IN ( :column_a_0, :column_a_1, :column_a_2, :column_a_3 )" - query_with_list_of_tuples = \ - '(column_a, column_b) IN (( :column_acolumn_b_0_0, :column_acolumn_b_0_1 ), ' \ - '( :column_acolumn_b_1_0, :column_acolumn_b_1_1 ))' - not_query_with_list_of_tuples = \ - '(column_a, column_b) NOT IN (( :column_acolumn_b_0_0, :column_acolumn_b_0_1 ), ' \ - '( :column_acolumn_b_1_0, :column_acolumn_b_1_1 ))' - query_with_table = \ - 'table.column_a IN ( :table_column_a_0, :table_column_a_1, :table_column_a_2, :table_column_a_3 )' - not_query = 'column_a NOT IN ( :column_a_0, :column_a_1, :column_a_2, :column_a_3 )' - not_query_with_table = \ - 'table.column_a NOT IN ( :table_column_a_0, :table_column_a_1, :table_column_a_2, :table_column_a_3 )' + query_with_list_of_tuples = ( + "(column_a, column_b) IN (( :column_acolumn_b_0_0, :column_acolumn_b_0_1 ), " + "( :column_acolumn_b_1_0, :column_acolumn_b_1_1 ))" + ) + not_query_with_list_of_tuples = ( + "(column_a, column_b) NOT IN (( :column_acolumn_b_0_0, :column_acolumn_b_0_1 ), " + "( :column_acolumn_b_1_0, :column_acolumn_b_1_1 ))" + ) + query_with_table = "table.column_a IN ( :table_column_a_0, :table_column_a_1, :table_column_a_2, :table_column_a_3 )" + not_query = "column_a NOT IN ( :column_a_0, :column_a_1, :column_a_2, :column_a_3 )" + not_query_with_table = "table.column_a NOT IN ( :table_column_a_0, :table_column_a_1, :table_column_a_2, :table_column_a_3 )" - test_query_data = ('template_function, column_name,column_values,expected_query', [ - (TemplateGenerators.in_column, 'column_a', number_values, query), - (TemplateGenerators.in_column, 'table.column_a', number_values, query_with_table), - (TemplateGenerators.in_column, 'column_a', string_values, query), - (TemplateGenerators.in_column, 'table.column_a', string_values, query_with_table), - (TemplateGenerators.in_column, 'column_a', [], '1 <> 1'), - (TemplateGenerators.in_multi_column, '(column_a, column_b)', tuple_values, query_with_list_of_tuples), - (TemplateGenerators.not_in_multi_column, '(column_a, column_b)', tuple_values, not_query_with_list_of_tuples), - (TemplateGenerators.not_in_column, 'column_a', number_values, not_query), - (TemplateGenerators.not_in_column, 'table.column_a', number_values, not_query_with_table), - (TemplateGenerators.not_in_column, 'column_a', string_values, not_query), - (TemplateGenerators.not_in_column, 'table.column_a', string_values, not_query_with_table), - (TemplateGenerators.not_in_column, 'column_a', [], '1 = 1'), - ( - TemplateGenerators.values, - 'someid', - insert_values, - "VALUES ( :someid_0_0, :someid_0_1 ), ( :someid_1_0, :someid_1_1 )", - ) - ]) + test_query_data = ( + "template_function, column_name,column_values,expected_query", + [ + (TemplateGenerators.in_column, "column_a", number_values, query), + ( + TemplateGenerators.in_column, + "table.column_a", + number_values, + query_with_table, + ), + (TemplateGenerators.in_column, "column_a", string_values, query), + ( + TemplateGenerators.in_column, + "table.column_a", + string_values, + query_with_table, + ), + (TemplateGenerators.in_column, "column_a", [], "1 <> 1"), + ( + TemplateGenerators.in_multi_column, + "(column_a, column_b)", + tuple_values, + query_with_list_of_tuples, + ), + ( + TemplateGenerators.not_in_multi_column, + "(column_a, column_b)", + tuple_values, + not_query_with_list_of_tuples, + ), + (TemplateGenerators.not_in_column, "column_a", number_values, not_query), + ( + TemplateGenerators.not_in_column, + "table.column_a", + number_values, + not_query_with_table, + ), + (TemplateGenerators.not_in_column, "column_a", string_values, not_query), + ( + TemplateGenerators.not_in_column, + "table.column_a", + string_values, + not_query_with_table, + ), + (TemplateGenerators.not_in_column, "column_a", [], "1 = 1"), + ( + TemplateGenerators.values, + "someid", + insert_values, + "VALUES ( :someid_0_0, :someid_0_1 ), ( :someid_1_0, :someid_1_1 )", + ), + ], + ) parameter_numbers = { - 'column_a_0': number_values[0], - 'column_a_1': number_values[1], - 'column_a_2': number_values[2], - 'column_a_3': number_values[3] + "column_a_0": number_values[0], + "column_a_1": number_values[1], + "column_a_2": number_values[2], + "column_a_3": number_values[3], } with_table_parameter_numbers = { - 'table_column_a_0': number_values[0], - 'table_column_a_1': number_values[1], - 'table_column_a_2': number_values[2], - 'table_column_a_3': number_values[3] + "table_column_a_0": number_values[0], + "table_column_a_1": number_values[1], + "table_column_a_2": number_values[2], + "table_column_a_3": number_values[3], } parameter_strings = { - 'column_a_0': string_values[0], - 'column_a_1': string_values[1], - 'column_a_2': string_values[2], - 'column_a_3': string_values[3] + "column_a_0": string_values[0], + "column_a_1": string_values[1], + "column_a_2": string_values[2], + "column_a_3": string_values[3], } with_table_parameter_strings = { - 'table_column_a_0': string_values[0], - 'table_column_a_1': string_values[1], - 'table_column_a_2': string_values[2], - 'table_column_a_3': string_values[3] + "table_column_a_0": string_values[0], + "table_column_a_1": string_values[1], + "table_column_a_2": string_values[2], + "table_column_a_3": string_values[3], } - test_params_data = ('template_function, column_name,column_values,expected_params', [ - (TemplateGenerators.in_column, 'column_a', number_values, parameter_numbers), - (TemplateGenerators.in_column, 'table.column_a', number_values, with_table_parameter_numbers), - (TemplateGenerators.in_column, 'column_a', string_values, parameter_strings), - (TemplateGenerators.in_column, 'table.column_a', string_values, with_table_parameter_strings), - (TemplateGenerators.in_column, 'column_a', [], None), - (TemplateGenerators.not_in_column, 'column_a', number_values, parameter_numbers), - (TemplateGenerators.not_in_column, 'table.column_a', number_values, with_table_parameter_numbers), - (TemplateGenerators.not_in_column, 'column_a', string_values, parameter_strings), - (TemplateGenerators.not_in_column, 'table.column_a', string_values, with_table_parameter_strings), - (TemplateGenerators.not_in_column, 'column_a', [], None), - (TemplateGenerators.values, 'someid', insert_values, { - 'someid_0_0': insert_values[0][0], - 'someid_0_1': insert_values[0][1], - 'someid_1_0': insert_values[1][0], - 'someid_1_1': insert_values[1][1], - - }) - ]) + test_params_data = ( + "template_function, column_name,column_values,expected_params", + [ + ( + TemplateGenerators.in_column, + "column_a", + number_values, + parameter_numbers, + ), + ( + TemplateGenerators.in_column, + "table.column_a", + number_values, + with_table_parameter_numbers, + ), + ( + TemplateGenerators.in_column, + "column_a", + string_values, + parameter_strings, + ), + ( + TemplateGenerators.in_column, + "table.column_a", + string_values, + with_table_parameter_strings, + ), + (TemplateGenerators.in_column, "column_a", [], None), + ( + TemplateGenerators.not_in_column, + "column_a", + number_values, + parameter_numbers, + ), + ( + TemplateGenerators.not_in_column, + "table.column_a", + number_values, + with_table_parameter_numbers, + ), + ( + TemplateGenerators.not_in_column, + "column_a", + string_values, + parameter_strings, + ), + ( + TemplateGenerators.not_in_column, + "table.column_a", + string_values, + with_table_parameter_strings, + ), + (TemplateGenerators.not_in_column, "column_a", [], None), + ( + TemplateGenerators.values, + "someid", + insert_values, + { + "someid_0_0": insert_values[0][0], + "someid_0_1": insert_values[0][1], + "someid_1_0": insert_values[1][0], + "someid_1_1": insert_values[1][1], + }, + ), + ], + ) @staticmethod @pytest.mark.parametrize(*test_query_data) @@ -112,4 +193,4 @@ def test_params(template_function, column_name, column_values, expected_params): @staticmethod def test_insert_none(): with pytest.raises(ListTemplateException): - TemplateGenerators.values('someid', None) + TemplateGenerators.values("someid", None) diff --git a/dysql/test_managers.py b/dysql/test_managers.py index 2f8f97c..eb71b10 100644 --- a/dysql/test_managers.py +++ b/dysql/test_managers.py @@ -13,7 +13,11 @@ from typing import Optional from .connections import sqlquery -from .databases import is_set_current_database_supported, set_current_database, set_default_connection_parameters +from .databases import ( + is_set_current_database_supported, + set_current_database, + set_default_connection_parameters, +) from .mappers import CountMapper from .query_utils import QueryData @@ -24,18 +28,19 @@ class DbTestManagerBase(abc.ABC): """ Base class for all test managers. See individual implementations for usage details. """ + # pylint: disable=too-many-instance-attributes def __init__( - self, - host: str, - username: str, - password: str, - db_name: str, - schema_db_name: Optional[str], - docker_container: Optional[str] = None, - keep_db: bool = False, - **connection_defaults, + self, + host: str, + username: str, + password: str, + db_name: str, + schema_db_name: Optional[str], + docker_container: Optional[str] = None, + keep_db: bool = False, + **connection_defaults, ): # pylint: disable=too-many-arguments """ Constructor, any unknown kwargs are passed directly to set_default_connection_parameters. @@ -61,19 +66,19 @@ def __init__( @staticmethod def _is_running_in_docker(): - if os.path.exists('/.dockerenv'): + if os.path.exists("/.dockerenv"): return True - if os.path.exists('/proc/1/cgroup'): - with open('/proc/1/cgroup', 'rt', encoding='utf8') as fobj: + if os.path.exists("/proc/1/cgroup"): + with open("/proc/1/cgroup", "rt", encoding="utf8") as fobj: contents = fobj.read() - for marker in ('docker', 'kubepod', 'lxc'): + for marker in ("docker", "kubepod", "lxc"): if marker in contents: return True return False def __enter__(self): - LOGGER.debug(f'Setting up database : {self.db_name}') + LOGGER.debug(f"Setting up database : {self.db_name}") # Set the host based on whether we are in buildrunner or not (to test locally) self._create_test_db() @@ -94,7 +99,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): if not self.keep_db: - LOGGER.debug(f'Tearing down database : {self.db_name}') + LOGGER.debug(f"Tearing down database : {self.db_name}") self._tear_down_test_db() @abc.abstractmethod @@ -125,8 +130,8 @@ def _wait_for_tables_exist(self) -> None: tables_exist = False expected_count = self._get_tables_count(self.schema_db_name) while not tables_exist: - sleep(.25) - LOGGER.debug('Tables are still not ready') + sleep(0.25) + LOGGER.debug("Tables are still not ready") actual_count = self._get_tables_count(self.db_name) tables_exist = expected_count == actual_count @@ -141,12 +146,14 @@ def _run(self, command: str) -> subprocess.CompletedProcess: try: LOGGER.debug(f"Executing : '{command}'") - completed_process = subprocess.run(command, shell=True, timeout=30, check=True, capture_output=True) + completed_process = subprocess.run( + command, shell=True, timeout=30, check=True, capture_output=True + ) LOGGER.debug(f"Executed : {completed_process.stdout}") return completed_process except subprocess.CalledProcessError: - LOGGER.exception(f'Error handling command : {command}') + LOGGER.exception(f"Error handling command : {command}") raise @@ -163,16 +170,17 @@ def setup_db(self): with MariaDbTestManager(f'testdb_{self.__class__.__name__.lower()}'): yield """ + # pylint: disable=too-few-public-methods def __init__( - self, - db_name: str, - schema_db_name: Optional[str] = None, - echo_queries: bool = False, - keep_db: bool = False, - pool_size=3, - charset='utf8' + self, + db_name: str, + schema_db_name: Optional[str] = None, + echo_queries: bool = False, + keep_db: bool = False, + pool_size=3, + charset="utf8", ): # pylint: disable=too-many-arguments """ :param db_name: the name you want for your test database @@ -183,37 +191,45 @@ def __init__( :param charset: This allows you to override the default charset if you need something besides utf8 """ super().__init__( - os.getenv('MARIA_HOST', 'localhost'), - os.getenv('MARIA_USERNAME', 'root'), - os.getenv('MARIA_PASSWORD', 'password'), + os.getenv("MARIA_HOST", "localhost"), + os.getenv("MARIA_USERNAME", "root"), + os.getenv("MARIA_PASSWORD", "password"), db_name, schema_db_name, port=3306, echo_queries=echo_queries, pool_size=pool_size, - docker_container=os.getenv('MARIA_CONTAINER_NAME', 'mariadb'), + docker_container=os.getenv("MARIA_CONTAINER_NAME", "mariadb"), keep_db=keep_db, - charset=charset + charset=charset, ) def _create_test_db(self) -> None: - self._run(f'mysql -p{self.password} -h{self.host} -N -e "DROP DATABASE IF EXISTS {self.db_name}"') - self._run(f'mysql -p{self.password} -h{self.host} -s -N -e "CREATE DATABASE IF NOT EXISTS {self.db_name}"') + self._run( + f'mysql -p{self.password} -h{self.host} -N -e "DROP DATABASE IF EXISTS {self.db_name}"' + ) + self._run( + f'mysql -p{self.password} -h{self.host} -s -N -e "CREATE DATABASE IF NOT EXISTS {self.db_name}"' + ) if self.schema_db_name: self._run( - f'mysqldump --no-data -p{self.password} {self.schema_db_name} -h{self.host} ' - f'| mysql -p{self.password} {self.db_name} -h{self.host}' + f"mysqldump --no-data -p{self.password} {self.schema_db_name} -h{self.host} " + f"| mysql -p{self.password} {self.db_name} -h{self.host}" ) def _tear_down_test_db(self) -> None: - self._run(f'echo "DROP DATABASE IF EXISTS {self.db_name} " | mysql -p{self.password} -h{self.host}') + self._run( + f'echo "DROP DATABASE IF EXISTS {self.db_name} " | mysql -p{self.password} -h{self.host}' + ) @sqlquery(mapper=CountMapper()) def _get_tables_count(self, db_name: str) -> int: # pylint: disable=unused-argument return QueryData( - ''' + """ SELECT count(1) FROM information_schema.TABLES WHERE TABLE_SCHEMA=:db_name - ''', query_params={'db_name': db_name}) + """, + query_params={"db_name": db_name}, + ) diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..302d077 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,20 @@ +[lint] +# Enable the copyright rule +preview = true +extend-select = ["CPY"] + +[extend-per-file-ignores] +"setup.py" = ["CPY"] +"dysql/version.py" = ["CPY"] + +[lint.flake8-copyright] +min-file-size = 200 +notice-rgx = """\ +\"\"\" +Copyright [(2)\\d{3}]* Adobe +All Rights Reserved. + +NOTICE: Adobe permits you to use, modify, and distribute this file in accordance +with the terms of the Adobe license agreement accompanying it. +\"\"\" +""" diff --git a/setup.py b/setup.py index faccb97..16a007e 100644 --- a/setup.py +++ b/setup.py @@ -2,18 +2,14 @@ import os import subprocess import types -from datetime import datetime from setuptools import setup, find_packages -BASE_VERSION = '3.0' -SOURCE_DIR = os.path.dirname( - os.path.abspath(__file__) -) -DYSQL_DIR = os.path.join(SOURCE_DIR, 'dysql') -VERSION_FILE = os.path.join(DYSQL_DIR, 'version.py') -HEADER_FILE = os.path.join(SOURCE_DIR, '.pylint-license-header') +BASE_VERSION = "3.0" +SOURCE_DIR = os.path.dirname(os.path.abspath(__file__)) +DYSQL_DIR = os.path.join(SOURCE_DIR, "dysql") +VERSION_FILE = os.path.join(DYSQL_DIR, "version.py") def get_version(): @@ -22,78 +18,74 @@ def get_version(): """ if os.path.exists(VERSION_FILE): # Read version from file - loader = importlib.machinery.SourceFileLoader('dysql_version', VERSION_FILE) + loader = importlib.machinery.SourceFileLoader("dysql_version", VERSION_FILE) version_mod = types.ModuleType(loader.name) loader.exec_module(version_mod) existing_version = version_mod.__version__ # pylint: disable=no-member - print(f'Using existing dysql version: {existing_version}') + print(f"Using existing dysql version: {existing_version}") return existing_version # Generate the version from the base version and the git commit number, then store it in the file try: cmd = subprocess.Popen( args=[ - 'git', - 'rev-list', - '--count', - 'HEAD', + "git", + "rev-list", + "--count", + "HEAD", ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, - encoding='utf8', + encoding="utf8", ) stdout = cmd.communicate()[0] output = stdout.strip() if cmd.returncode == 0: - new_version = '{0}.{1}'.format(BASE_VERSION, output) - print(f'Setting version to {new_version}') + new_version = "{0}.{1}".format(BASE_VERSION, output) + print(f"Setting version to {new_version}") # write the version file if os.path.exists(DYSQL_DIR): - with open(HEADER_FILE, 'r', encoding='utf8') as fobj: - header = fobj.read() - header = header.replace('20\d\d', datetime.now().strftime('%Y')) - with open(VERSION_FILE, 'w', encoding='utf8') as fobj: - fobj.write(f"{header}\n__version__ = '{new_version}'\n") + with open(VERSION_FILE, "w", encoding="utf8") as fobj: + fobj.write(f"__version__ = '{new_version}'\n") return new_version except Exception as exc: - print(f'Could not generate version from git commits: {exc}') + print(f"Could not generate version from git commits: {exc}") # If all else fails, use development version - return f'{BASE_VERSION}.DEVELOPMENT' + return f"{BASE_VERSION}.DEVELOPMENT" -with open(os.path.join(os.path.dirname(__file__), 'README.rst')) as fobj: +with open(os.path.join(os.path.dirname(__file__), "README.rst")) as fobj: long_description = fobj.read().strip() setup( - name='dy-sql', + name="dy-sql", version=get_version(), - license='MIT', - description='Dynamically runs SQL queries and executions.', + license="MIT", + description="Dynamically runs SQL queries and executions.", long_description=long_description, - long_description_content_type='text/x-rst', - author='Adobe', - author_email='noreply@adobe.com', - url='https://github.com/adobe/dy-sql', - platforms=['Any'], - packages=find_packages(exclude=('*test*',)), + long_description_content_type="text/x-rst", + author="Adobe", + author_email="noreply@adobe.com", + url="https://github.com/adobe/dy-sql", + platforms=["Any"], + packages=find_packages(exclude=("*test*",)), zip_safe=False, install_requires=[ # SQLAlchemy 2+ is not yet submitted - 'sqlalchemy<2', + "sqlalchemy<2", # now using features only found in pydantic 2+ - 'pydantic>=2', + "pydantic>=2", ], classifiers=[ - 'Development Status :: 4 - Beta', - 'License :: OSI Approved :: MIT License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', + "Development Status :: 4 - Beta", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", ], - ) diff --git a/test_requirements.in b/test_requirements.in new file mode 100644 index 0000000..4662d09 --- /dev/null +++ b/test_requirements.in @@ -0,0 +1,10 @@ +pytest>=6.2.4 +pytest-randomly>=3.10.1 +pytest-cov>=2.12.1 +ruff>=0.1.7 +ruff-lsp>=0.0.45 +# Python 3.8 only supports up to 3.5.0 +pre-commit<3.6 +# Lock major version only +docker<7 +tox-pyenv diff --git a/test_requirements.txt b/test_requirements.txt index 6ecbb54..50e8d0b 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -1,6 +1,107 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile test_requirements.in +# +attrs==23.1.0 + # via + # cattrs + # lsprotocol + # pytest +cattrs==23.2.3 + # via lsprotocol +certifi==2023.11.17 + # via requests +cfgv==3.4.0 + # via pre-commit +charset-normalizer==3.3.2 + # via requests +coverage==7.3.4 + # via pytest-cov +distlib==0.3.8 + # via virtualenv +docker==6.1.3 + # via -r test_requirements.in +filelock==3.13.1 + # via + # tox + # virtualenv +identify==2.5.33 + # via pre-commit +idna==3.6 + # via requests +iniconfig==2.0.0 + # via pytest +lsprotocol==2023.0.0 + # via + # pygls + # ruff-lsp +nodeenv==1.8.0 + # via pre-commit +packaging==23.2 + # via + # docker + # pytest + # ruff-lsp + # tox +platformdirs==4.1.0 + # via + # tox + # virtualenv +pluggy==0.13.1 + # via + # pytest + # tox +pre-commit==3.5.0 + # via -r test_requirements.in +py==1.11.0 + # via + # pytest + # tox +pygls==1.2.1 + # via ruff-lsp pytest==6.2.4 -pytest-randomly==3.10.1 + # via + # -r test_requirements.in + # pytest-cov + # pytest-randomly pytest-cov==2.12.1 -pylint>2.10.2 -pylintfileheader==0.3.0 -pycodestyle==2.8.0 + # via -r test_requirements.in +pytest-randomly==3.10.1 + # via -r test_requirements.in +pyyaml==6.0.1 + # via pre-commit +requests==2.31.0 + # via docker +ruff==0.1.8 + # via + # -r test_requirements.in + # ruff-lsp +ruff-lsp==0.0.48 + # via -r test_requirements.in +six==1.16.0 + # via tox +toml==0.10.2 + # via + # pytest + # pytest-cov +tox==3.28.0 + # via tox-pyenv +tox-pyenv==1.1.0 + # via -r test_requirements.in +typing-extensions==4.9.0 + # via ruff-lsp +urllib3==2.1.0 + # via + # docker + # requests +virtualenv==20.25.0 + # via + # pre-commit + # tox +websocket-client==1.7.0 + # via docker + +# The following packages are considered to be unsafe in a requirements file: +# setuptools