diff --git a/udf/anomaly-detection/src/_config.py b/udf/anomaly-detection/src/_config.py index 30c4ade7..ec161c15 100644 --- a/udf/anomaly-detection/src/_config.py +++ b/udf/anomaly-detection/src/_config.py @@ -21,6 +21,8 @@ class ReTrainConf: min_train_size: int = 2000 retrain_freq_hr: int = 24 resume_training: bool = False + model_expiry_sec: int = 86400 # 24 hrs + dedup_expiry_sec: int = 1800 # 30 days @dataclass diff --git a/udf/anomaly-detection/src/udf/trainer.py b/udf/anomaly-detection/src/udf/trainer.py index 81c3395f..a2b00f4b 100644 --- a/udf/anomaly-detection/src/udf/trainer.py +++ b/udf/anomaly-detection/src/udf/trainer.py @@ -46,7 +46,7 @@ class Trainer: def fetch_prometheus_data(cls, payload: TrainerPayload) -> pd.DataFrame: prometheus_conf = ConfigManager.get_prom_config() if prometheus_conf is None: - _LOGGER.error("Prometheus config is not available") + _LOGGER.error("%s - Prometheus config is not available", payload.uuid) return pd.DataFrame() data_fetcher = PrometheusDataFetcher(prometheus_conf.server) return data_fetcher.fetch_data( @@ -61,7 +61,7 @@ def fetch_druid_data(cls, payload: TrainerPayload) -> pd.DataFrame: druid_conf = ConfigManager.get_druid_config() fetcher_conf = stream_config.druid_fetcher if druid_conf is None: - _LOGGER.error("Druid config is not available") + _LOGGER.error("%s - Druid config is not available", payload.uuid) return pd.DataFrame() data_fetcher = DruidFetcher(url=druid_conf.url, endpoint=druid_conf.endpoint) @@ -79,28 +79,43 @@ def fetch_druid_data(cls, payload: TrainerPayload) -> pd.DataFrame: @classmethod def fetch_data(cls, payload: TrainerPayload) -> pd.DataFrame: + _start_train = time.perf_counter() stream_config = ConfigManager.get_stream_config(payload.config_id) + + _df = pd.DataFrame() if stream_config.source == DataSource.PROMETHEUS: - return cls.fetch_prometheus_data(payload) + _df = cls.fetch_prometheus_data(payload) elif stream_config.source == DataSource.DRUID: - return cls.fetch_druid_data(payload) + _df = cls.fetch_druid_data(payload) + else: + _LOGGER.error( + "%s - Data source is not supported, source: %s, keys: %s", + payload.uuid, + stream_config.source, + payload.composite_keys, + ) + return _df - _LOGGER.error( - "Data source is not supported, source: %s, keys: %s", + _LOGGER.debug( + "%s - Time taken to fetch data from %s: %.3f sec, df shape: %s", + payload.uuid, stream_config.source, - payload.composite_keys, + time.perf_counter() - _start_train, + _df.shape, ) - return pd.DataFrame() + return _df @classmethod - def _is_new_request(cls, redis_client: redis_client_t, payload: TrainerPayload) -> bool: + def _is_new_request( + cls, redis_client: redis_client_t, dedup_expiry: int, payload: TrainerPayload + ) -> bool: _ckeys = ":".join(payload.composite_keys) r_key = f"train::{_ckeys}" value = redis_client.get(r_key) if value: return False - redis_client.setex(r_key, time=REQUEST_EXPIRY, value=1) + redis_client.setex(r_key, time=dedup_expiry, value=1) return True @classmethod @@ -157,6 +172,7 @@ def _train_and_save( model_cfg = numalogic_conf.model preproc_cfgs = numalogic_conf.preprocess + retrain_cfg = ConfigManager.get_retrain_config(payload.config_id) # TODO: filter the metrics here @@ -173,7 +189,7 @@ def _train_and_save( # TODO if one of the models fail to save, delete the previously saved models and transition stage # Save main model - model_registry = RedisRegistry(client=redis_client) + model_registry = RedisRegistry(client=redis_client, ttl=retrain_cfg.model_expiry_sec) try: version = model_registry.save( skeys=skeys, @@ -238,17 +254,17 @@ def _train_and_save( def run(self, keys: List[str], datum: Datum) -> Messages: messages = Messages() redis_client = get_redis_client_from_conf() - payload = TrainerPayload(**orjson.loads(datum.value)) - is_new = self._is_new_request(redis_client, payload) + + retrain_config = ConfigManager.get_retrain_config(payload.config_id) + numalogic_config = ConfigManager.get_numalogic_config(payload.config_id) + + is_new = self._is_new_request(redis_client, retrain_config.dedup_expiry_sec, payload) if not is_new: messages.append(Message.to_drop()) return messages - retrain_config = ConfigManager.get_retrain_config(payload.config_id) - numalogic_config = ConfigManager.get_numalogic_config(payload.config_id) - try: df = self.fetch_data(payload) except Exception as err: @@ -275,6 +291,6 @@ def run(self, keys: List[str], datum: Datum) -> Messages: train_df = get_feature_df(df, payload.metrics) self._train_and_save(numalogic_config, payload, redis_client, train_df) - messages.append(Message(keys=keys, value=train_df.to_json())) + messages.append(Message(keys=keys, value=payload.to_json())) return messages