Skip to content

Commit

Permalink
Merge pull request #46 from tamagoko/last_insert_id
Browse files Browse the repository at this point in the history
adding support for including a get_last_insert_id
  • Loading branch information
tamagoko authored Jul 24, 2024
2 parents 66d8a40 + 839bb58 commit 340c321
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 17 deletions.
53 changes: 49 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,9 @@ query returning a dictionary where there are multiple results under each key. No
def get_status_by_name():
return QueryData("SELECT status, name FROM table")
==========
@sqlupdate
~~~~~~~~~~
==========
Handles any SQL that is not a select. This is primarily, but not limited to, ``insert``, ``update``, and ``delete``.


Expand All @@ -418,6 +418,10 @@ Handles any SQL that is not a select. This is primarily, but not limited to, ``i
def insert_items(item_dict):
return QueryData("INSERT INTO", template_params={'in__item_id':item_id_list})
---------------------------------
multiple queries in a transaction
---------------------------------
You can yield multiple QueryData objects. This is done in a transaction and it can be helpful for data integrity or just
a nice clean way to run a set of updates.

Expand All @@ -430,11 +434,16 @@ a nice clean way to run a set of updates.
yield QueryData(f'INSERT INTO table_1 {insert_values_1}', query_params=insert_values_params_1)
yield QueryData(f'INSERT INTO table_2 {insert_values_2}', query_params=insert_values_params_2)
if needed you can assign a callback to be ran after a query or set of queries completes successfully
--------------------------
getting the last insert id
--------------------------
You can assign a callback to be ran after a query or set of queries completes successfully. This is useful when you need
to get the last insert id for a table that has an auto incrementing id field. This allows you to set it as a parameter on
a follow up relational table within the same transaction scope.

.. code-block:: python
@sqlupdate(on_success=_handle_insert_success)
@sqlupdate()
def insert_items_with_callback(item_dict):
insert_values_1, insert_params_1 = TemplateGenerator.values('table1values', _get_values_for_1_from_items(item_dict))
insert_values_2, insert_params_2 = TemplateGenerator.values('table2values', _get_values_for_2_from_items(item_dict))
Expand All @@ -444,6 +453,42 @@ if needed you can assign a callback to be ran after a query or set of queries co
def _handle_insert_success(item_dict):
# callback logic here happens after the transaction is complete
`get_last_insert_id` is a placeholder kwarg that will be automatically overwritten by the sqlupdate decorator at run time.
Therefore, the assigned value in the function definition does not matter.


Using `get_last_insert_id` gives you the most recently set id. You can leverage this for later queries yielded, or you could
use it and set ids in a reference object passed in for access to the ides outside of the sqlupdate function.


.. code-block:: python
@sqlupdate()
def insert_item_with_get_last_insert(get_last_insert_id=None, item_dict):
insert_values, insert_params = TemplateGenerator.values('table1values', _get_values_from_items(item_dict))
yield QueryData(f'INSERT INTO table_1 {insert_values}', query_params=insert_values_params)
last_id = get_last_insert_id()
yield QueryData(f'INSERT INTO related_table_1 (table_1_id, value) VALUES (:table_1_id, :value)',
query_params={'table_1_id': last_id, 'value': 'some_value'})
.. note::
`get_last_insert_id` will get you the last inserted id from the most recently table inserted with an autoincrement.
Be sure to call `get_last_insert_id` right after you yield the query that inserts the record you need the id for.


.. code-block:: python
class Item(BaseModel):
id: int | None = None
name: str
@sqlupdate()
def insert_items_and_update_ids(items: List[Item], get_last_insert_id = None)
for item in items:
yield QueryData("INSERT INTO table (name) VALUES (:name)", query_params={'name': item.name})
last_id = get_last_insert_id()
item.id = last_id
@sqlexists
~~~~~~~~~~
This wraps a SQL query to determine if a row exists or not. If at least one row is returned from the query, it will
Expand Down
39 changes: 27 additions & 12 deletions dysql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@
)
from .query_utils import get_query_data


logger = logging.getLogger("database")


# Always initialize a database container, it is never set again
_DATABASE_CONTAINER = DatabaseContainerSingleton()

Expand Down Expand Up @@ -182,7 +180,9 @@ def handle_query(*args, **kwargs):


def sqlupdate(
isolation_level="READ_COMMITTED", disable_foreign_key_checks=False, on_success=None
isolation_level="READ_COMMITTED",
disable_foreign_key_checks=False,
on_success=None,
):
"""
:param isolation_level should specify whether we can read data from transactions that are not
Expand All @@ -202,14 +202,22 @@ def sqlupdate(
anything
Examples::
@sqlinsert
def insert_example(key_values)
return "INSERT INTO table(id, value) VALUES (:id, :value)", key_values
@sqlinsert
def delete_example(ids)
return "DELETE FROM table", key_values
@sqlupdate
def insert_example(key_values)
return QueryData("INSERT INTO table(id, value) VALUES (:id, :value)", key_values)
@sqlupdate
def delete_example(ids)
return QueryData("DELETE FROM table WHERE id=:id", { "id": id })
@sqlupdate()
def insert_with_relations(get_last_insert_id = None):
yield QueryData("INSERT INTO table(value) VALUES (:value)", key_values)
id = get_last_insert_id()
yield "INSERT INTO relation_table(id, value) VALUES (:id, :value)", {
"id": id,
"value": "value"
})
"""

def update_wrapper(func):
Expand All @@ -225,7 +233,11 @@ def handle_query(*args, **kwargs):
) as conn_manager:
if disable_foreign_key_checks:
conn_manager.execute_query("SET FOREIGN_KEY_CHECKS=0")

last_insert_method = "get_last_insert_id"
if last_insert_method in inspect.signature(func).parameters:
kwargs[last_insert_method] = lambda: conn_manager.execute_query(
"SELECT LAST_INSERT_ID()"
).scalar()
if inspect.isgeneratorfunction(func):
logger.debug("handling each query before committing transaction")

Expand All @@ -245,6 +257,9 @@ def handle_query(*args, **kwargs):

if disable_foreign_key_checks:
conn_manager.execute_query("SET FOREIGN_KEY_CHECKS=1")

if last_insert_method in kwargs:
del kwargs[last_insert_method]
if on_success:
on_success(*args, **kwargs)

Expand Down
39 changes: 38 additions & 1 deletion dysql/test/test_sql_insert_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,19 @@

@pytest.fixture(name="mock_engine", autouse=True)
def mock_engine_fixture(mock_create_engine):
initial_id = 0

def handle_execute(query=None, args=None):
nonlocal initial_id
if "INSERT INTO get_last(name)" in query.text:
initial_id += 1
if "SELECT LAST_INSERT_ID()" == query.text:
return type("Result", (), {"scalar": lambda: initial_id})
return []

mock_engine = setup_mock_engine(mock_create_engine)
mock_engine.connect().execution_options().execute.side_effect = lambda x, y: []
execute_mock = mock_engine.connect().execution_options().execute
execute_mock.side_effect = handle_execute
return mock_engine


Expand Down Expand Up @@ -140,3 +151,29 @@ def insert_into_single_value(names):
return QueryData(
"INSERT INTO table(name) {values__name_col}", template_params=template_params
)


def test_last_insert_id():
@sqlupdate()
def insert(get_last_insert_id=None):
yield QueryData("INSERT INTO get_last(name) VALUES ('Tom')")
assert get_last_insert_id
assert get_last_insert_id() == 1
yield QueryData("INSERT INTO get_last(name) VALUES ('Jerry')")
assert get_last_insert_id() == 2

insert()


def test_last_insert_id_removed_before_callback():
def callback(**kwargs):
assert "get_last_insert_id" not in kwargs

@sqlupdate()
def insert(get_last_insert_id=None):
assert get_last_insert_id
yield QueryData("INSERT INTO get_last(name) VALUES ('Tom')")
yield QueryData("INSERT INTO get_last(name) VALUES ('Jerry')")
assert get_last_insert_id() == 2

insert()

0 comments on commit 340c321

Please sign in to comment.