Skip to content

Commit

Permalink
add initialize to cache (#598)
Browse files Browse the repository at this point in the history
* add initialize to cache

* initialize only if not none

---------

Co-authored-by: Rajas Bansal <[email protected]>
  • Loading branch information
rajasbansal and rajasbansal committed Oct 13, 2023
1 parent 6db7430 commit 7284987
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 5 deletions.
5 changes: 5 additions & 0 deletions src/autolabel/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ class BaseCache(ABC):
def __init__(self) -> None:
super().__init__()

@abstractmethod
def initialize():
"""initialize the cache. Must be implemented by classes derived from BaseCache."""
pass

@abstractmethod
def lookup(self, entry):
"""abstract method to retrieve a cached entry. Must be implemented by classes derived from BaseCache."""
Expand Down
6 changes: 5 additions & 1 deletion src/autolabel/cache/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ class RedisCache(BaseCache):
"""A cache system implemented with Redis"""

def __init__(self, endpoint: str, db: int = 0):
self.endpoint = endpoint
self.db = db

def initialize(self):
try:
from redis import Redis

self.redis = Redis.from_url(endpoint, db=db)
self.redis = Redis.from_url(self.endpoint, db=self.db)
except ImportError:
raise ImportError(
"redis is required to use the Redis Cache. Please install it with the following command: pip install redis"
Expand Down
4 changes: 2 additions & 2 deletions src/autolabel/cache/sqlalchemy_generation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ class SQLAlchemyGenerationCache(BaseCache):
"""A cache system implemented with SQL Alchemy"""

def __init__(self):
self.engine = create_db_engine()
self.engine = None
self.base = Base
self.session = None
self.initialize()

def initialize(self):
self.engine = create_db_engine()
self.base.metadata.create_all(self.engine)
self.session = sessionmaker(bind=self.engine)()

Expand Down
4 changes: 2 additions & 2 deletions src/autolabel/cache/sqlalchemy_transform_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ class SQLAlchemyTransformCache(BaseCache):
"""

def __init__(self):
self.engine = create_db_engine()
self.engine = None
self.base = Base
self.session = None
self.initialize()

def initialize(self):
self.engine = create_db_engine()
self.base.metadata.create_all(self.engine)
self.session = sessionmaker(bind=self.engine)()

Expand Down
5 changes: 5 additions & 0 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def __init__(
self.generation_cache = None
self.transform_cache = None

if self.generation_cache is not None:
self.generation_cache.initialize()
if self.transform_cache is not None:
self.transform_cache.initialize()

self.console = Console(quiet=not console_output)

self.config = (
Expand Down

0 comments on commit 7284987

Please sign in to comment.