From 6c0f65b6b63b223bec1059ecd037697b068f7e63 Mon Sep 17 00:00:00 2001 From: kukushking Date: Tue, 11 Jul 2023 17:19:16 +0100 Subject: [PATCH] fix: RDS Data API - allow ANSI-compatible identifiers. (#2391) --- awswrangler/data_api/rds.py | 33 ++++++++++++++++----- tests/unit/test_data_api.py | 57 +++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 7 deletions(-) diff --git a/awswrangler/data_api/rds.py b/awswrangler/data_api/rds.py index d59c73feb..e9a5e9d86 100644 --- a/awswrangler/data_api/rds.py +++ b/awswrangler/data_api/rds.py @@ -1,6 +1,7 @@ """RDS Data API Connector.""" import datetime as dt import logging +import re import time import uuid from decimal import Decimal @@ -227,6 +228,19 @@ def _get_statement_result(self, request_id: str) -> pd.DataFrame: return dataframe +def escape_identifier(identifier: str, sql_mode: str = "mysql") -> str: + """Escape identifiers. Uses MySQL-compatible backticks by default.""" + if not isinstance(identifier, str): + raise TypeError("SQL identifier must be a string") + if re.search(r"\W", identifier): + raise TypeError(f"SQL identifier contains invalid characters: {identifier}") + if sql_mode == "mysql": + return f"`{identifier}`" + elif sql_mode == "ansi": + return f'"{identifier}"' + raise ValueError(f"Unknown SQL MODE: {sql_mode}") + + def connect( resource_arn: str, database: str, secret_arn: str = "", boto3_session: Optional[boto3.Session] = None, **kwargs: Any ) -> RdsDataApi: @@ -271,8 +285,8 @@ def read_sql_query(sql: str, con: RdsDataApi, database: Optional[str] = None) -> return con.execute(sql, database=database) -def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str) -> None: - sql = f"DROP TABLE IF EXISTS `{table}`" +def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str, sql_mode: str) -> None: + sql = f"DROP TABLE IF EXISTS {escape_identifier(table, sql_mode=sql_mode)}" _logger.debug("Drop table query:\n%s", sql) con.execute(sql, database=database, transaction_id=transaction_id) @@ -292,9 +306,10 @@ def _create_table( index: bool, dtype: Optional[Dict[str, str]], varchar_lengths: Optional[Dict[str, int]], + sql_mode: str, ) -> None: if mode == "overwrite": - _drop_table(con=con, table=table, database=database, transaction_id=transaction_id) + _drop_table(con=con, table=table, database=database, transaction_id=transaction_id, sql_mode=sql_mode) elif _does_table_exist(con=con, table=table, database=database, transaction_id=transaction_id): return @@ -306,8 +321,8 @@ def _create_table( varchar_lengths=varchar_lengths, converter_func=_data_types.pyarrow2mysql, ) - cols_str: str = "".join([f"`{k}` {v},\n" for k, v in mysql_types.items()])[:-2] - sql = f"CREATE TABLE IF NOT EXISTS `{table}` (\n{cols_str})" + cols_str: str = "".join([f"{escape_identifier(k, sql_mode=sql_mode)} {v},\n" for k, v in mysql_types.items()])[:-2] + sql = f"CREATE TABLE IF NOT EXISTS {escape_identifier(table, sql_mode=sql_mode)} (\n{cols_str})" _logger.debug("Create table query:\n%s", sql) con.execute(sql, database=database, transaction_id=transaction_id) @@ -388,6 +403,7 @@ def to_sql( varchar_lengths: Optional[Dict[str, int]] = None, use_column_names: bool = False, chunksize: int = 200, + sql_mode: str = "mysql", ) -> None: """ Insert data using an SQL query on a Data API connection. @@ -439,19 +455,22 @@ def to_sql( index=index, dtype=dtype, varchar_lengths=varchar_lengths, + sql_mode=sql_mode, ) if index: df = df.reset_index(level=df.index.names) if use_column_names: - insertion_columns = "(" + ", ".join([f"`{col}`" for col in df.columns]) + ")" + insertion_columns = ( + "(" + ", ".join([f"{escape_identifier(col, sql_mode=sql_mode)}" for col in df.columns]) + ")" + ) else: insertion_columns = "" placeholders = ", ".join([f":{col}" for col in df.columns]) - sql = f"""INSERT INTO `{table}` {insertion_columns} VALUES ({placeholders})""" + sql = f"INSERT INTO {escape_identifier(table, sql_mode=sql_mode)} {insertion_columns} VALUES ({placeholders})" parameter_sets = _generate_parameter_sets(df) for parameter_sets_chunk in _utils.chunkify(parameter_sets, max_length=chunksize): diff --git a/tests/unit/test_data_api.py b/tests/unit/test_data_api.py index a8d7c83b4..a2af98326 100644 --- a/tests/unit/test_data_api.py +++ b/tests/unit/test_data_api.py @@ -39,6 +39,13 @@ def mysql_serverless_connector(databases_parameters: Dict[str, Any]) -> "RdsData yield con +@pytest.fixture +def postgresql_serverless_connector(databases_parameters: Dict[str, Any]) -> "RdsDataApi": + con = create_rds_connector("postgresql_serverless", databases_parameters) + with con: + yield con + + def test_connect_redshift_serverless_iam_role(databases_parameters: Dict[str, Any]) -> None: workgroup_name = databases_parameters["redshift_serverless"]["workgroup"] database = databases_parameters["redshift_serverless"]["database"] @@ -68,6 +75,16 @@ def mysql_serverless_table(mysql_serverless_connector: "RdsDataApi") -> Iterator mysql_serverless_connector.execute(f"DROP TABLE IF EXISTS test.{name}") +@pytest.fixture(scope="function") +def postgresql_serverless_table(postgresql_serverless_connector: "RdsDataApi") -> Iterator[str]: + name = f"tbl_{get_time_str_with_random_suffix()}" + print(f"Table name: {name}") + try: + yield name + finally: + postgresql_serverless_connector.execute(f"DROP TABLE IF EXISTS test.{name}") + + def test_data_api_redshift_columnless_query(redshift_connector: "RedshiftDataApi") -> None: dataframe = wr.data_api.redshift.read_sql_query("SELECT 1", con=redshift_connector) unknown_column_indicator = "?column?" @@ -223,3 +240,43 @@ def test_data_api_mysql_to_sql_mode( def test_data_api_exception(mysql_serverless_connector: "RdsDataApi", mysql_serverless_table: str) -> None: with pytest.raises(boto3.client("rds-data").exceptions.BadRequestException): wr.data_api.rds.read_sql_query("CUPCAKE", con=mysql_serverless_connector) + + +def test_data_api_mysql_ansi(mysql_serverless_connector: "RdsDataApi", mysql_serverless_table: str) -> None: + database = "test" + frame = pd.DataFrame([[42, "test"]], columns=["id", "name"]) + + mysql_serverless_connector.execute("SET SESSION sql_mode='ANSI_QUOTES';") + + wr.data_api.rds.to_sql( + df=frame, + con=mysql_serverless_connector, + table=mysql_serverless_table, + database=database, + sql_mode="ansi", + ) + + out_frame = wr.data_api.rds.read_sql_query( + f"SELECT name FROM {mysql_serverless_table} WHERE id = 42", con=mysql_serverless_connector + ) + expected_dataframe = pd.DataFrame([["test"]], columns=["name"]) + assert_pandas_equals(out_frame, expected_dataframe) + + +def test_data_api_postgresql(postgresql_serverless_connector: "RdsDataApi", postgresql_serverless_table: str) -> None: + database = "test" + frame = pd.DataFrame([[42, "test"]], columns=["id", "name"]) + + wr.data_api.rds.to_sql( + df=frame, + con=postgresql_serverless_connector, + table=postgresql_serverless_table, + database=database, + sql_mode="ansi", + ) + + out_frame = wr.data_api.rds.read_sql_query( + f"SELECT name FROM {postgresql_serverless_table} WHERE id = 42", con=postgresql_serverless_connector + ) + expected_dataframe = pd.DataFrame([["test"]], columns=["name"]) + assert_pandas_equals(out_frame, expected_dataframe)