Skip to content

Commit

Permalink
DEV: Use newer APIs for Python >= 3.11
Browse files Browse the repository at this point in the history
  • Loading branch information
czgdp1807 committed Aug 22, 2024
1 parent ccb2162 commit d1d465f
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions jupyter_cache/cache/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime
import os
from pathlib import Path
import sys
from typing import Any, Dict, List, Optional, Union

from sqlalchemy import JSON, Column, DateTime, Integer, String, Text
Expand Down Expand Up @@ -51,6 +52,12 @@ def get_version(path: Union[str, Path]) -> Optional[str]:
return version_file.read_text().strip()


def datetime_utcnow():
if sys.version_info.minor >= 11:
return lambda : datetime.datetime.now(datetime.UTC)
return datetime.datetime.utcnow


@contextmanager
def session_context(engine: Engine):
"""Open a connection to the database."""
Expand Down Expand Up @@ -128,7 +135,7 @@ class NbProjectRecord(OrmBase):
"""A list of file assets required for the notebook to run."""
exec_data = Column(JSON(), nullable=True)
"""Data on how to execute the notebook."""
created = Column(DateTime, nullable=False, default=lambda: datetime.datetime.now(datetime.UTC) )
created = Column(DateTime, nullable=False, default=datetime_utcnow())
traceback = Column(Text(), nullable=True, default="")
"""A traceback is added if a notebook fails to execute fully."""

Expand Down Expand Up @@ -288,9 +295,9 @@ class NbCacheRecord(OrmBase):
description = Column(String(255), nullable=False, default="")
data = Column(JSON())
"""Extra data, such as the execution time."""
created = Column(DateTime, nullable=False, default=lambda: datetime.datetime.now(datetime.UTC))
created = Column(DateTime, nullable=False, default=datetime_utcnow())
accessed = Column(
DateTime, nullable=False, default=lambda: datetime.datetime.now(datetime.UTC), onupdate=lambda: datetime.datetime.now(datetime.UTC)
DateTime, nullable=False, default=datetime_utcnow(), onupdate=datetime_utcnow()
)

def __repr__(self):
Expand Down Expand Up @@ -368,7 +375,7 @@ def touch(pk, db: Engine):
record = session.query(NbCacheRecord).filter_by(pk=pk).one_or_none()
if record is None:
raise KeyError(f"Cache record not found for NB with PK: {pk}")
record.accessed = datetime.datetime.now(datetime.UTC)
record.accessed = datetime_utcnow()()
session.commit()

def touch_hashkey(hashkey, db: Engine):
Expand All @@ -379,7 +386,7 @@ def touch_hashkey(hashkey, db: Engine):
)
if record is None:
raise KeyError(f"Cache record not found for NB with hashkey: {hashkey}")
record.accessed = datetime.datetime.now(datetime.UTC)
record.accessed = datetime_utcnow()()
session.commit()

@staticmethod
Expand Down

0 comments on commit d1d465f

Please sign in to comment.