Skip to content

Commit

Permalink
Merge pull request #771 from lsst/tickets/DM-37249-v24
Browse files Browse the repository at this point in the history
DM-37249-v24: backport transaction-level-pooling compatibility to v24
  • Loading branch information
TallJimbo authored Jan 14, 2023
2 parents dde7253 + c517410 commit 01021c5
Show file tree
Hide file tree
Showing 20 changed files with 664 additions and 473 deletions.
3 changes: 3 additions & 0 deletions doc/changes/DM-37249.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Rework transaction and connection management for compatibility with transaction-level connection pooling on the server.

Butler clients still hold long-lived connections, via delegation to SQLAlchemy's connection pooling, which can handle disconnections transparently most of the time. But we now wrap all temporary table usage and cursor iteration in transactions.
10 changes: 7 additions & 3 deletions python/lsst/daf/butler/registries/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def fromConfig(
writeable=writeable,
)
managerTypes = RegistryManagerTypes.fromConfig(config)
managers = managerTypes.loadRepo(database)
with database.session():
managers = managerTypes.loadRepo(database)
if defaults is None:
defaults = RegistryDefaults()
return cls(database, defaults, managers)
Expand Down Expand Up @@ -231,7 +232,8 @@ def dimensions(self) -> DimensionUniverse:

def refresh(self) -> None:
# Docstring inherited from lsst.daf.butler.registry.Registry
self._managers.refresh()
with self._db.transaction():
self._managers.refresh()

@contextlib.contextmanager
def transaction(self, *, savepoint: bool = False) -> Iterator[None]:
Expand Down Expand Up @@ -1179,7 +1181,9 @@ def queryDatasetAssociations(
flattenChains=flattenChains,
):
query = storage.select(collectionRecord)
for row in self._db.query(query).mappings():
with self._db.query(query) as sql_result:
sql_mappings = sql_result.mappings().fetchall()
for row in sql_mappings:
dataId = DataCoordinate.fromRequiredValues(
storage.datasetType.dimensions,
tuple(row[name] for name in storage.datasetType.dimensions.required.names),
Expand Down
10 changes: 7 additions & 3 deletions python/lsst/daf/butler/registry/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def initialize(cls, db: Database, context: StaticTablesContext) -> ButlerAttribu
def get(self, name: str, default: Optional[str] = None) -> Optional[str]:
# Docstring inherited from ButlerAttributeManager.
sql = sqlalchemy.sql.select(self._table.columns.value).where(self._table.columns.name == name)
row = self._db.query(sql).fetchone()
with self._db.query(sql) as sql_result:
row = sql_result.fetchone()
if row is not None:
return row[0]
return default
Expand Down Expand Up @@ -119,13 +120,16 @@ def items(self) -> Iterable[Tuple[str, str]]:
self._table.columns.name,
self._table.columns.value,
)
for row in self._db.query(sql):
with self._db.query(sql) as sql_result:
sql_rows = sql_result.fetchall()
for row in sql_rows:
yield row[0], row[1]

def empty(self) -> bool:
# Docstring inherited from ButlerAttributeManager.
sql = sqlalchemy.sql.select(sqlalchemy.sql.func.count()).select_from(self._table)
row = self._db.query(sql).fetchone()
with self._db.query(sql) as sql_result:
row = sql_result.fetchone()
return row[0] == 0

@classmethod
Expand Down
16 changes: 10 additions & 6 deletions python/lsst/daf/butler/registry/bridge/monolithic.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def check(self, refs: Iterable[DatasetIdRef]) -> Iterable[DatasetIdRef]:
)
)
)
for row in self._db.query(sql).fetchall():
with self._db.query(sql) as sql_result:
sql_rows = sql_result.fetchall()
for row in sql_rows:
yield byId[row.dataset_id]

@contextmanager
Expand Down Expand Up @@ -233,9 +235,8 @@ def join_records(

# Run query, transform results into a list of dicts that we can later
# use to delete.
rows = [
dict(**row, datastore_name=self.datastoreName) for row in self._db.query(info_in_trash).mappings()
]
with self._db.query(info_in_trash) as sql_result:
rows = [dict(**row, datastore_name=self.datastoreName) for row in sql_result.mappings()]

# It is possible for trashed refs to be linked to artifacts that
# are still associated with refs that are not to be trashed. We
Expand Down Expand Up @@ -264,7 +265,8 @@ def join_records(
== items_not_in_trash.columns[record_column],
)
)
preserved = {row[record_column] for row in self._db.query(items_to_preserve).mappings()}
with self._db.query(items_to_preserve) as sql_result:
preserved = {row[record_column] for row in sql_result.mappings()}

# Convert results to a tuple of id+info and a record of the artifacts
# that should not be deleted from datastore. The id+info tuple is
Expand Down Expand Up @@ -360,7 +362,9 @@ def findDatastores(self, ref: DatasetIdRef) -> Iterable[str]:
.select_from(self._tables.dataset_location)
.where(self._tables.dataset_location.columns.dataset_id == ref.getCheckedId())
)
for row in self._db.query(sql).mappings():
with self._db.query(sql) as sql_result:
sql_rows = sql_result.mappings().fetchall()
for row in sql_rows:
yield row[self._tables.dataset_location.columns.datastore_name]
for name, bridge in self._ephemeral.items():
if ref in bridge:
Expand Down
14 changes: 9 additions & 5 deletions python/lsst/daf/butler/registry/collections/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,10 @@ def _load(self, manager: CollectionManager) -> CollectionSearch:
.where(self._table.columns.parent == self.key)
.order_by(self._table.columns.position)
)
return CollectionSearch.fromExpression(
[manager[row._mapping[self._table.columns.child]].name for row in self._db.query(sql)]
)
with self._db.query(sql) as sql_result:
return CollectionSearch.fromExpression(
tuple(manager[row[self._table.columns.child]].name for row in sql_result.mappings())
)


K = TypeVar("K")
Expand Down Expand Up @@ -340,7 +341,9 @@ def refresh(self) -> None:
records = []
chains = []
TimespanReprClass = self._db.getTimespanRepresentation()
for row in self._db.query(sql).mappings():
with self._db.query(sql) as sql_result:
sql_rows = sql_result.mappings().fetchall()
for row in sql_rows:
collection_id = row[self._tables.collection.columns[self._collectionIdName]]
name = row[self._tables.collection.columns.name]
type = CollectionType(row["type"])
Expand Down Expand Up @@ -465,7 +468,8 @@ def getDocumentation(self, key: Any) -> Optional[str]:
.select_from(self._tables.collection)
.where(self._tables.collection.columns[self._collectionIdName] == key)
)
return self._db.query(sql).scalar()
with self._db.query(sql) as sql_result:
return sql_result.scalar()

def setDocumentation(self, key: Any, doc: Optional[str]) -> None:
# Docstring inherited from CollectionManager.
Expand Down
90 changes: 73 additions & 17 deletions python/lsst/daf/butler/registry/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(

@classmethod
def makeEngine(cls, uri: str, *, writeable: bool = True) -> sqlalchemy.engine.Engine:
return sqlalchemy.engine.create_engine(uri)
return sqlalchemy.engine.create_engine(uri, pool_size=1)

@classmethod
def fromEngine(
Expand All @@ -110,25 +110,65 @@ def fromEngine(
return cls(engine=engine, origin=origin, namespace=namespace, writeable=writeable)

@contextmanager
def transaction(
def _transaction(
self,
*,
interrupting: bool = False,
savepoint: bool = False,
lock: Iterable[sqlalchemy.schema.Table] = (),
) -> Iterator[None]:
with super().transaction(interrupting=interrupting, savepoint=savepoint, lock=lock):
assert self._session_connection is not None, "Guaranteed to have a connection in transaction"
if not self.isWriteable():
with closing(self._session_connection.connection.cursor()) as cursor:
cursor.execute("SET TRANSACTION READ ONLY")
else:
with closing(self._session_connection.connection.cursor()) as cursor:
# Make timestamps UTC, because we didn't use TIMESTAMPZ for
# the column type. When we can tolerate a schema change,
# we should change that type and remove this line.
cursor.execute("SET TIME ZONE 0")
yield
for_temp_tables: bool = False,
) -> Iterator[tuple[bool, sqlalchemy.engine.Connection]]:
with super()._transaction(interrupting=interrupting, savepoint=savepoint, lock=lock) as (
is_new,
connection,
):
if is_new:
# pgbouncer with transaction-level pooling (which we aim to
# support) says that SET cannot be used, except for a list of
# "Startup parameters" that includes "timezone" (see
# https://www.pgbouncer.org/features.html#fnref:0). But I
# don't see "timezone" in PostgreSQL's list of parameters
# passed when creating a new connection
# (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS).
# Given that the pgbouncer docs say, "PgBouncer detects their
# changes and so it can guarantee they remain consistent for
# the client", I assume we can use "SET TIMESPAN" and pgbouncer
# will take care of clients that share connections being set
# consistently. And if that assumption is wrong, we should
# still probably be okay, since all clients should be Butler
# clients, and they'll all be setting the same thing.
#
# The "SET TRANSACTION READ ONLY" should also be safe, because
# it only ever acts on the current transaction; I think it's
# not included in pgbouncer's declaration that SET is
# incompatible with transaction-level pooling because
# PostgreSQL actually considers SET TRANSACTION to be a
# fundamentally different statement from SET (they have their
# own distinct doc pages, at least).
if not (self.isWriteable() or for_temp_tables):
# PostgreSQL permits writing to temporary tables inside
# read-only transactions, but it doesn't permit creating
# them.
with closing(connection.connection.cursor()) as cursor:
cursor.execute("SET TRANSACTION READ ONLY")
cursor.execute("SET TIME ZONE 0")
else:
with closing(connection.connection.cursor()) as cursor:
# Make timestamps UTC, because we didn't use TIMESTAMPZ
# for the column type. When we can tolerate a schema
# change, we should change that type and remove this
# line.
cursor.execute("SET TIME ZONE 0")
yield is_new, connection

@contextmanager
def temporary_table(
self, spec: ddl.TableSpec, name: Optional[str] = None
) -> Iterator[sqlalchemy.schema.Table]:
# Docstring inherited.
with self.transaction(for_temp_tables=True):
with super().temporary_table(spec, name) as table:
yield table

def _lockTables(
self, connection: sqlalchemy.engine.Connection, tables: Iterable[sqlalchemy.schema.Table] = ()
Expand Down Expand Up @@ -171,6 +211,22 @@ def _convertExclusionConstraintSpec(
name=self.shrinkDatabaseEntityName("_".join(names)),
)

def _make_temporary_table(
self,
connection: sqlalchemy.engine.Connection,
spec: ddl.TableSpec,
name: Optional[str] = None,
**kwargs: Any,
) -> sqlalchemy.schema.Table:
# Docstring inherited
# Adding ON COMMIT DROP here is really quite defensive: we already
# manually drop the table at the end of the temporary_table context
# manager, and that will usually happen first. But this will guarantee
# that we drop the table at the end of the transaction even if the
# connection lasts longer, and that's good citizenship when connections
# may be multiplexed by e.g. pgbouncer.
return super()._make_temporary_table(connection, spec, name, postgresql_on_commit="DROP", **kwargs)

@classmethod
def getTimespanRepresentation(cls) -> Type[TimespanDatabaseRepresentation]:
# Docstring inherited.
Expand All @@ -193,7 +249,7 @@ def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
if column.name not in table.primary_key
}
query = query.on_conflict_do_update(constraint=table.primary_key, set_=data)
with self._connection() as connection:
with self._transaction() as (_, connection):
connection.execute(query, rows)

def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int:
Expand All @@ -207,7 +263,7 @@ def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only:
query = base_insert.on_conflict_do_nothing(constraint=table.primary_key)
else:
query = base_insert.on_conflict_do_nothing()
with self._connection() as connection:
with self._transaction() as (_, connection):
return connection.execute(query, rows).rowcount


Expand Down
27 changes: 15 additions & 12 deletions python/lsst/daf/butler/registry/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,6 @@ def _onSqlite3Connect(
cursor.execute("PRAGMA busy_timeout = 300000;") # in ms, so 5min (way longer than should be needed)


def _onSqlite3Begin(connection: sqlalchemy.engine.Connection) -> sqlalchemy.engine.Connection:
assert connection.dialect.name == "sqlite"
# Replace pysqlite's buggy transaction handling that never BEGINs with our
# own that does, and tell SQLite to try to acquire a lock as soon as we
# start a transaction (this should lead to more blocking and fewer
# deadlocks).
connection.execute(sqlalchemy.text("BEGIN IMMEDIATE"))
return connection


class SqliteDatabase(Database):
"""An implementation of the `Database` interface for SQLite3.
Expand Down Expand Up @@ -190,6 +180,19 @@ def creator() -> sqlite3.Connection:
engine = sqlalchemy.engine.create_engine(uri, creator=creator)

sqlalchemy.event.listen(engine, "connect", _onSqlite3Connect)

def _onSqlite3Begin(connection: sqlalchemy.engine.Connection) -> sqlalchemy.engine.Connection:
assert connection.dialect.name == "sqlite"
# Replace pysqlite's buggy transaction handling that never BEGINs
# with our own that does, and tell SQLite to try to acquire a lock
# as soon as we start a transaction that might involve writes (this
# should lead to more blocking and fewer deadlocks).
if writeable:
connection.execute(sqlalchemy.text("BEGIN IMMEDIATE"))
else:
connection.execute(sqlalchemy.text("BEGIN"))
return connection

sqlalchemy.event.listen(engine, "begin", _onSqlite3Begin)

return engine
Expand Down Expand Up @@ -319,7 +322,7 @@ def replace(self, table: sqlalchemy.schema.Table, *rows: dict) -> None:
if column.name not in table.primary_key
}
query = query.on_conflict_do_update(index_elements=table.primary_key, set_=data)
with self._connection() as connection:
with self._transaction() as (_, connection):
connection.execute(query, rows)

def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only: bool = False) -> int:
Expand All @@ -331,7 +334,7 @@ def ensure(self, table: sqlalchemy.schema.Table, *rows: dict, primary_key_only:
query = query.on_conflict_do_nothing(index_elements=table.primary_key)
else:
query = query.on_conflict_do_nothing()
with self._connection() as connection:
with self._transaction() as (_, connection):
return connection.execute(query, rows).rowcount

filename: Optional[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def refresh(self) -> None:
byName = {}
byId: Dict[DatasetId, ByDimensionsDatasetRecordStorage] = {}
c = self._static.dataset_type.columns
for row in self._db.query(self._static.dataset_type.select()).mappings():
with self._db.query(self._static.dataset_type.select()) as sql_result:
sql_rows = sql_result.mappings().fetchall()
for row in sql_rows:
name = row[c.name]
dimensions = self._dimensions.loadDimensionGraph(row[c.dimensions_key])
calibTableName = row[c.calibration_association_table]
Expand Down Expand Up @@ -346,7 +348,8 @@ def getDatasetRef(self, id: DatasetId) -> Optional[DatasetRef]:
.select_from(self._static.dataset)
.where(self._static.dataset.columns.id == id)
)
row = self._db.query(sql).mappings().fetchone()
with self._db.query(sql) as sql_result:
row = sql_result.mappings().fetchone()
if row is None:
return None
recordsForType = self._byId.get(row[self._static.dataset.columns.dataset_type_id])
Expand Down
Loading

0 comments on commit 01021c5

Please sign in to comment.