diff --git a/src/autolabel/cache/base.py b/src/autolabel/cache/base.py index 299c4de9..e4a0270a 100644 --- a/src/autolabel/cache/base.py +++ b/src/autolabel/cache/base.py @@ -9,6 +9,11 @@ class BaseCache(ABC): def __init__(self) -> None: super().__init__() + @abstractmethod + def initialize(): + """initialize the cache. Must be implemented by classes derived from BaseCache.""" + pass + @abstractmethod def lookup(self, entry): """abstract method to retrieve a cached entry. Must be implemented by classes derived from BaseCache.""" diff --git a/src/autolabel/cache/redis_cache.py b/src/autolabel/cache/redis_cache.py index bf39f73a..d7e5a848 100644 --- a/src/autolabel/cache/redis_cache.py +++ b/src/autolabel/cache/redis_cache.py @@ -12,10 +12,14 @@ class RedisCache(BaseCache): """A cache system implemented with Redis""" def __init__(self, endpoint: str, db: int = 0): + self.endpoint = endpoint + self.db = db + + def initialize(self): try: from redis import Redis - self.redis = Redis.from_url(endpoint, db=db) + self.redis = Redis.from_url(self.endpoint, db=self.db) except ImportError: raise ImportError( "redis is required to use the Redis Cache. Please install it with the following command: pip install redis" diff --git a/src/autolabel/cache/sqlalchemy_generation_cache.py b/src/autolabel/cache/sqlalchemy_generation_cache.py index 338d4931..3803a37b 100644 --- a/src/autolabel/cache/sqlalchemy_generation_cache.py +++ b/src/autolabel/cache/sqlalchemy_generation_cache.py @@ -16,12 +16,12 @@ class SQLAlchemyGenerationCache(BaseCache): """A cache system implemented with SQL Alchemy""" def __init__(self): - self.engine = create_db_engine() + self.engine = None self.base = Base self.session = None - self.initialize() def initialize(self): + self.engine = create_db_engine() self.base.metadata.create_all(self.engine) self.session = sessionmaker(bind=self.engine)() diff --git a/src/autolabel/cache/sqlalchemy_transform_cache.py b/src/autolabel/cache/sqlalchemy_transform_cache.py index 9a144970..9b7aa54b 100644 --- a/src/autolabel/cache/sqlalchemy_transform_cache.py +++ b/src/autolabel/cache/sqlalchemy_transform_cache.py @@ -20,12 +20,12 @@ class SQLAlchemyTransformCache(BaseCache): """ def __init__(self): - self.engine = create_db_engine() + self.engine = None self.base = Base self.session = None - self.initialize() def initialize(self): + self.engine = create_db_engine() self.base.metadata.create_all(self.engine) self.session = sessionmaker(bind=self.engine)() diff --git a/src/autolabel/labeler.py b/src/autolabel/labeler.py index ad8a3360..d3187419 100644 --- a/src/autolabel/labeler.py +++ b/src/autolabel/labeler.py @@ -82,6 +82,11 @@ def __init__( self.generation_cache = None self.transform_cache = None + if self.generation_cache is not None: + self.generation_cache.initialize() + if self.transform_cache is not None: + self.transform_cache.initialize() + self.console = Console(quiet=not console_output) self.config = (